diff --git a/.devcontainer/README.md b/.devcontainer/README.md index 2b18630a21..359e2e5aef 100644 --- a/.devcontainer/README.md +++ b/.devcontainer/README.md @@ -1,23 +1,26 @@ # Development with devcontainer + This project includes a devcontainer configuration that allows you to open the project in a container with a fully configured development environment. Both frontend and backend environments are initialized when the container is started. + ## GitHub Codespaces + [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/langgenius/dify) you can simply click the button above to open this project in GitHub Codespaces. For more info, check out the [GitHub documentation](https://docs.github.com/en/free-pro-team@latest/github/developing-online-with-codespaces/creating-a-codespace#creating-a-codespace). - ## VS Code Dev Containers + [![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/langgenius/dify) if you have VS Code installed, you can click the button above to open this project in VS Code Dev Containers. You can learn more in the [Dev Containers documentation](https://code.visualstudio.com/docs/devcontainers/containers). - ## Pros of Devcontainer + Unified Development Environment: By using devcontainers, you can ensure that all developers are developing in the same environment, reducing the occurrence of "it works on my machine" type of issues. Quick Start: New developers can set up their development environment in a few simple steps, without spending a lot of time on environment configuration. @@ -25,11 +28,13 @@ Quick Start: New developers can set up their development environment in a few si Isolation: Devcontainers isolate your project from your host operating system, reducing the chance of OS updates or other application installations impacting the development environment. ## Cons of Devcontainer + Learning Curve: For developers unfamiliar with Docker and VS Code, using devcontainers may be somewhat complex. Performance Impact: While usually minimal, programs running inside a devcontainer may be slightly slower than those running directly on the host. ## Troubleshooting + if you see such error message when you open this project in codespaces: ![Alt text](troubleshooting.png) diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 022f71bfb4..39a653953e 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -1,11 +1,11 @@ #!/bin/bash -npm add -g pnpm@10.13.1 +npm add -g pnpm@10.15.0 cd web && pnpm install pipx install uv echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc -echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc +echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md index 47e2453f41..a59630d112 100644 --- a/.github/CODE_OF_CONDUCT.md +++ b/.github/CODE_OF_CONDUCT.md @@ -17,27 +17,25 @@ diverse, inclusive, and healthy community. Examples of behavior that contributes to a positive environment for our community include: -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our mistakes, +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience -* Focusing on what is best not just for us as individuals, but for the +- Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: -* The use of sexualized language or imagery, and sexual attention or +- The use of sexualized language or imagery, and sexual attention or advances of any kind -* Trolling, insulting or derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in a +- Other conduct which could reasonably be considered inappropriate in a professional setting ## Language Policy To facilitate clear and effective communication, all discussions, comments, documentation, and pull requests in this project should be conducted in English. This ensures that all contributors can participate and collaborate effectively. - - diff --git a/.github/ISSUE_TEMPLATE/refactor.yml b/.github/ISSUE_TEMPLATE/refactor.yml new file mode 100644 index 0000000000..cf74dcc546 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/refactor.yml @@ -0,0 +1,44 @@ +name: "✨ Refactor" +description: Refactor existing code for improved readability and maintainability. +title: "[Chore/Refactor] " +labels: + - refactor +body: + - type: checkboxes + attributes: + label: Self Checks + description: "To make sure we get to you in time, please check the following :)" + options: + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). + required: true + - label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). + required: true + - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. + required: true + - label: I confirm that I am using English to submit this report, otherwise it will be closed. + required: true + - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :) + required: true + - label: "Please do not modify this template :) and fill in all the required fields." + required: true + - type: textarea + id: description + attributes: + label: Description + placeholder: "Describe the refactor you are proposing." + validations: + required: true + - type: textarea + id: motivation + attributes: + label: Motivation + placeholder: "Explain why this refactor is necessary." + validations: + required: false + - type: textarea + id: additional-context + attributes: + label: Additional Context + placeholder: "Add any other context or screenshots about the request here." + validations: + required: false diff --git a/.github/actions/setup-uv/action.yml b/.github/actions/setup-uv/action.yml index 0499b44dba..6990f6becf 100644 --- a/.github/actions/setup-uv/action.yml +++ b/.github/actions/setup-uv/action.yml @@ -8,7 +8,7 @@ inputs: uv-version: description: UV version to set up required: true - default: '~=0.7.11' + default: '0.8.9' uv-lockfile: description: Path to the UV lockfile to restore cache from required: true @@ -26,7 +26,7 @@ runs: python-version: ${{ inputs.python-version }} - name: Install uv - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v6 with: version: ${{ inputs.uv-version }} python-version: ${{ inputs.python-version }} diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index f4a5f754e0..aa5a50918a 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,8 +1,8 @@ > [!IMPORTANT] > > 1. Make sure you have read our [contribution guidelines](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) -> 2. Ensure there is an associated issue and you have been assigned to it -> 3. Use the correct syntax to link this PR: `Fixes #`. +> 1. Ensure there is an associated issue and you have been assigned to it +> 1. Use the correct syntax to link this PR: `Fixes #`. ## Summary @@ -12,7 +12,7 @@ | Before | After | |--------|-------| -| ... | ... | +| ... | ... | ## Checklist diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index a5a5071fae..63d681e7ed 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -47,7 +47,16 @@ jobs: - name: Run Unit tests run: | uv run --project api bash dev/pytest/pytest_unit_tests.sh - + - name: Run ty check + run: | + cd api + uv add --dev ty + uv run ty check || true + - name: Run pyrefly check + run: | + cd api + uv add --dev pyrefly + uv run pyrefly check || true - name: Coverage Summary run: | set -x @@ -99,3 +108,6 @@ jobs: - name: Run Tool run: uv run --project api bash dev/pytest/pytest_tools.sh + + - name: Run TestContainers + run: uv run --project api bash dev/pytest/pytest_testcontainers.sh diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 5e290c5d02..dada6229db 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -9,6 +9,7 @@ permissions: jobs: autofix: + if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -22,6 +23,10 @@ jobs: uv run ruff check --fix-only . # Format code uv run ruff format . - + - name: ast-grep + run: | + uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all + - name: mdformat + run: | + uvx mdformat . - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 - diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index b933560a5e..17af047267 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -7,6 +7,7 @@ on: - "deploy/dev" - "deploy/enterprise" - "build/**" + - "release/e-*" tags: - "*" diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index a283f8d5ca..9aad9558b0 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -49,8 +49,8 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: | uv run --directory api ruff --version - uv run --directory api ruff check --diff ./ - uv run --directory api ruff format --check --diff ./ + uv run --directory api ruff check ./ + uv run --directory api ruff format --check ./ - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' @@ -82,7 +82,7 @@ jobs: - name: Install pnpm uses: pnpm/action-setup@v4 with: - version: 10 + package_json_file: web/package.json run_install: false - name: Setup NodeJS @@ -95,10 +95,12 @@ jobs: - name: Web dependencies if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web run: pnpm install --frozen-lockfile - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web run: pnpm run lint docker-compose-template: diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index c79d58563f..c004836808 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -1,13 +1,18 @@ name: Check i18n Files and Create PR on: - pull_request: - types: [closed] + push: branches: [main] + paths: + - 'web/i18n/en-US/*.ts' + +permissions: + contents: write + pull-requests: write jobs: check-and-update: - if: github.event.pull_request.merged == true + if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest defaults: run: @@ -15,8 +20,8 @@ jobs: steps: - uses: actions/checkout@v4 with: - fetch-depth: 2 # last 2 commits - persist-credentials: false + fetch-depth: 2 + token: ${{ secrets.GITHUB_TOKEN }} - name: Check for file changes in i18n/en-US id: check_files @@ -27,6 +32,13 @@ jobs: echo "Changed files: $changed_files" if [ -n "$changed_files" ]; then echo "FILES_CHANGED=true" >> $GITHUB_ENV + file_args="" + for file in $changed_files; do + filename=$(basename "$file" .ts) + file_args="$file_args --file=$filename" + done + echo "FILE_ARGS=$file_args" >> $GITHUB_ENV + echo "File arguments: $file_args" else echo "FILES_CHANGED=false" >> $GITHUB_ENV fi @@ -34,7 +46,7 @@ jobs: - name: Install pnpm uses: pnpm/action-setup@v4 with: - version: 10 + package_json_file: web/package.json run_install: false - name: Set up Node.js @@ -47,16 +59,19 @@ jobs: - name: Install dependencies if: env.FILES_CHANGED == 'true' + working-directory: ./web run: pnpm install --frozen-lockfile - - name: Run npm script + - name: Generate i18n translations if: env.FILES_CHANGED == 'true' - run: pnpm run auto-gen-i18n + working-directory: ./web + run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }} - name: Create Pull Request if: env.FILES_CHANGED == 'true' uses: peter-evans/create-pull-request@v6 with: + token: ${{ secrets.GITHUB_TOKEN }} commit-message: Update i18n files based on en-US changes title: 'chore: translate i18n files' body: This PR was automatically created to update i18n files based on changes in en-US locale. diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index c3f8fdbaf6..d104d69947 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -35,7 +35,7 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' uses: pnpm/action-setup@v4 with: - version: 10 + package_json_file: web/package.json run_install: false - name: Setup Node.js @@ -48,8 +48,10 @@ jobs: - name: Install dependencies if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web run: pnpm install --frozen-lockfile - name: Run tests if: steps.changed-files.outputs.any_changed == 'true' + working-directory: ./web run: pnpm test diff --git a/.gitignore b/.gitignore index dd4673a3d2..30432c4302 100644 --- a/.gitignore +++ b/.gitignore @@ -197,6 +197,8 @@ sdks/python-client/dify_client.egg-info !.vscode/README.md pyrightconfig.json api/.vscode +# vscode Code History Extension +.history .idea/ @@ -215,3 +217,4 @@ mise.toml # AI Assistant .roo/ api/.env.backup +/clickzetta diff --git a/.vscode/README.md b/.vscode/README.md index 26516f0540..87b45787c3 100644 --- a/.vscode/README.md +++ b/.vscode/README.md @@ -4,10 +4,10 @@ This `launch.json.template` file provides various debug configurations for the D ## How to Use -1. **Create `launch.json`**: If you don't have one, create a file named `launch.json` inside the `.vscode` directory. -2. **Copy Content**: Copy the entire content from `launch.json.template` into your newly created `launch.json` file. -3. **Select Debug Configuration**: Go to the Run and Debug view in VS Code / Cursor (Ctrl+Shift+D or Cmd+Shift+D). -4. **Start Debugging**: Select the desired configuration from the dropdown menu and click the green play button. +1. **Create `launch.json`**: If you don't have one, create a file named `launch.json` inside the `.vscode` directory. +1. **Copy Content**: Copy the entire content from `launch.json.template` into your newly created `launch.json` file. +1. **Select Debug Configuration**: Go to the Run and Debug view in VS Code / Cursor (Ctrl+Shift+D or Cmd+Shift+D). +1. **Start Debugging**: Select the desired configuration from the dropdown menu and click the green play button. ## Tips diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..fd437d7bf0 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,88 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management. + +The codebase consists of: + +- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture +- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19 +- **Docker deployment** (`/docker`): Containerized deployment configurations + +## Development Commands + +### Backend (API) + +All Python commands must be prefixed with `uv run --project api`: + +```bash +# Start development servers +./dev/start-api # Start API server +./dev/start-worker # Start Celery worker + +# Run tests +uv run --project api pytest # Run all tests +uv run --project api pytest tests/unit_tests/ # Unit tests only +uv run --project api pytest tests/integration_tests/ # Integration tests + +# Code quality +./dev/reformat # Run all formatters and linters +uv run --project api ruff check --fix ./ # Fix linting issues +uv run --project api ruff format ./ # Format code +uv run --project api mypy . # Type checking +``` + +### Frontend (Web) + +```bash +cd web +pnpm lint # Run ESLint +pnpm eslint-fix # Fix ESLint issues +pnpm test # Run Jest tests +``` + +## Testing Guidelines + +### Backend Testing + +- Use `pytest` for all backend tests +- Write tests first (TDD approach) +- Test structure: Arrange-Act-Assert + +## Code Style Requirements + +### Python + +- Use type hints for all functions and class attributes +- No `Any` types unless absolutely necessary +- Implement special methods (`__repr__`, `__str__`) appropriately + +### TypeScript/JavaScript + +- Strict TypeScript configuration +- ESLint with Prettier integration +- Avoid `any` type + +## Important Notes + +- **Environment Variables**: Always use UV for Python commands: `uv run --project api ` +- **Comments**: Only write meaningful comments that explain "why", not "what" +- **File Creation**: Always prefer editing existing files over creating new ones +- **Documentation**: Don't create documentation files unless explicitly requested +- **Code Quality**: Always run `./dev/reformat` before committing backend changes + +## Common Development Tasks + +### Adding a New API Endpoint + +1. Create controller in `/api/controllers/` +1. Add service logic in `/api/services/` +1. Update routes in controller's `__init__.py` +1. Write tests in `/api/tests/` + +## Project-Specific Conventions + +- All async tasks use Celery with Redis as broker diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5d4ba36485..fdc414b047 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -34,11 +34,11 @@ Don't forget to link an existing issue or open a new issue in the PR's descripti How we prioritize: - | Issue Type | Priority | - | ------------------------------------------------------------ | --------------- | - | Bugs in core functions (cloud service, cannot login, applications not working, security loopholes) | Critical | - | Non-critical bugs, performance boosts | Medium Priority | - | Minor fixes (typos, confusing but working UI) | Low Priority | +| Issue Type | Priority | +| ------------------------------------------------------------ | --------------- | +| Bugs in core functions (cloud service, cannot login, applications not working, security loopholes) | Critical | +| Non-critical bugs, performance boosts | Medium Priority | +| Minor fixes (typos, confusing but working UI) | Low Priority | ### Feature requests @@ -52,23 +52,25 @@ How we prioritize: How we prioritize: - | Feature Type | Priority | - | ------------------------------------------------------------ | --------------- | - | High-Priority Features as being labeled by a team member | High Priority | - | Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority | - | Non-core features and minor enhancements | Low Priority | - | Valuable but not immediate | Future-Feature | +| Feature Type | Priority | +| ------------------------------------------------------------ | --------------- | +| High-Priority Features as being labeled by a team member | High Priority | +| Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority | +| Non-core features and minor enhancements | Low Priority | +| Valuable but not immediate | Future-Feature | + ## Submitting your PR ### Pull Request Process 1. Fork the repository -2. Before you draft a PR, please create an issue to discuss the changes you want to make -3. Create a new branch for your changes -4. Please add tests for your changes accordingly -5. Ensure your code passes the existing tests -6. Please link the issue in the PR description, `fixes #` -7. Get merged! +1. Before you draft a PR, please create an issue to discuss the changes you want to make +1. Create a new branch for your changes +1. Please add tests for your changes accordingly +1. Ensure your code passes the existing tests +1. Please link the issue in the PR description, `fixes #` +1. Get merged! + ### Setup the project #### Frontend @@ -82,12 +84,14 @@ For setting up the backend service, kindly refer to our detailed [instructions]( #### Other things to note We recommend reviewing this document carefully before proceeding with the setup, as it contains essential information about: + - Prerequisites and dependencies - Installation steps - Configuration details - Common troubleshooting tips Feel free to reach out if you encounter any issues during the setup process. + ## Getting Help If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 69ae7071bb..c278c8fd7a 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -34,12 +34,11 @@ 优先级划分: - | 问题类型 | 优先级 | - | -------------------------------------------------- | ---------- | - | 核心功能 bug(云服务、登录失败、应用无法使用、安全漏洞) | 紧急 | - | 非关键 bug、性能优化 | 中等优先级 | - | 小修复(拼写错误、界面混乱但可用) | 低优先级 | - +| 问题类型 | 优先级 | +| -------------------------------------------------- | ---------- | +| 核心功能 bug(云服务、登录失败、应用无法使用、安全漏洞) | 紧急 | +| 非关键 bug、性能优化 | 中等优先级 | +| 小修复(拼写错误、界面混乱但可用) | 低优先级 | ### 功能请求 @@ -53,12 +52,12 @@ 优先级划分: - | 功能类型 | 优先级 | - | -------------------------------------------------- | ---------- | - | 被团队成员标记为高优先级的功能 | 高优先级 | - | 来自[社区反馈板](https://github.com/langgenius/dify/discussions/categories/feedbacks)的热门功能请求 | 中等优先级 | - | 非核心功能和小改进 | 低优先级 | - | 有价值但非紧急的功能 | 未来特性 | +| 功能类型 | 优先级 | +| -------------------------------------------------- | ---------- | +| 被团队成员标记为高优先级的功能 | 高优先级 | +| 来自[社区反馈板](https://github.com/langgenius/dify/discussions/categories/feedbacks)的热门功能请求 | 中等优先级 | +| 非核心功能和小改进 | 低优先级 | +| 有价值但非紧急的功能 | 未来特性 | ## 提交 PR @@ -67,12 +66,12 @@ ### PR 提交流程 1. Fork 本仓库 -2. 在提交 PR 之前,请先创建 issue 讨论你想要做的修改 -3. 为你的修改创建一个新的分支 -4. 请为你的修改添加相应的测试 -5. 确保你的代码能通过现有的测试 -6. 请在 PR 描述中关联相关 issue,格式为 `fixes #` -7. 等待合并! +1. 在提交 PR 之前,请先创建 issue 讨论你想要做的修改 +1. 为你的修改创建一个新的分支 +1. 请为你的修改添加相应的测试 +1. 确保你的代码能通过现有的测试 +1. 请在 PR 描述中关联相关 issue,格式为 `fixes #` +1. 等待合并! #### 前端 @@ -85,6 +84,7 @@ #### 其他注意事项 我们建议在开始设置之前仔细阅读本文档,因为它包含以下重要信息: + - 前置条件和依赖项 - 安装步骤 - 配置细节 diff --git a/CONTRIBUTING_DE.md b/CONTRIBUTING_DE.md index ddbf3abc55..f819e80bbb 100644 --- a/CONTRIBUTING_DE.md +++ b/CONTRIBUTING_DE.md @@ -32,11 +32,11 @@ Vergessen Sie nicht, in der PR-Beschreibung ein bestehendes Issue zu verlinken o Unsere Priorisierung: - | Fehlertyp | Priorität | - | ------------------------------------------------------------ | --------------- | - | Fehler in Kernfunktionen (Cloud-Service, Login nicht möglich, Anwendungen funktionieren nicht, Sicherheitslücken) | Kritisch | - | Nicht-kritische Fehler, Leistungsverbesserungen | Mittlere Priorität | - | Kleinere Korrekturen (Tippfehler, verwirrende aber funktionierende UI) | Niedrige Priorität | +| Fehlertyp | Priorität | +| ------------------------------------------------------------ | --------------- | +| Fehler in Kernfunktionen (Cloud-Service, Login nicht möglich, Anwendungen funktionieren nicht, Sicherheitslücken) | Kritisch | +| Nicht-kritische Fehler, Leistungsverbesserungen | Mittlere Priorität | +| Kleinere Korrekturen (Tippfehler, verwirrende aber funktionierende UI) | Niedrige Priorität | ### Feature-Anfragen @@ -50,24 +50,24 @@ Unsere Priorisierung: Unsere Priorisierung: - | Feature-Typ | Priorität | - | ------------------------------------------------------------ | --------------- | - | Hochprioritäre Features (durch Teammitglied gekennzeichnet) | Hohe Priorität | - | Beliebte Feature-Anfragen aus unserem [Community-Feedback-Board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Mittlere Priorität | - | Nicht-Kernfunktionen und kleinere Verbesserungen | Niedrige Priorität | - | Wertvoll, aber nicht dringend | Zukunfts-Feature | +| Feature-Typ | Priorität | +| ------------------------------------------------------------ | --------------- | +| Hochprioritäre Features (durch Teammitglied gekennzeichnet) | Hohe Priorität | +| Beliebte Feature-Anfragen aus unserem [Community-Feedback-Board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Mittlere Priorität | +| Nicht-Kernfunktionen und kleinere Verbesserungen | Niedrige Priorität | +| Wertvoll, aber nicht dringend | Zukunfts-Feature | ## Einreichen Ihres PRs ### Pull-Request-Prozess 1. Repository forken -2. Vor dem Erstellen eines PRs bitte ein Issue zur Diskussion der Änderungen erstellen -3. Einen neuen Branch für Ihre Änderungen erstellen -4. Tests für Ihre Änderungen hinzufügen -5. Sicherstellen, dass Ihr Code die bestehenden Tests besteht -6. Issue in der PR-Beschreibung verlinken (`fixes #`) -7. Auf den Merge warten! +1. Vor dem Erstellen eines PRs bitte ein Issue zur Diskussion der Änderungen erstellen +1. Einen neuen Branch für Ihre Änderungen erstellen +1. Tests für Ihre Änderungen hinzufügen +1. Sicherstellen, dass Ihr Code die bestehenden Tests besteht +1. Issue in der PR-Beschreibung verlinken (`fixes #`) +1. Auf den Merge warten! ### Projekt einrichten @@ -82,6 +82,7 @@ Für die Einrichtung des Backend-Service folgen Sie bitte unseren detaillierten #### Weitere Hinweise Wir empfehlen, dieses Dokument sorgfältig zu lesen, da es wichtige Informationen enthält über: + - Voraussetzungen und Abhängigkeiten - Installationsschritte - Konfigurationsdetails @@ -92,4 +93,3 @@ Bei Problemen während der Einrichtung können Sie sich gerne an uns wenden. ## Hilfe bekommen Wenn Sie beim Mitwirken Fragen haben oder nicht weiterkommen, stellen Sie Ihre Fragen einfach im entsprechenden GitHub Issue oder besuchen Sie unseren [Discord](https://discord.gg/8Tpq4AcN9c) für einen schnellen Austausch. - diff --git a/CONTRIBUTING_ES.md b/CONTRIBUTING_ES.md index 98cbb5b457..e19d958c65 100644 --- a/CONTRIBUTING_ES.md +++ b/CONTRIBUTING_ES.md @@ -34,11 +34,11 @@ No olvides vincular un issue existente o abrir uno nuevo en la descripción del Cómo priorizamos: - | Tipo de Issue | Prioridad | - | ------------------------------------------------------------ | --------------- | - | Errores en funciones principales (servicio en la nube, no poder iniciar sesión, aplicaciones que no funcionan, fallos de seguridad) | Crítica | - | Errores no críticos, mejoras de rendimiento | Prioridad Media | - | Correcciones menores (errores tipográficos, UI confusa pero funcional) | Prioridad Baja | +| Tipo de Issue | Prioridad | +| ------------------------------------------------------------ | --------------- | +| Errores en funciones principales (servicio en la nube, no poder iniciar sesión, aplicaciones que no funcionan, fallos de seguridad) | Crítica | +| Errores no críticos, mejoras de rendimiento | Prioridad Media | +| Correcciones menores (errores tipográficos, UI confusa pero funcional) | Prioridad Baja | ### Solicitudes de funcionalidades @@ -52,23 +52,25 @@ Cómo priorizamos: Cómo priorizamos: - | Tipo de Funcionalidad | Prioridad | - | ------------------------------------------------------------ | --------------- | - | Funcionalidades de alta prioridad etiquetadas por un miembro del equipo | Prioridad Alta | - | Solicitudes populares de funcionalidades de nuestro [tablero de comentarios de la comunidad](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Prioridad Media | - | Funcionalidades no principales y mejoras menores | Prioridad Baja | - | Valiosas pero no inmediatas | Futura-Funcionalidad | +| Tipo de Funcionalidad | Prioridad | +| ------------------------------------------------------------ | --------------- | +| Funcionalidades de alta prioridad etiquetadas por un miembro del equipo | Prioridad Alta | +| Solicitudes populares de funcionalidades de nuestro [tablero de comentarios de la comunidad](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Prioridad Media | +| Funcionalidades no principales y mejoras menores | Prioridad Baja | +| Valiosas pero no inmediatas | Futura-Funcionalidad | + ## Enviando tu PR ### Proceso de Pull Request 1. Haz un fork del repositorio -2. Antes de redactar un PR, por favor crea un issue para discutir los cambios que quieres hacer -3. Crea una nueva rama para tus cambios -4. Por favor añade pruebas para tus cambios en consecuencia -5. Asegúrate de que tu código pasa las pruebas existentes -6. Por favor vincula el issue en la descripción del PR, `fixes #` -7. ¡Fusiona tu código! +1. Antes de redactar un PR, por favor crea un issue para discutir los cambios que quieres hacer +1. Crea una nueva rama para tus cambios +1. Por favor añade pruebas para tus cambios en consecuencia +1. Asegúrate de que tu código pasa las pruebas existentes +1. Por favor vincula el issue en la descripción del PR, `fixes #` +1. ¡Fusiona tu código! + ### Configuración del proyecto #### Frontend @@ -82,12 +84,14 @@ Para configurar el servicio backend, por favor consulta nuestras [instrucciones #### Otras cosas a tener en cuenta Recomendamos revisar este documento cuidadosamente antes de proceder con la configuración, ya que contiene información esencial sobre: + - Requisitos previos y dependencias - Pasos de instalación - Detalles de configuración - Consejos comunes de solución de problemas No dudes en contactarnos si encuentras algún problema durante el proceso de configuración. + ## Obteniendo Ayuda -Si alguna vez te quedas atascado o tienes una pregunta urgente mientras contribuyes, simplemente envíanos tus consultas a través del issue relacionado de GitHub, o únete a nuestro [Discord](https://discord.gg/8Tpq4AcN9c) para una charla rápida. +Si alguna vez te quedas atascado o tienes una pregunta urgente mientras contribuyes, simplemente envíanos tus consultas a través del issue relacionado de GitHub, o únete a nuestro [Discord](https://discord.gg/8Tpq4AcN9c) para una charla rápida. diff --git a/CONTRIBUTING_FR.md b/CONTRIBUTING_FR.md index fc8410dfd6..335e943fcd 100644 --- a/CONTRIBUTING_FR.md +++ b/CONTRIBUTING_FR.md @@ -34,11 +34,11 @@ N'oubliez pas de lier un problème existant ou d'ouvrir un nouveau problème dan Comment nous priorisons : - | Type de Problème | Priorité | - | ------------------------------------------------------------ | --------------- | - | Bugs dans les fonctions principales (service cloud, impossibilité de se connecter, applications qui ne fonctionnent pas, failles de sécurité) | Critique | - | Bugs non critiques, améliorations de performance | Priorité Moyenne | - | Corrections mineures (fautes de frappe, UI confuse mais fonctionnelle) | Priorité Basse | +| Type de Problème | Priorité | +| ------------------------------------------------------------ | --------------- | +| Bugs dans les fonctions principales (service cloud, impossibilité de se connecter, applications qui ne fonctionnent pas, failles de sécurité) | Critique | +| Bugs non critiques, améliorations de performance | Priorité Moyenne | +| Corrections mineures (fautes de frappe, UI confuse mais fonctionnelle) | Priorité Basse | ### Demandes de fonctionnalités @@ -52,23 +52,25 @@ Comment nous priorisons : Comment nous priorisons : - | Type de Fonctionnalité | Priorité | - | ------------------------------------------------------------ | --------------- | - | Fonctionnalités hautement prioritaires étiquetées par un membre de l'équipe | Priorité Haute | - | Demandes populaires de fonctionnalités de notre [tableau de feedback communautaire](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Priorité Moyenne | - | Fonctionnalités non essentielles et améliorations mineures | Priorité Basse | - | Précieuses mais non immédiates | Fonctionnalité Future | +| Type de Fonctionnalité | Priorité | +| ------------------------------------------------------------ | --------------- | +| Fonctionnalités hautement prioritaires étiquetées par un membre de l'équipe | Priorité Haute | +| Demandes populaires de fonctionnalités de notre [tableau de feedback communautaire](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Priorité Moyenne | +| Fonctionnalités non essentielles et améliorations mineures | Priorité Basse | +| Précieuses mais non immédiates | Fonctionnalité Future | + ## Soumettre votre PR ### Processus de Pull Request 1. Forkez le dépôt -2. Avant de rédiger une PR, veuillez créer un problème pour discuter des changements que vous souhaitez apporter -3. Créez une nouvelle branche pour vos changements -4. Veuillez ajouter des tests pour vos changements en conséquence -5. Assurez-vous que votre code passe les tests existants -6. Veuillez lier le problème dans la description de la PR, `fixes #` -7. Faites fusionner votre code ! +1. Avant de rédiger une PR, veuillez créer un problème pour discuter des changements que vous souhaitez apporter +1. Créez une nouvelle branche pour vos changements +1. Veuillez ajouter des tests pour vos changements en conséquence +1. Assurez-vous que votre code passe les tests existants +1. Veuillez lier le problème dans la description de la PR, `fixes #` +1. Faites fusionner votre code ! + ### Configuration du projet #### Frontend @@ -82,12 +84,14 @@ Pour configurer le service backend, veuillez consulter nos [instructions détail #### Autres choses à noter Nous recommandons de revoir attentivement ce document avant de procéder à la configuration, car il contient des informations essentielles sur : + - Prérequis et dépendances - Étapes d'installation - Détails de configuration - Conseils courants de dépannage N'hésitez pas à nous contacter si vous rencontrez des problèmes pendant le processus de configuration. + ## Obtenir de l'aide -Si jamais vous êtes bloqué ou avez une question urgente en contribuant, envoyez-nous simplement vos questions via le problème GitHub concerné, ou rejoignez notre [Discord](https://discord.gg/8Tpq4AcN9c) pour une discussion rapide. +Si jamais vous êtes bloqué ou avez une question urgente en contribuant, envoyez-nous simplement vos questions via le problème GitHub concerné, ou rejoignez notre [Discord](https://discord.gg/8Tpq4AcN9c) pour une discussion rapide. diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index e991d0263e..2d0d79fc16 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -34,11 +34,11 @@ PRの説明には、既存のイシューへのリンクを含めるか、新し 優先順位の付け方: - | 問題の種類 | 優先度 | - | ------------------------------------------------------------ | --------- | - | コア機能のバグ(クラウドサービス、ログイン不可、アプリケーション不具合、セキュリティ脆弱性) | 最重要 | - | 重要度の低いバグ、パフォーマンス改善 | 中程度 | - | 軽微な修正(タイプミス、分かりにくいが動作するUI) | 低 | +| 問題の種類 | 優先度 | +| ------------------------------------------------------------ | --------- | +| コア機能のバグ(クラウドサービス、ログイン不可、アプリケーション不具合、セキュリティ脆弱性) | 最重要 | +| 重要度の低いバグ、パフォーマンス改善 | 中程度 | +| 軽微な修正(タイプミス、分かりにくいが動作するUI) | 低 | ### 機能リクエスト @@ -52,24 +52,24 @@ PRの説明には、既存のイシューへのリンクを含めるか、新し 優先順位の付け方: - | 機能の種類 | 優先度 | - | ------------------------------------------------------------ | --------- | - | チームメンバーによって高優先度とラベル付けされた機能 | 高 | - | [コミュニティフィードボード](https://github.com/langgenius/dify/discussions/categories/feedbacks)での人気の機能リクエスト | 中程度 | - | 非コア機能と軽微な改善 | 低 | - | 価値はあるが緊急性の低いもの | 将来対応 | +| 機能の種類 | 優先度 | +| ------------------------------------------------------------ | --------- | +| チームメンバーによって高優先度とラベル付けされた機能 | 高 | +| [コミュニティフィードボード](https://github.com/langgenius/dify/discussions/categories/feedbacks)での人気の機能リクエスト | 中程度 | +| 非コア機能と軽微な改善 | 低 | +| 価値はあるが緊急性の低いもの | 将来対応 | ## PRの提出 ### プルリクエストのプロセス 1. リポジトリをフォークする -2. PRを作成する前に、変更内容についてイシューで議論する -3. 変更用の新しいブランチを作成する -4. 変更に応じたテストを追加する -5. 既存のテストをパスすることを確認する -6. PRの説明文にイシューをリンクする(`fixes #`) -7. マージ完了! +1. PRを作成する前に、変更内容についてイシューで議論する +1. 変更用の新しいブランチを作成する +1. 変更に応じたテストを追加する +1. 既存のテストをパスすることを確認する +1. PRの説明文にイシューをリンクする(`fixes #`) +1. マージ完了! ### プロジェクトのセットアップ @@ -84,6 +84,7 @@ PRの説明には、既存のイシューへのリンクを含めるか、新し #### その他の注意点 セットアップを進める前に、以下の重要な情報が含まれているため、このドキュメントを注意深く確認することをお勧めします: + - 前提条件と依存関係 - インストール手順 - 設定の詳細 @@ -94,4 +95,3 @@ PRの説明には、既存のイシューへのリンクを含めるか、新し ## サポートを受ける 貢献中に行き詰まったり、緊急の質問がある場合は、関連するGitHubイシューで質問するか、[Discord](https://discord.gg/8Tpq4AcN9c)で気軽にチャットしてください。 - diff --git a/CONTRIBUTING_KR.md b/CONTRIBUTING_KR.md index 78d3f38c47..14b1c9a9ca 100644 --- a/CONTRIBUTING_KR.md +++ b/CONTRIBUTING_KR.md @@ -34,11 +34,11 @@ PR 설명에 기존 이슈를 연결하거나 새 이슈를 여는 것을 잊지 우선순위 결정 방법: - | 이슈 유형 | 우선순위 | - | ------------------------------------------------------------ | --------------- | - | 핵심 기능의 버그(클라우드 서비스, 로그인 불가, 애플리케이션 작동 불능, 보안 취약점) | 중대 | - | 비중요 버그, 성능 향상 | 중간 우선순위 | - | 사소한 수정(오타, 혼란스럽지만 작동하는 UI) | 낮은 우선순위 | +| 이슈 유형 | 우선순위 | +| ------------------------------------------------------------ | --------------- | +| 핵심 기능의 버그(클라우드 서비스, 로그인 불가, 애플리케이션 작동 불능, 보안 취약점) | 중대 | +| 비중요 버그, 성능 향상 | 중간 우선순위 | +| 사소한 수정(오타, 혼란스럽지만 작동하는 UI) | 낮은 우선순위 | ### 기능 요청 @@ -52,23 +52,25 @@ PR 설명에 기존 이슈를 연결하거나 새 이슈를 여는 것을 잊지 우선순위 결정 방법: - | 기능 유형 | 우선순위 | - | ------------------------------------------------------------ | --------------- | - | 팀 구성원에 의해 레이블이 지정된 고우선순위 기능 | 높은 우선순위 | - | 우리의 [커뮤니티 피드백 보드](https://github.com/langgenius/dify/discussions/categories/feedbacks)에서 인기 있는 기능 요청 | 중간 우선순위 | - | 비핵심 기능 및 사소한 개선 | 낮은 우선순위 | - | 가치 있지만 즉시 필요하지 않은 기능 | 미래 기능 | +| 기능 유형 | 우선순위 | +| ------------------------------------------------------------ | --------------- | +| 팀 구성원에 의해 레이블이 지정된 고우선순위 기능 | 높은 우선순위 | +| 우리의 [커뮤니티 피드백 보드](https://github.com/langgenius/dify/discussions/categories/feedbacks)에서 인기 있는 기능 요청 | 중간 우선순위 | +| 비핵심 기능 및 사소한 개선 | 낮은 우선순위 | +| 가치 있지만 즉시 필요하지 않은 기능 | 미래 기능 | + ## PR 제출하기 ### Pull Request 프로세스 1. 저장소를 포크하세요 -2. PR을 작성하기 전에, 변경하고자 하는 내용에 대해 논의하기 위한 이슈를 생성해 주세요 -3. 변경 사항을 위한 새 브랜치를 만드세요 -4. 변경 사항에 대한 테스트를 적절히 추가해 주세요 -5. 코드가 기존 테스트를 통과하는지 확인하세요 -6. PR 설명에 이슈를 연결해 주세요, `fixes #<이슈_번호>` -7. 병합 완료! +1. PR을 작성하기 전에, 변경하고자 하는 내용에 대해 논의하기 위한 이슈를 생성해 주세요 +1. 변경 사항을 위한 새 브랜치를 만드세요 +1. 변경 사항에 대한 테스트를 적절히 추가해 주세요 +1. 코드가 기존 테스트를 통과하는지 확인하세요 +1. PR 설명에 이슈를 연결해 주세요, `fixes #<이슈_번호>` +1. 병합 완료! + ### 프로젝트 설정하기 #### 프론트엔드 @@ -82,12 +84,14 @@ PR 설명에 기존 이슈를 연결하거나 새 이슈를 여는 것을 잊지 #### 기타 참고 사항 설정을 진행하기 전에 이 문서를 주의 깊게 검토하는 것을 권장합니다. 다음과 같은 필수 정보가 포함되어 있습니다: + - 필수 조건 및 종속성 - 설치 단계 - 구성 세부 정보 - 일반적인 문제 해결 팁 설정 과정에서 문제가 발생하면 언제든지 연락해 주세요. + ## 도움 받기 -기여하는 동안 막히거나 긴급한 질문이 있으면, 관련 GitHub 이슈를 통해 질문을 보내거나, 빠른 대화를 위해 우리의 [Discord](https://discord.gg/8Tpq4AcN9c)에 참여하세요. +기여하는 동안 막히거나 긴급한 질문이 있으면, 관련 GitHub 이슈를 통해 질문을 보내거나, 빠른 대화를 위해 우리의 [Discord](https://discord.gg/8Tpq4AcN9c)에 참여하세요. diff --git a/CONTRIBUTING_PT.md b/CONTRIBUTING_PT.md index 7347fd7f9c..aeabcad51f 100644 --- a/CONTRIBUTING_PT.md +++ b/CONTRIBUTING_PT.md @@ -34,11 +34,11 @@ Não se esqueça de vincular um problema existente ou abrir um novo problema na Como priorizamos: - | Tipo de Problema | Prioridade | - | ------------------------------------------------------------ | --------------- | - | Bugs em funções centrais (serviço em nuvem, não conseguir fazer login, aplicações não funcionando, falhas de segurança) | Crítica | - | Bugs não críticos, melhorias de desempenho | Prioridade Média | - | Correções menores (erros de digitação, interface confusa mas funcional) | Prioridade Baixa | +| Tipo de Problema | Prioridade | +| ------------------------------------------------------------ | --------------- | +| Bugs em funções centrais (serviço em nuvem, não conseguir fazer login, aplicações não funcionando, falhas de segurança) | Crítica | +| Bugs não críticos, melhorias de desempenho | Prioridade Média | +| Correções menores (erros de digitação, interface confusa mas funcional) | Prioridade Baixa | ### Solicitações de recursos @@ -52,23 +52,25 @@ Como priorizamos: Como priorizamos: - | Tipo de Recurso | Prioridade | - | ------------------------------------------------------------ | --------------- | - | Recursos de alta prioridade conforme rotulado por um membro da equipe | Prioridade Alta | - | Solicitações populares de recursos do nosso [quadro de feedback da comunidade](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Prioridade Média | - | Recursos não essenciais e melhorias menores | Prioridade Baixa | - | Valiosos mas não imediatos | Recurso Futuro | +| Tipo de Recurso | Prioridade | +| ------------------------------------------------------------ | --------------- | +| Recursos de alta prioridade conforme rotulado por um membro da equipe | Prioridade Alta | +| Solicitações populares de recursos do nosso [quadro de feedback da comunidade](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Prioridade Média | +| Recursos não essenciais e melhorias menores | Prioridade Baixa | +| Valiosos mas não imediatos | Recurso Futuro | + ## Enviando seu PR ### Processo de Pull Request 1. Faça um fork do repositório -2. Antes de elaborar um PR, por favor crie um problema para discutir as mudanças que você quer fazer -3. Crie um novo branch para suas alterações -4. Por favor, adicione testes para suas alterações conforme apropriado -5. Certifique-se de que seu código passa nos testes existentes -6. Por favor, vincule o problema na descrição do PR, `fixes #` -7. Faça o merge do seu código! +1. Antes de elaborar um PR, por favor crie um problema para discutir as mudanças que você quer fazer +1. Crie um novo branch para suas alterações +1. Por favor, adicione testes para suas alterações conforme apropriado +1. Certifique-se de que seu código passa nos testes existentes +1. Por favor, vincule o problema na descrição do PR, `fixes #` +1. Faça o merge do seu código! + ### Configurando o projeto #### Frontend @@ -82,12 +84,14 @@ Para configurar o serviço backend, por favor consulte nossas [instruções deta #### Outras coisas a observar Recomendamos revisar este documento cuidadosamente antes de prosseguir com a configuração, pois ele contém informações essenciais sobre: + - Pré-requisitos e dependências - Etapas de instalação - Detalhes de configuração - Dicas comuns de solução de problemas Sinta-se à vontade para entrar em contato se encontrar quaisquer problemas durante o processo de configuração. + ## Obtendo Ajuda -Se você ficar preso ou tiver uma dúvida urgente enquanto contribui, simplesmente envie suas perguntas através do problema relacionado no GitHub, ou entre no nosso [Discord](https://discord.gg/8Tpq4AcN9c) para uma conversa rápida. +Se você ficar preso ou tiver uma dúvida urgente enquanto contribui, simplesmente envie suas perguntas através do problema relacionado no GitHub, ou entre no nosso [Discord](https://discord.gg/8Tpq4AcN9c) para uma conversa rápida. diff --git a/CONTRIBUTING_TR.md b/CONTRIBUTING_TR.md index 681f05689b..d016802a53 100644 --- a/CONTRIBUTING_TR.md +++ b/CONTRIBUTING_TR.md @@ -34,11 +34,11 @@ PR açıklamasında mevcut bir sorunu bağlamayı veya yeni bir sorun açmayı u Nasıl önceliklendiriyoruz: - | Sorun Türü | Öncelik | - | ------------------------------------------------------------ | --------------- | - | Temel işlevlerdeki hatalar (bulut hizmeti, giriş yapamama, çalışmayan uygulamalar, güvenlik açıkları) | Kritik | - | Kritik olmayan hatalar, performans artışları | Orta Öncelik | - | Küçük düzeltmeler (yazım hataları, kafa karıştırıcı ama çalışan UI) | Düşük Öncelik | +| Sorun Türü | Öncelik | +| ------------------------------------------------------------ | --------------- | +| Temel işlevlerdeki hatalar (bulut hizmeti, giriş yapamama, çalışmayan uygulamalar, güvenlik açıkları) | Kritik | +| Kritik olmayan hatalar, performans artışları | Orta Öncelik | +| Küçük düzeltmeler (yazım hataları, kafa karıştırıcı ama çalışan UI) | Düşük Öncelik | ### Özellik İstekleri @@ -52,23 +52,25 @@ Nasıl önceliklendiriyoruz: Nasıl önceliklendiriyoruz: - | Özellik Türü | Öncelik | - | ------------------------------------------------------------ | --------------- | - | Bir ekip üyesi tarafından etiketlenen Yüksek Öncelikli Özellikler | Yüksek Öncelik | - | [Topluluk geri bildirim panosundan](https://github.com/langgenius/dify/discussions/categories/feedbacks) popüler özellik istekleri | Orta Öncelik | - | Temel olmayan özellikler ve küçük geliştirmeler | Düşük Öncelik | - | Değerli ama acil olmayan | Gelecek-Özellik | +| Özellik Türü | Öncelik | +| ------------------------------------------------------------ | --------------- | +| Bir ekip üyesi tarafından etiketlenen Yüksek Öncelikli Özellikler | Yüksek Öncelik | +| [Topluluk geri bildirim panosundan](https://github.com/langgenius/dify/discussions/categories/feedbacks) popüler özellik istekleri | Orta Öncelik | +| Temel olmayan özellikler ve küçük geliştirmeler | Düşük Öncelik | +| Değerli ama acil olmayan | Gelecek-Özellik | + ## PR'nizi Göndermek ### Pull Request Süreci 1. Depoyu fork edin -2. Bir PR taslağı oluşturmadan önce, yapmak istediğiniz değişiklikleri tartışmak için lütfen bir sorun oluşturun -3. Değişiklikleriniz için yeni bir dal oluşturun -4. Lütfen değişiklikleriniz için uygun testler ekleyin -5. Kodunuzun mevcut testleri geçtiğinden emin olun -6. Lütfen PR açıklamasında sorunu bağlayın, `fixes #` -7. Kodunuzu birleştirin! +1. Bir PR taslağı oluşturmadan önce, yapmak istediğiniz değişiklikleri tartışmak için lütfen bir sorun oluşturun +1. Değişiklikleriniz için yeni bir dal oluşturun +1. Lütfen değişiklikleriniz için uygun testler ekleyin +1. Kodunuzun mevcut testleri geçtiğinden emin olun +1. Lütfen PR açıklamasında sorunu bağlayın, `fixes #` +1. Kodunuzu birleştirin! + ### Projeyi Kurma #### Frontend @@ -82,12 +84,14 @@ Backend hizmetini kurmak için, lütfen `api/README.md` dosyasındaki detaylı [ #### Dikkat Edilecek Diğer Şeyler Kuruluma geçmeden önce bu belgeyi dikkatlice incelemenizi öneririz, çünkü şunlar hakkında temel bilgiler içerir: + - Ön koşullar ve bağımlılıklar - Kurulum adımları - Yapılandırma detayları - Yaygın sorun giderme ipuçları Kurulum süreci sırasında herhangi bir sorunla karşılaşırsanız bizimle iletişime geçmekten çekinmeyin. + ## Yardım Almak -Katkıda bulunurken takılırsanız veya yanıcı bir sorunuz olursa, sorularınızı ilgili GitHub sorunu aracılığıyla bize gönderin veya hızlı bir sohbet için [Discord'umuza](https://discord.gg/8Tpq4AcN9c) katılın. +Katkıda bulunurken takılırsanız veya yanıcı bir sorunuz olursa, sorularınızı ilgili GitHub sorunu aracılığıyla bize gönderin veya hızlı bir sohbet için [Discord'umuza](https://discord.gg/8Tpq4AcN9c) katılın. diff --git a/CONTRIBUTING_TW.md b/CONTRIBUTING_TW.md index a61ea918c5..5c4d7022fe 100644 --- a/CONTRIBUTING_TW.md +++ b/CONTRIBUTING_TW.md @@ -22,7 +22,7 @@ ### 錯誤回報 -> [!IMPORTANT] +> [!IMPORTANT]\ > 提交錯誤回報時,請務必包含以下資訊: - 清晰明確的標題 @@ -34,15 +34,15 @@ 優先順序評估: - | 議題類型 | 優先級 | - | -------- | ------ | - | 核心功能錯誤(雲端服務、無法登入、應用程式無法運作、安全漏洞) | 緊急 | - | 非緊急錯誤、效能優化 | 中等 | - | 次要修正(拼字錯誤、介面混淆但可運作) | 低 | +| 議題類型 | 優先級 | +| -------- | ------ | +| 核心功能錯誤(雲端服務、無法登入、應用程式無法運作、安全漏洞) | 緊急 | +| 非緊急錯誤、效能優化 | 中等 | +| 次要修正(拼字錯誤、介面混淆但可運作) | 低 | ### 功能請求 -> [!NOTE] +> [!NOTE]\ > 提交功能請求時,請務必包含以下資訊: - 清晰明確的標題 @@ -52,24 +52,24 @@ 優先順序評估: - | 功能類型 | 優先級 | - | -------- | ------ | - | 團隊成員標記為高優先級的功能 | 高 | - | 來自[社群回饋板](https://github.com/langgenius/dify/discussions/categories/feedbacks)的熱門功能請求 | 中 | - | 非核心功能和小幅改進 | 低 | - | 有價值但非急迫的功能 | 未來功能 | +| 功能類型 | 優先級 | +| -------- | ------ | +| 團隊成員標記為高優先級的功能 | 高 | +| 來自[社群回饋板](https://github.com/langgenius/dify/discussions/categories/feedbacks)的熱門功能請求 | 中 | +| 非核心功能和小幅改進 | 低 | +| 有價值但非急迫的功能 | 未來功能 | ## 提交 PR ### PR 流程 1. Fork 專案 -2. 在開始撰寫 PR 前,請先建立議題討論你想做的更改 -3. 為你的更改建立新分支 -4. 請為你的更改新增相應的測試 -5. 確保你的程式碼通過現有測試 -6. 請在 PR 描述中連結相關議題,使用 `fixes #` -7. 等待合併! +1. 在開始撰寫 PR 前,請先建立議題討論你想做的更改 +1. 為你的更改建立新分支 +1. 請為你的更改新增相應的測試 +1. 確保你的程式碼通過現有測試 +1. 請在 PR 描述中連結相關議題,使用 `fixes #` +1. 等待合併! ### 專案設定 @@ -84,6 +84,7 @@ #### 其他注意事項 我們建議在開始設定前仔細閱讀此文件,因為它包含以下重要資訊: + - 前置需求和相依性 - 安裝步驟 - 設定細節 @@ -94,4 +95,3 @@ ## 尋求協助 如果你在貢獻過程中遇到困難或有急切的問題,可以透過相關的 GitHub 議題詢問,或加入我們的 [Discord](https://discord.gg/8Tpq4AcN9c) 進行即時交流。 - diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md index 807054acce..2ad431296a 100644 --- a/CONTRIBUTING_VI.md +++ b/CONTRIBUTING_VI.md @@ -22,7 +22,7 @@ Hãy tham gia, đóng góp và cùng nhau xây dựng điều tuyệt vời! ### Báo cáo lỗi -> [!QUAN TRỌNG] +> [!QUAN TRỌNG]\ > Vui lòng đảm bảo cung cấp các thông tin sau khi gửi báo cáo lỗi: - Tiêu đề rõ ràng và mô tả @@ -34,11 +34,11 @@ Hãy tham gia, đóng góp và cùng nhau xây dựng điều tuyệt vời! Cách chúng tôi ưu tiên: - | Loại vấn đề | Mức độ ưu tiên | - | ----------- | -------------- | - | Lỗi trong các chức năng cốt lõi (dịch vụ đám mây, không thể đăng nhập, ứng dụng không hoạt động, lỗ hổng bảo mật) | Quan trọng | - | Lỗi không nghiêm trọng, cải thiện hiệu suất | Ưu tiên trung bình | - | Sửa lỗi nhỏ (lỗi chính tả, UI gây nhầm lẫn nhưng vẫn hoạt động) | Ưu tiên thấp | +| Loại vấn đề | Mức độ ưu tiên | +| ----------- | -------------- | +| Lỗi trong các chức năng cốt lõi (dịch vụ đám mây, không thể đăng nhập, ứng dụng không hoạt động, lỗ hổng bảo mật) | Quan trọng | +| Lỗi không nghiêm trọng, cải thiện hiệu suất | Ưu tiên trung bình | +| Sửa lỗi nhỏ (lỗi chính tả, UI gây nhầm lẫn nhưng vẫn hoạt động) | Ưu tiên thấp | ### Yêu cầu tính năng @@ -52,24 +52,24 @@ Cách chúng tôi ưu tiên: Cách chúng tôi ưu tiên: - | Loại tính năng | Mức độ ưu tiên | - | -------------- | -------------- | - | Tính năng ưu tiên cao được gắn nhãn bởi thành viên nhóm | Ưu tiên cao | - | Yêu cầu tính năng phổ biến từ [bảng phản hồi cộng đồng](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Ưu tiên trung bình | - | Tính năng không cốt lõi và cải tiến nhỏ | Ưu tiên thấp | - | Có giá trị nhưng không cấp bách | Tính năng tương lai | +| Loại tính năng | Mức độ ưu tiên | +| -------------- | -------------- | +| Tính năng ưu tiên cao được gắn nhãn bởi thành viên nhóm | Ưu tiên cao | +| Yêu cầu tính năng phổ biến từ [bảng phản hồi cộng đồng](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Ưu tiên trung bình | +| Tính năng không cốt lõi và cải tiến nhỏ | Ưu tiên thấp | +| Có giá trị nhưng không cấp bách | Tính năng tương lai | ## Gửi PR của bạn ### Quy trình tạo Pull Request 1. Fork repository -2. Trước khi soạn PR, vui lòng tạo issue để thảo luận về các thay đổi bạn muốn thực hiện -3. Tạo nhánh mới cho các thay đổi của bạn -4. Vui lòng thêm test cho các thay đổi tương ứng -5. Đảm bảo code của bạn vượt qua các test hiện có -6. Vui lòng liên kết issue trong mô tả PR, `fixes #` -7. Được merge! +1. Trước khi soạn PR, vui lòng tạo issue để thảo luận về các thay đổi bạn muốn thực hiện +1. Tạo nhánh mới cho các thay đổi của bạn +1. Vui lòng thêm test cho các thay đổi tương ứng +1. Đảm bảo code của bạn vượt qua các test hiện có +1. Vui lòng liên kết issue trong mô tả PR, `fixes #` +1. Được merge! ### Thiết lập dự án @@ -84,6 +84,7 @@ Cách chúng tôi ưu tiên: #### Các điểm cần lưu ý khác Chúng tôi khuyến nghị xem xét kỹ tài liệu này trước khi tiến hành thiết lập, vì nó chứa thông tin thiết yếu về: + - Điều kiện tiên quyết và dependencies - Các bước cài đặt - Chi tiết cấu hình @@ -94,4 +95,3 @@ Chúng tôi khuyến nghị xem xét kỹ tài liệu này trước khi tiến h ## Nhận trợ giúp Nếu bạn bị mắc kẹt hoặc có câu hỏi cấp bách trong quá trình đóng góp, chỉ cần gửi câu hỏi của bạn thông qua issue GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh. - diff --git a/README.md b/README.md index 2909e0e6cf..90da1d3def 100644 --- a/README.md +++ b/README.md @@ -107,74 +107,6 @@ Monitor and analyze application logs and performance over time. You could contin **7. Backend-as-a-Service**: All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. -## Feature Comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
- ## Using Dify - **Cloud
** @@ -185,7 +117,8 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. - **Dify for enterprise / organizations
** - We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
+ We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs.
+ > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding. ## Staying ahead @@ -225,23 +158,27 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) ##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Using Alibaba Cloud Computing Nest -Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) #### Using Alibaba Cloud Data Management -One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Deploy to AKS with Azure Devops Pipeline + +One-Click deploy Dify to AKS with [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). At the same time, please consider supporting Dify by sharing it on social media and at events and conferences. -> We are looking for contributors to help translate Dify into languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). +> We are looking for contributors to help translate Dify into languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). ## Community & contact diff --git a/README_AR.md b/README_AR.md index e959ca0f78..2451757ab5 100644 --- a/README_AR.md +++ b/README_AR.md @@ -52,7 +52,7 @@ مشروع Dify هو منصة تطوير تطبيقات الذكاء الصناعي مفتوحة المصدر. تجمع واجهته البديهية بين سير العمل الذكي بالذكاء الاصطناعي وخط أنابيب RAG وقدرات الوكيل وإدارة النماذج وميزات الملاحظة وأكثر من ذلك، مما يتيح لك الانتقال بسرعة من المرحلة التجريبية إلى الإنتاج. إليك قائمة بالميزات الأساسية:

-**1. سير العمل**: قم ببناء واختبار سير عمل الذكاء الاصطناعي القوي على قماش بصري، مستفيدًا من جميع الميزات التالية وأكثر. +**1. سير العمل**: قم ببناء واختبار سير عمل الذكاء الاصطناعي القوي على قماش بصري، مستفيدًا من جميع الميزات التالية وأكثر. **2. الدعم الشامل للنماذج**: تكامل سلس مع مئات من LLMs الخاصة / مفتوحة المصدر من عشرات من موفري التحليل والحلول المستضافة ذاتيًا، مما يغطي GPT و Mistral و Llama3 وأي نماذج متوافقة مع واجهة OpenAI API. يمكن العثور على قائمة كاملة بمزودي النموذج المدعومين [هنا](https://docs.dify.ai/getting-started/readme/model-providers). @@ -68,88 +68,20 @@ **7.الواجهة الخلفية (Backend) كخدمة**: تأتي جميع عروض Dify مع APIs مطابقة، حتى يمكنك دمج Dify بسهولة في منطق أعمالك الخاص. -## مقارنة الميزات - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
الميزةDify.AILangChainFlowiseOpenAI Assistants API
نهج البرمجةموجّه لـ تطبيق + واجهة برمجة تطبيق (API)برمجة Pythonموجه لتطبيقواجهة برمجة تطبيق (API)
LLMs المدعومةتنوع غنيتنوع غنيتنوع غنيفقط OpenAI
محرك RAG
الوكيل
سير العمل
الملاحظة
ميزات الشركات (SSO / مراقبة الوصول)
نشر محلي
- ## استخدام Dify - **سحابة
** -نحن نستضيف [خدمة Dify Cloud](https://dify.ai) لأي شخص لتجربتها بدون أي إعدادات. توفر كل قدرات النسخة التي تمت استضافتها ذاتيًا، وتتضمن 200 أمر GPT-4 مجانًا في خطة الصندوق الرملي. + نحن نستضيف [خدمة Dify Cloud](https://dify.ai) لأي شخص لتجربتها بدون أي إعدادات. توفر كل قدرات النسخة التي تمت استضافتها ذاتيًا، وتتضمن 200 أمر GPT-4 مجانًا في خطة الصندوق الرملي. - **استضافة ذاتية لنسخة المجتمع Dify
** -ابدأ سريعًا في تشغيل Dify في بيئتك باستخدام [دليل البدء السريع](#البدء السريع). -استخدم [توثيقنا](https://docs.dify.ai) للمزيد من المراجع والتعليمات الأعمق. + ابدأ سريعًا في تشغيل Dify في بيئتك باستخدام \[دليل البدء السريع\](#البدء السريع). + استخدم [توثيقنا](https://docs.dify.ai) للمزيد من المراجع والتعليمات الأعمق. - **مشروع Dify للشركات / المؤسسات
** -نحن نوفر ميزات إضافية مركزة على الشركات. [جدول اجتماع معنا](https://cal.com/guchenhe/30min) أو [أرسل لنا بريدًا إلكترونيًا](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) لمناقشة احتياجات الشركات.
+ نحن نوفر ميزات إضافية مركزة على الشركات. [جدول اجتماع معنا](https://cal.com/guchenhe/30min) أو [أرسل لنا بريدًا إلكترونيًا](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) لمناقشة احتياجات الشركات.
> بالنسبة للشركات الناشئة والشركات الصغيرة التي تستخدم خدمات AWS، تحقق من [Dify Premium على AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) ونشرها في شبكتك الخاصة على AWS VPC بنقرة واحدة. إنها عرض AMI بأسعار معقولة مع خيار إنشاء تطبيقات بشعار وعلامة تجارية مخصصة. -> + ## البقاء قدمًا قم بإضافة نجمة إلى Dify على GitHub وتلق تنبيهًا فوريًا بالإصدارات الجديدة. @@ -157,11 +89,11 @@ ![نجمنا](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) ## البداية السريعة -> + > قبل تثبيت Dify، تأكد من أن جهازك يلبي الحد الأدنى من متطلبات النظام التالية: > ->- معالج >= 2 نواة ->- ذاكرة وصول عشوائي (RAM) >= 4 جيجابايت +> - معالج >= 2 نواة +> - ذاكرة وصول عشوائي (RAM) >= 4 جيجابايت
@@ -208,22 +140,27 @@ docker compose up -d ##### AWS -- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK بواسطة @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK بواسطة @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### استخدام Alibaba Cloud للنشر - [بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) - + +[بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + #### استخدام Alibaba Cloud Data Management للنشر انشر ​​Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### استخدام Azure Devops Pipeline للنشر على AKS + +انشر Dify على AKS بنقرة واحدة باستخدام [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ## المساهمة لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا. في الوقت نفسه، يرجى النظر في دعم Dify عن طريق مشاركته على وسائل التواصل الاجتماعي وفي الفعاليات والمؤتمرات. -> نحن نبحث عن مساهمين لمساعدة في ترجمة Dify إلى لغات أخرى غير اللغة الصينية المندرين أو الإنجليزية. إذا كنت مهتمًا بالمساعدة، يرجى الاطلاع على [README للترجمة](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) لمزيد من المعلومات، واترك لنا تعليقًا في قناة `global-users` على [خادم المجتمع على Discord](https://discord.gg/8Tpq4AcN9c). +> نحن نبحث عن مساهمين لمساعدة في ترجمة Dify إلى لغات أخرى غير اللغة الصينية المندرين أو الإنجليزية. إذا كنت مهتمًا بالمساعدة، يرجى الاطلاع على [README للترجمة](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) لمزيد من المعلومات، واترك لنا تعليقًا في قناة `global-users` على [خادم المجتمع على Discord](https://discord.gg/8Tpq4AcN9c). **المساهمون** @@ -232,6 +169,7 @@ docker compose up -d ## المجتمع والاتصال + - [مناقشة GitHub](https://github.com/langgenius/dify/discussions). الأفضل لـ: مشاركة التعليقات وطرح الأسئلة. - [المشكلات على GitHub](https://github.com/langgenius/dify/issues). الأفضل لـ: الأخطاء التي تواجهها في استخدام Dify.AI، واقتراحات الميزات. انظر [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). - [Discord](https://discord.gg/FngNHpbcY7). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع. diff --git a/README_BN.md b/README_BN.md index 29d7374ea5..ef24dea171 100644 --- a/README_BN.md +++ b/README_BN.md @@ -56,133 +56,67 @@ ডিফাই একটি ওপেন-সোর্স LLM অ্যাপ ডেভেলপমেন্ট প্ল্যাটফর্ম। এটি ইন্টুইটিভ ইন্টারফেস, এজেন্টিক AI ওয়ার্কফ্লো, RAG পাইপলাইন, এজেন্ট ক্যাপাবিলিটি, মডেল ম্যানেজমেন্ট, মনিটরিং সুবিধা এবং আরও অনেক কিছু একত্রিত করে, যা দ্রুত প্রোটোটাইপ থেকে প্রোডাকশন পর্যন্ত নিয়ে যেতে সহায়তা করে। ## কুইক স্টার্ট + +> ডিফাই ইনস্টল করার আগে, নিশ্চিত করুন যে আপনার মেশিন নিম্নলিখিত ন্যূনতম কনফিগারেশনের প্রয়োজনীয়তা পূরন করে : > -> ডিফাই ইনস্টল করার আগে, নিশ্চিত করুন যে আপনার মেশিন নিম্নলিখিত ন্যূনতম কনফিগারেশনের প্রয়োজনীয়তা পূরন করে : -> ->- সিপিউ >= 2 কোর ->- র‍্যাম >= 4 জিবি +> - সিপিউ >= 2 কোর +> - র‍্যাম >= 4 জিবি
ডিফাই সার্ভার চালু করার সবচেয়ে সহজ উপায় [docker compose](docker/docker-compose.yaml) মাধ্যমে। নিম্নলিখিত কমান্ডগুলো ব্যবহার করে ডিফাই চালানোর আগে, নিশ্চিত করুন যে আপনার মেশিনে [Docker](https://docs.docker.com/get-docker/) এবং [Docker Compose](https://docs.docker.com/compose/install/) ইনস্টল করা আছে : + ```bash cd dify cd docker cp .env.example .env docker compose up -d ``` + চালানোর পর, আপনি আপনার ব্রাউজারে [http://localhost/install](http://localhost/install)-এ ডিফাই ড্যাশবোর্ডে অ্যাক্সেস করতে পারেন এবং ইনিশিয়ালাইজেশন প্রক্রিয়া শুরু করতে পারেন। #### সাহায্যের খোঁজে -ডিফাই সেট আপ করতে সমস্যা হলে দয়া করে আমাদের [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) দেখুন। যদি তবুও সমস্যা থেকে থাকে, তাহলে [কমিউনিটি এবং আমাদের](#community--contact) সাথে যোগাযোগ করুন। +ডিফাই সেট আপ করতে সমস্যা হলে দয়া করে আমাদের [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) দেখুন। যদি তবুও সমস্যা থেকে থাকে, তাহলে [কমিউনিটি এবং আমাদের](#community--contact) সাথে যোগাযোগ করুন। > যদি আপনি ডিফাইতে অবদান রাখতে বা অতিরিক্ত উন্নয়ন করতে চান, আমাদের [সোর্স কোড থেকে ডিপ্লয়মেন্টের গাইড](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) দেখুন। ## প্রধান ফিচারসমূহ **১. ওয়ার্কফ্লো**: - ভিজ্যুয়াল ক্যানভাসে AI ওয়ার্কফ্লো তৈরি এবং পরীক্ষা করুন, নিম্নলিখিত সব ফিচার এবং তার বাইরেও আরও অনেক কিছু ব্যবহার করে। +ভিজ্যুয়াল ক্যানভাসে AI ওয়ার্কফ্লো তৈরি এবং পরীক্ষা করুন, নিম্নলিখিত সব ফিচার এবং তার বাইরেও আরও অনেক কিছু ব্যবহার করে। -**২. মডেল সাপোর্ট**: - GPT, Mistral, Llama3, এবং যেকোনো OpenAI API-সামঞ্জস্যপূর্ণ মডেলসহ, কয়েক ডজন ইনফারেন্স প্রদানকারী এবং সেল্ফ-হোস্টেড সমাধান থেকে শুরু করে প্রোপ্রাইটরি/ওপেন-সোর্স LLM-এর সাথে সহজে ইন্টিগ্রেশন। সমর্থিত মডেল প্রদানকারীদের একটি সম্পূর্ণ তালিকা পাওয়া যাবে [এখানে](https://docs.dify.ai/getting-started/readme/model-providers)। +**২. মডেল সাপোর্ট**: +GPT, Mistral, Llama3, এবং যেকোনো OpenAI API-সামঞ্জস্যপূর্ণ মডেলসহ, কয়েক ডজন ইনফারেন্স প্রদানকারী এবং সেল্ফ-হোস্টেড সমাধান থেকে শুরু করে প্রোপ্রাইটরি/ওপেন-সোর্স LLM-এর সাথে সহজে ইন্টিগ্রেশন। সমর্থিত মডেল প্রদানকারীদের একটি সম্পূর্ণ তালিকা পাওয়া যাবে [এখানে](https://docs.dify.ai/getting-started/readme/model-providers)। ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) -**3. প্রম্পট IDE**: - প্রম্পট তৈরি, মডেলের পারফরম্যান্স তুলনা এবং চ্যাট-বেজড অ্যাপে টেক্সট-টু-স্পিচের মতো বৈশিষ্ট্য যুক্ত করার জন্য ইন্টুইটিভ ইন্টারফেস। +**3. প্রম্পট IDE**: +প্রম্পট তৈরি, মডেলের পারফরম্যান্স তুলনা এবং চ্যাট-বেজড অ্যাপে টেক্সট-টু-স্পিচের মতো বৈশিষ্ট্য যুক্ত করার জন্য ইন্টুইটিভ ইন্টারফেস। **4. RAG পাইপলাইন**: - ডকুমেন্ট ইনজেশন থেকে শুরু করে রিট্রিভ পর্যন্ত সবকিছুই বিস্তৃত RAG ক্যাপাবিলিটির আওতাভুক্ত। PDF, PPT এবং অন্যান্য সাধারণ ডকুমেন্ট ফর্ম্যাট থেকে টেক্সট এক্সট্রাকশনের জন্য আউট-অফ-বক্স সাপোর্ট। +ডকুমেন্ট ইনজেশন থেকে শুরু করে রিট্রিভ পর্যন্ত সবকিছুই বিস্তৃত RAG ক্যাপাবিলিটির আওতাভুক্ত। PDF, PPT এবং অন্যান্য সাধারণ ডকুমেন্ট ফর্ম্যাট থেকে টেক্সট এক্সট্রাকশনের জন্য আউট-অফ-বক্স সাপোর্ট। -**5. এজেন্ট ক্যাপাবিলিটি**: - LLM ফাংশন কলিং বা ReAct উপর ভিত্তি করে এজেন্ট ডিফাইন করতে পারেন এবং এজেন্টের জন্য পূর্ব-নির্মিত বা কাস্টম টুলস যুক্ত করতে পারেন। Dify AI এজেন্টদের জন্য 50+ বিল্ট-ইন টুলস সরবরাহ করে, যেমন Google Search, DALL·E, Stable Diffusion এবং WolframAlpha। +**5. এজেন্ট ক্যাপাবিলিটি**: +LLM ফাংশন কলিং বা ReAct উপর ভিত্তি করে এজেন্ট ডিফাইন করতে পারেন এবং এজেন্টের জন্য পূর্ব-নির্মিত বা কাস্টম টুলস যুক্ত করতে পারেন। Dify AI এজেন্টদের জন্য 50+ বিল্ট-ইন টুলস সরবরাহ করে, যেমন Google Search, DALL·E, Stable Diffusion এবং WolframAlpha। -**6. এলএলএম-অপ্স**: - সময়ের সাথে সাথে অ্যাপ্লিকেশন লগ এবং পারফরম্যান্স মনিটর এবং বিশ্লেষণ করুন। প্রডাকশন ডেটা এবং annotation এর উপর ভিত্তি করে প্রম্পট, ডেটাসেট এবং মডেলগুলিকে ক্রমাগত উন্নত করতে পারেন। +**6. এলএলএম-অপ্স**: +সময়ের সাথে সাথে অ্যাপ্লিকেশন লগ এবং পারফরম্যান্স মনিটর এবং বিশ্লেষণ করুন। প্রডাকশন ডেটা এবং annotation এর উপর ভিত্তি করে প্রম্পট, ডেটাসেট এবং মডেলগুলিকে ক্রমাগত উন্নত করতে পারেন। **7. ব্যাকএন্ড-অ্যাজ-এ-সার্ভিস**: - ডিফাই-এর সমস্ত অফার সংশ্লিষ্ট API-সহ আছে, যাতে আপনি অনায়াসে ডিফাইকে আপনার নিজস্ব বিজনেস লজিকে ইন্টেগ্রেট করতে পারেন। +ডিফাই-এর সমস্ত অফার সংশ্লিষ্ট API-সহ আছে, যাতে আপনি অনায়াসে ডিফাইকে আপনার নিজস্ব বিজনেস লজিকে ইন্টেগ্রেট করতে পারেন। -## বৈশিষ্ট্য তুলনা - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
বৈশিষ্ট্যDify.AILangChainFlowiseOpenAI Assistants API
প্রোগ্রামিং পদ্ধতিAPI + App-orientedPython CodeApp-orientedAPI-oriented
সাপোর্টেড LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG ইঞ্জিন
এজেন্ট
ওয়ার্কফ্লো
অবজার্ভেবল
এন্টারপ্রাইজ ফিচার (SSO/Access control)
লোকাল ডেপ্লয়মেন্ট
- -## ডিফাই-এর ব্যবহার +## ডিফাই-এর ব্যবহার - **ক্লাউড
** -জিরো সেটাপে ব্যবহার করতে আমাদের [Dify Cloud](https://dify.ai) সার্ভিসটি ব্যবহার করতে পারেন। এখানে সেল্ফহোস্টিং-এর সকল ফিচার ও ক্যাপাবিলিটিসহ স্যান্ডবক্সে ২০০ জিপিটি-৪ কল ফ্রি পাবেন। + জিরো সেটাপে ব্যবহার করতে আমাদের [Dify Cloud](https://dify.ai) সার্ভিসটি ব্যবহার করতে পারেন। এখানে সেল্ফহোস্টিং-এর সকল ফিচার ও ক্যাপাবিলিটিসহ স্যান্ডবক্সে ২০০ জিপিটি-৪ কল ফ্রি পাবেন। - **সেল্ফহোস্টিং ডিফাই কমিউনিটি সংস্করণ
** -সেল্ফহোস্ট করতে এই [স্টার্টার গাইড](#quick-start) ব্যবহার করে দ্রুত আপনার এনভায়রনমেন্টে ডিফাই চালান। -আরো ইন-ডেপথ রেফারেন্সের জন্য [ডকুমেন্টেশন](https://docs.dify.ai) দেখেন। + সেল্ফহোস্ট করতে এই [স্টার্টার গাইড](#quick-start) ব্যবহার করে দ্রুত আপনার এনভায়রনমেন্টে ডিফাই চালান। + আরো ইন-ডেপথ রেফারেন্সের জন্য [ডকুমেন্টেশন](https://docs.dify.ai) দেখেন। - **এন্টারপ্রাইজ / প্রতিষ্ঠানের জন্য Dify
** -আমরা এন্টারপ্রাইজ/প্রতিষ্ঠান-কেন্দ্রিক সেবা প্রদান করে থাকি । [এই চ্যাটবটের মাধ্যমে আপনার প্রশ্নগুলি আমাদের জন্য লগ করুন।](https://udify.app/chat/22L1zSxg6yW1cWQg) অথবা [আমাদের ইমেল পাঠান](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) আপনার চাহিদা সম্পর্কে আলোচনা করার জন্য।
+ আমরা এন্টারপ্রাইজ/প্রতিষ্ঠান-কেন্দ্রিক সেবা প্রদান করে থাকি । [এই চ্যাটবটের মাধ্যমে আপনার প্রশ্নগুলি আমাদের জন্য লগ করুন।](https://udify.app/chat/22L1zSxg6yW1cWQg) অথবা [আমাদের ইমেল পাঠান](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) আপনার চাহিদা সম্পর্কে আলোচনা করার জন্য।
> AWS ব্যবহারকারী স্টার্টআপ এবং ছোট ব্যবসার জন্য, [AWS মার্কেটপ্লেসে Dify Premium](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) দেখুন এবং এক-ক্লিকের মাধ্যমে এটি আপনার নিজস্ব AWS VPC-তে ডিপ্লয় করুন। এটি একটি সাশ্রয়ী মূল্যের AMI অফার, যাতে কাস্টম লোগো এবং ব্র্যান্ডিং সহ অ্যাপ তৈরির সুবিধা আছে। @@ -194,10 +128,10 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন ## Advanced Setup -যদি আপনার কনফিগারেশনটি কাস্টমাইজ করার প্রয়োজন হয়, তাহলে অনুগ্রহ করে আমাদের [.env.example](docker/.env.example) ফাইল দেখুন এবং আপনার `.env` ফাইলে সংশ্লিষ্ট মানগুলি আপডেট করুন। এছাড়াও, আপনার নির্দিষ্ট এনভায়রনমেন্ট এবং প্রয়োজনীয়তার উপর ভিত্তি করে আপনাকে `docker-compose.yaml` ফাইলে সমন্বয় করতে হতে পারে, যেমন ইমেজ ভার্সন পরিবর্তন করা, পোর্ট ম্যাপিং করা, অথবা ভলিউম মাউন্ট করা। +যদি আপনার কনফিগারেশনটি কাস্টমাইজ করার প্রয়োজন হয়, তাহলে অনুগ্রহ করে আমাদের [.env.example](docker/.env.example) ফাইল দেখুন এবং আপনার `.env` ফাইলে সংশ্লিষ্ট মানগুলি আপডেট করুন। এছাড়াও, আপনার নির্দিষ্ট এনভায়রনমেন্ট এবং প্রয়োজনীয়তার উপর ভিত্তি করে আপনাকে `docker-compose.yaml` ফাইলে সমন্বয় করতে হতে পারে, যেমন ইমেজ ভার্সন পরিবর্তন করা, পোর্ট ম্যাপিং করা, অথবা ভলিউম মাউন্ট করা। যেকোনো পরিবর্তন করার পর, অনুগ্রহ করে `docker-compose up -d` পুনরায় চালান। ভেরিয়েবলের সম্পূর্ণ তালিকা [এখানে] (https://docs.dify.ai/getting-started/install-self-hosted/environments) খুঁজে পেতে পারেন। -যদি আপনি একটি হাইলি এভেইলেবল সেটআপ কনফিগার করতে চান, তাহলে কমিউনিটি [Helm Charts](https://helm.sh/) এবং YAML ফাইল রয়েছে যা Dify কে Kubernetes-এ ডিপ্লয় করার প্রক্রিয়া বর্ণনা করে। +যদি আপনি একটি হাইলি এভেইলেবল সেটআপ কনফিগার করতে চান, তাহলে কমিউনিটি [Helm Charts](https://helm.sh/) এবং YAML ফাইল রয়েছে যা Dify কে Kubernetes-এ ডিপ্লয় করার প্রক্রিয়া বর্ণনা করে। - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) @@ -206,7 +140,6 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) - [🚀 নতুন! YAML ফাইলসমূহ (Dify v1.6.0 সমর্থিত) তৈরি করেছেন @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) - #### টেরাফর্ম ব্যবহার করে ডিপ্লয় [terraform](https://www.terraform.io/) ব্যবহার করে এক ক্লিকেই ক্লাউড প্ল্যাটফর্মে Dify ডিপ্লয় করুন। @@ -225,30 +158,34 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন ##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud ব্যবহার করে ডিপ্লয় - [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) #### Alibaba Cloud Data Management ব্যবহার করে ডিপ্লয় - [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### AKS-এ ডিপ্লয় করার জন্য Azure Devops Pipeline ব্যবহার + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ব্যবহার করে Dify কে AKS-এ এক ক্লিকে ডিপ্লয় করুন ## Contributing যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)। একই সাথে, সোশ্যাল মিডিয়া এবং ইভেন্ট এবং কনফারেন্সে এটি শেয়ার করে Dify কে সমর্থন করুন। -> আমরা ম্যান্ডারিন বা ইংরেজি ছাড়া অন্য ভাষায় Dify অনুবাদ করতে সাহায্য করার জন্য অবদানকারীদের খুঁজছি। আপনি যদি সাহায্য করতে আগ্রহী হন, তাহলে আরও তথ্যের জন্য [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) দেখুন এবং আমাদের [ডিসকর্ড কমিউনিটি সার্ভার](https://discord.gg/8Tpq4AcN9c) এর `গ্লোবাল-ইউজারস` চ্যানেলে আমাদের একটি মন্তব্য করুন। +> আমরা ম্যান্ডারিন বা ইংরেজি ছাড়া অন্য ভাষায় Dify অনুবাদ করতে সাহায্য করার জন্য অবদানকারীদের খুঁজছি। আপনি যদি সাহায্য করতে আগ্রহী হন, তাহলে আরও তথ্যের জন্য [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) দেখুন এবং আমাদের [ডিসকর্ড কমিউনিটি সার্ভার](https://discord.gg/8Tpq4AcN9c) এর `গ্লোবাল-ইউজারস` চ্যানেলে আমাদের একটি মন্তব্য করুন। ## কমিউনিটি এবং যোগাযোগ - [GitHub Discussion](https://github.com/langgenius/dify/discussions) ফিডব্যাক এবং প্রতিক্রিয়া জানানোর মাধ্যম। -- [GitHub Issues](https://github.com/langgenius/dify/issues). Dify.AI ব্যবহার করে আপনি যেসব বাগের সম্মুখীন হন এবং ফিচার প্রস্তাবনা। আমাদের [অবদান নির্দেশিকা](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) দেখুন। -- [Discord](https://discord.gg/FngNHpbcY7) আপনার এপ্লিকেশন শেয়ার এবং কমিউনিটি আড্ডার মাধ্যম। -- [X(Twitter)](https://twitter.com/dify_ai) আপনার এপ্লিকেশন শেয়ার এবং কমিউনিটি আড্ডার মাধ্যম। +- [GitHub Issues](https://github.com/langgenius/dify/issues). Dify.AI ব্যবহার করে আপনি যেসব বাগের সম্মুখীন হন এবং ফিচার প্রস্তাবনা। আমাদের [অবদান নির্দেশিকা](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) দেখুন। +- [Discord](https://discord.gg/FngNHpbcY7) আপনার এপ্লিকেশন শেয়ার এবং কমিউনিটি আড্ডার মাধ্যম। +- [X(Twitter)](https://twitter.com/dify_ai) আপনার এপ্লিকেশন শেয়ার এবং কমিউনিটি আড্ডার মাধ্যম। **অবদানকারীদের তালিকা** @@ -260,7 +197,7 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) -## নিরাপত্তা বিষয়ক +## নিরাপত্তা বিষয়ক আপনার গোপনীয়তা রক্ষা করতে, অনুগ্রহ করে GitHub-এ নিরাপত্তা সংক্রান্ত সমস্যা পোস্ট করা এড়িয়ে চলুন। পরিবর্তে, আপনার প্রশ্নগুলি ঠিকানায় পাঠান এবং আমরা আপনাকে আরও বিস্তারিত উত্তর প্রদান করব। diff --git a/README_CN.md b/README_CN.md index 486a368c09..2949b38867 100644 --- a/README_CN.md +++ b/README_CN.md @@ -48,8 +48,7 @@ README in বাংলা - -# +#
langgenius%2Fdify | 趋势转变 @@ -58,109 +57,41 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 工作流、RAG 管道、Agent、模型管理、可观测性功能等,让您可以快速从原型到生产。以下是其核心功能列表:

-**1. 工作流**: - 在画布上构建和测试功能强大的 AI 工作流程,利用以下所有功能以及更多功能。 +**1. 工作流**: +在画布上构建和测试功能强大的 AI 工作流程,利用以下所有功能以及更多功能。 -**2. 全面的模型支持**: - 与数百种专有/开源 LLMs 以及数十种推理提供商和自托管解决方案无缝集成,涵盖 GPT、Mistral、Llama3 以及任何与 OpenAI API 兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。 +**2. 全面的模型支持**: +与数百种专有/开源 LLMs 以及数十种推理提供商和自托管解决方案无缝集成,涵盖 GPT、Mistral、Llama3 以及任何与 OpenAI API 兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。 ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) +**3. Prompt IDE**: +用于制作提示、比较模型性能以及向基于聊天的应用程序添加其他功能(如文本转语音)的直观界面。 -**3. Prompt IDE**: - 用于制作提示、比较模型性能以及向基于聊天的应用程序添加其他功能(如文本转语音)的直观界面。 +**4. RAG Pipeline**: +广泛的 RAG 功能,涵盖从文档摄入到检索的所有内容,支持从 PDF、PPT 和其他常见文档格式中提取文本的开箱即用的支持。 -**4. RAG Pipeline**: - 广泛的 RAG 功能,涵盖从文档摄入到检索的所有内容,支持从 PDF、PPT 和其他常见文档格式中提取文本的开箱即用的支持。 +**5. Agent 智能体**: +您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了 50 多种内置工具,如谷歌搜索、DALL·E、Stable Diffusion 和 WolframAlpha 等。 -**5. Agent 智能体**: - 您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了 50 多种内置工具,如谷歌搜索、DALL·E、Stable Diffusion 和 WolframAlpha 等。 +**6. LLMOps**: +随时间监视和分析应用程序日志和性能。您可以根据生产数据和标注持续改进提示、数据集和模型。 -**6. LLMOps**: - 随时间监视和分析应用程序日志和性能。您可以根据生产数据和标注持续改进提示、数据集和模型。 - -**7. 后端即服务**: - 所有 Dify 的功能都带有相应的 API,因此您可以轻松地将 Dify 集成到自己的业务逻辑中。 - - -## 功能比较 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
功能Dify.AILangChainFlowiseOpenAI Assistant API
编程方法API + 应用程序导向Python 代码应用程序导向API 导向
支持的 LLMs丰富多样丰富多样丰富多样仅限 OpenAI
RAG 引擎
Agent
工作流
可观测性
企业功能(SSO/访问控制)
本地部署
+**7. 后端即服务**: +所有 Dify 的功能都带有相应的 API,因此您可以轻松地将 Dify 集成到自己的业务逻辑中。 ## 使用 Dify - **云
** -我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。 + 我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。 - **自托管 Dify 社区版
** -使用这个[入门指南](#快速启动)快速在您的环境中运行 Dify。 -使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。 + 使用这个[入门指南](#%E5%BF%AB%E9%80%9F%E5%90%AF%E5%8A%A8)快速在您的环境中运行 Dify。 + 使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。 - **面向企业/组织的 Dify
** -我们提供额外的面向企业的功能。[给我们发送电子邮件](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)讨论企业需求。
+ 我们提供额外的面向企业的功能。[给我们发送电子邮件](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry)讨论企业需求。
+ > 对于使用 AWS 的初创公司和中小型企业,请查看 [AWS Marketplace 上的 Dify 高级版](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6),并使用一键部署到您自己的 AWS VPC。它是一个价格实惠的 AMI 产品,提供了使用自定义徽标和品牌创建应用程序的选项。 ## 保持领先 @@ -199,31 +130,37 @@ docker compose up -d 使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。 - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) + - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) + - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) + - [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes) + - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) - [🚀 NEW! YAML 文件 (支持 Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) - - #### 使用 Terraform 部署 使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台 ##### Azure Global + - [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) #### 使用 AWS CDK 部署 使用 [CDK](https://aws.amazon.com/cdk/) 将 Dify 部署到 AWS -##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS + +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### 使用 阿里云计算巢 部署 @@ -233,18 +170,20 @@ docker compose up -d 使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云 +#### 使用 Azure Devops Pipeline 部署到AKS + +使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 将 Dify 一键部署到 AKS ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) - ## Contributing 对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。 -> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 +> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 **Contributors** @@ -258,10 +197,10 @@ docker compose up -d - [GitHub Discussion](https://github.com/langgenius/dify/discussions). 👉:分享您的应用程序并与社区交流。 - [GitHub Issues](https://github.com/langgenius/dify/issues)。👉:使用 Dify.AI 时遇到的错误和问题,请参阅[贡献指南](CONTRIBUTING.md)。 -- [电子邮件支持](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify)。👉:关于使用 Dify.AI 的问题。 +- [电子邮件支持](mailto:hello@dify.ai?subject=%5BGitHub%5DQuestions%20About%20Dify)。👉:关于使用 Dify.AI 的问题。 - [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。 - [X(Twitter)](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。 -- [商业许可](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)。👉:有关商业用途许可 Dify.AI 的商业咨询。 +- [商业许可](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry)。👉:有关商业用途许可 Dify.AI 的商业咨询。 ## 安全问题 diff --git a/README_DE.md b/README_DE.md index fce52c34c2..a593a12abf 100644 --- a/README_DE.md +++ b/README_DE.md @@ -56,10 +56,11 @@ Dify ist eine Open-Source-Plattform zur Entwicklung von LLM-Anwendungen. Ihre intuitive Benutzeroberfläche vereint agentenbasierte KI-Workflows, RAG-Pipelines, Agentenfunktionen, Modellverwaltung, Überwachungsfunktionen und mehr, sodass Sie schnell von einem Prototyp in die Produktion übergehen können. ## Schnellstart + > Bevor Sie Dify installieren, stellen Sie sicher, dass Ihr System die folgenden Mindestanforderungen erfüllt: -> ->- CPU >= 2 Core ->- RAM >= 4 GiB +> +> - CPU >= 2 Core +> - RAM >= 4 GiB
@@ -75,115 +76,48 @@ docker compose up -d Nachdem Sie den Server gestartet haben, können Sie über Ihren Browser auf das Dify Dashboard unter [http://localhost/install](http://localhost/install) zugreifen und den Initialisierungsprozess starten. #### Hilfe suchen + Bitte beachten Sie unsere [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs), wenn Sie Probleme bei der Einrichtung von Dify haben. Wenden Sie sich an [die Community und uns](#community--contact), falls weiterhin Schwierigkeiten auftreten. > Wenn Sie zu Dify beitragen oder zusätzliche Entwicklungen durchführen möchten, lesen Sie bitte unseren [Leitfaden zur Bereitstellung aus dem Quellcode](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code). ## Wesentliche Merkmale -**1. Workflow**: - Erstellen und testen Sie leistungsstarke KI-Workflows auf einer visuellen Oberfläche, wobei Sie alle der folgenden Funktionen und darüber hinaus nutzen können. -**2. Umfassende Modellunterstützung**: - Nahtlose Integration mit Hunderten von proprietären und Open-Source-LLMs von Dutzenden Inferenzanbietern und selbstgehosteten Lösungen, die GPT, Mistral, Llama3 und alle mit der OpenAI API kompatiblen Modelle abdecken. Eine vollständige Liste der unterstützten Modellanbieter finden Sie [hier](https://docs.dify.ai/getting-started/readme/model-providers). +**1. Workflow**: +Erstellen und testen Sie leistungsstarke KI-Workflows auf einer visuellen Oberfläche, wobei Sie alle der folgenden Funktionen und darüber hinaus nutzen können. +**2. Umfassende Modellunterstützung**: +Nahtlose Integration mit Hunderten von proprietären und Open-Source-LLMs von Dutzenden Inferenzanbietern und selbstgehosteten Lösungen, die GPT, Mistral, Llama3 und alle mit der OpenAI API kompatiblen Modelle abdecken. Eine vollständige Liste der unterstützten Modellanbieter finden Sie [hier](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) +**3. Prompt IDE**: +Intuitive Benutzeroberfläche zum Erstellen von Prompts, zum Vergleichen der Modellleistung und zum Hinzufügen zusätzlicher Funktionen wie Text-to-Speech in einer chatbasierten Anwendung. -**3. Prompt IDE**: - Intuitive Benutzeroberfläche zum Erstellen von Prompts, zum Vergleichen der Modellleistung und zum Hinzufügen zusätzlicher Funktionen wie Text-to-Speech in einer chatbasierten Anwendung. +**4. RAG Pipeline**: +Umfassende RAG-Funktionalitäten, die alles von der Dokumenteneinlesung bis zur -abfrage abdecken, mit sofort einsatzbereiter Unterstützung für die Textextraktion aus PDFs, PPTs und anderen gängigen Dokumentformaten. -**4. RAG Pipeline**: - Umfassende RAG-Funktionalitäten, die alles von der Dokumenteneinlesung bis zur -abfrage abdecken, mit sofort einsatzbereiter Unterstützung für die Textextraktion aus PDFs, PPTs und anderen gängigen Dokumentformaten. +**5. Fähigkeiten des Agenten**: +Sie können Agenten basierend auf LLM Function Calling oder ReAct definieren und vorgefertigte oder benutzerdefinierte Tools für den Agenten hinzufügen. Dify stellt über 50 integrierte Tools für KI-Agenten bereit, wie zum Beispiel Google Search, DALL·E, Stable Diffusion und WolframAlpha. -**5. Fähigkeiten des Agenten**: - Sie können Agenten basierend auf LLM Function Calling oder ReAct definieren und vorgefertigte oder benutzerdefinierte Tools für den Agenten hinzufügen. Dify stellt über 50 integrierte Tools für KI-Agenten bereit, wie zum Beispiel Google Search, DALL·E, Stable Diffusion und WolframAlpha. +**6. LLMOps**: +Überwachen und analysieren Sie Anwendungsprotokolle und die Leistung im Laufe der Zeit. Sie können kontinuierlich Prompts, Datensätze und Modelle basierend auf Produktionsdaten und Annotationen verbessern. -**6. LLMOps**: - Überwachen und analysieren Sie Anwendungsprotokolle und die Leistung im Laufe der Zeit. Sie können kontinuierlich Prompts, Datensätze und Modelle basierend auf Produktionsdaten und Annotationen verbessern. - -**7. Backend-as-a-Service**: - Alle Dify-Angebote kommen mit entsprechenden APIs, sodass Sie Dify mühelos in Ihre eigene Geschäftslogik integrieren können. - -## Vergleich der Merkmale - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
+**7. Backend-as-a-Service**: +Alle Dify-Angebote kommen mit entsprechenden APIs, sodass Sie Dify mühelos in Ihre eigene Geschäftslogik integrieren können. ## Dify verwenden - **Cloud
** -Wir hosten einen [Dify Cloud](https://dify.ai)-Service, den jeder ohne Einrichtung ausprobieren kann. Er bietet alle Funktionen der selbstgehosteten Version und beinhaltet 200 kostenlose GPT-4-Aufrufe im Sandbox-Plan. + Wir hosten einen [Dify Cloud](https://dify.ai)-Service, den jeder ohne Einrichtung ausprobieren kann. Er bietet alle Funktionen der selbstgehosteten Version und beinhaltet 200 kostenlose GPT-4-Aufrufe im Sandbox-Plan. - **Selbstgehostete Dify Community Edition
** -Starten Sie Dify schnell in Ihrer Umgebung mit diesem [Schnellstart-Leitfaden](#quick-start). Nutzen Sie unsere [Dokumentation](https://docs.dify.ai) für weiterführende Informationen und detaillierte Anweisungen. + Starten Sie Dify schnell in Ihrer Umgebung mit diesem [Schnellstart-Leitfaden](#quick-start). Nutzen Sie unsere [Dokumentation](https://docs.dify.ai) für weiterführende Informationen und detaillierte Anweisungen. - **Dify für Unternehmen / Organisationen
** -Wir bieten zusätzliche, unternehmensspezifische Funktionen. [Über diesen Chatbot können Sie uns Ihre Fragen mitteilen](https://udify.app/chat/22L1zSxg6yW1cWQg) oder [senden Sie uns eine E-Mail](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry), um Ihre unternehmerischen Bedürfnisse zu besprechen.
- > Für Startups und kleine Unternehmen, die AWS nutzen, schauen Sie sich [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) an und stellen Sie es mit nur einem Klick in Ihrer eigenen AWS VPC bereit. Es handelt sich um ein erschwingliches AMI-Angebot mit der Option, Apps mit individuellem Logo und Branding zu erstellen. + Wir bieten zusätzliche, unternehmensspezifische Funktionen. [Über diesen Chatbot können Sie uns Ihre Fragen mitteilen](https://udify.app/chat/22L1zSxg6yW1cWQg) oder [senden Sie uns eine E-Mail](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry), um Ihre unternehmerischen Bedürfnisse zu besprechen.
+ > Für Startups und kleine Unternehmen, die AWS nutzen, schauen Sie sich [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) an und stellen Sie es mit nur einem Klick in Ihrer eigenen AWS VPC bereit. Es handelt sich um ein erschwingliches AMI-Angebot mit der Option, Apps mit individuellem Logo und Branding zu erstellen. ## Immer einen Schritt voraus @@ -191,7 +125,6 @@ Star Dify auf GitHub und lassen Sie sich sofort über neue Releases benachrichti ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - ## Erweiterte Einstellungen Falls Sie die Konfiguration anpassen müssen, lesen Sie bitte die Kommentare in unserer [.env.example](docker/.env.example)-Datei und aktualisieren Sie die entsprechenden Werte in Ihrer `.env`-Datei. Zusätzlich müssen Sie eventuell Anpassungen an der `docker-compose.yaml`-Datei vornehmen, wie zum Beispiel das Ändern von Image-Versionen, Portzuordnungen oder Volumen-Mounts, je nach Ihrer spezifischen Einsatzumgebung und Ihren Anforderungen. Nachdem Sie Änderungen vorgenommen haben, starten Sie `docker-compose up -d` erneut. Eine vollständige Liste der verfügbaren Umgebungsvariablen finden Sie [hier](https://docs.dify.ai/getting-started/install-self-hosted/environments). @@ -210,19 +143,23 @@ Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von de Stellen Sie Dify mit nur einem Klick mithilfe von [terraform](https://www.terraform.io/) auf einer Cloud-Plattform bereit. ##### Azure Global + - [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) #### Verwendung von AWS CDK für die Bereitstellung Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) -##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS -#### Alibaba Cloud +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) + +#### Alibaba Cloud [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) @@ -230,20 +167,22 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Verwendung von Azure Devops Pipeline für AKS-Bereitstellung + +Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) verwenden ## Contributing Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. - -> Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c). +> Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c). ## Gemeinschaft & Kontakt -* [GitHub Discussion](https://github.com/langgenius/dify/discussions). Am besten geeignet für: den Austausch von Feedback und das Stellen von Fragen. -* [GitHub Issues](https://github.com/langgenius/dify/issues). Am besten für: Fehler, auf die Sie bei der Verwendung von Dify.AI stoßen, und Funktionsvorschläge. Siehe unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -* [Discord](https://discord.gg/FngNHpbcY7). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community. -* [X(Twitter)](https://twitter.com/dify_ai). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community. +- [GitHub Discussion](https://github.com/langgenius/dify/discussions). Am besten geeignet für: den Austausch von Feedback und das Stellen von Fragen. +- [GitHub Issues](https://github.com/langgenius/dify/issues). Am besten für: Fehler, auf die Sie bei der Verwendung von Dify.AI stoßen, und Funktionsvorschläge. Siehe unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [Discord](https://discord.gg/FngNHpbcY7). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community. +- [X(Twitter)](https://twitter.com/dify_ai). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community. **Mitwirkende** @@ -255,7 +194,6 @@ Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide]( [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) - ## Offenlegung der Sicherheit Um Ihre Privatsphäre zu schützen, vermeiden Sie es bitte, Sicherheitsprobleme auf GitHub zu posten. Schicken Sie Ihre Fragen stattdessen an security@dify.ai und wir werden Ihnen eine ausführlichere Antwort geben. @@ -263,4 +201,3 @@ Um Ihre Privatsphäre zu schützen, vermeiden Sie es bitte, Sicherheitsprobleme ## Lizenz Dieses Repository steht unter der [Dify Open Source License](LICENSE), die im Wesentlichen Apache 2.0 mit einigen zusätzlichen Einschränkungen ist. - diff --git a/README_ES.md b/README_ES.md index 6fd6dfcee8..c7a18dc675 100644 --- a/README_ES.md +++ b/README_ES.md @@ -48,7 +48,7 @@ README in বাংলা

-# +#

langgenius%2Fdify | Trendshift @@ -56,111 +56,42 @@ Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto. Su interfaz intuitiva combina flujo de trabajo de IA, pipeline RAG, capacidades de agente, gestión de modelos, características de observabilidad y más, lo que le permite pasar rápidamente de un prototipo a producción. Aquí hay una lista de las características principales:

-**1. Flujo de trabajo**: - Construye y prueba potentes flujos de trabajo de IA en un lienzo visual, aprovechando todas las siguientes características y más. +**1. Flujo de trabajo**: +Construye y prueba potentes flujos de trabajo de IA en un lienzo visual, aprovechando todas las siguientes características y más. -**2. Soporte de modelos completo**: - Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama3 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers). +**2. Soporte de modelos completo**: +Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama3 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers). ![proveedores-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) +**3. IDE de prompt**: +Interfaz intuitiva para crear prompts, comparar el rendimiento del modelo y agregar características adicionales como texto a voz a una aplicación basada en chat. -**3. IDE de prompt**: - Interfaz intuitiva para crear prompts, comparar el rendimiento del modelo y agregar características adicionales como texto a voz a una aplicación basada en chat. +**4. Pipeline RAG**: +Amplias capacidades de RAG que cubren todo, desde la ingestión de documentos hasta la recuperación, con soporte listo para usar para la extracción de texto de PDF, PPT y otros formatos de documento comunes. -**4. Pipeline RAG**: - Amplias capacidades de RAG que cubren todo, desde la ingestión de documentos hasta la recuperación, con soporte listo para usar para la extracción de texto de PDF, PPT y otros formatos de documento comunes. +**5. Capacidades de agente**: +Puedes definir agentes basados en LLM Function Calling o ReAct, y agregar herramientas preconstruidas o personalizadas para el agente. Dify proporciona más de 50 herramientas integradas para agentes de IA, como Búsqueda de Google, DALL·E, Difusión Estable y WolframAlpha. -**5. Capacidades de agente**: - Puedes definir agentes basados en LLM Function Calling o ReAct, y agregar herramientas preconstruidas o personalizadas para el agente. Dify proporciona más de 50 herramientas integradas para agentes de IA, como Búsqueda de Google, DALL·E, Difusión Estable y WolframAlpha. +**6. LLMOps**: +Supervisa y analiza registros de aplicaciones y rendimiento a lo largo del tiempo. Podrías mejorar continuamente prompts, conjuntos de datos y modelos basados en datos de producción y anotaciones. -**6. LLMOps**: - Supervisa y analiza registros de aplicaciones y rendimiento a lo largo del tiempo. Podrías mejorar continuamente prompts, conjuntos de datos y modelos basados en datos de producción y anotaciones. - -**7. Backend como servicio**: - Todas las ofertas de Dify vienen con APIs correspondientes, por lo que podrías integrar Dify sin esfuerzo en tu propia lógica empresarial. - - -## Comparación de características - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
CaracterísticaDify.AILangChainFlowiseAPI de Asistentes de OpenAI
Enfoque de programaciónAPI + orientado a la aplicaciónCódigo PythonOrientado a la aplicaciónOrientado a la API
LLMs admitidosGran variedadGran variedadGran variedadSolo OpenAI
Motor RAG
Agente
Flujo de trabajo
Observabilidad
Característica empresarial (SSO/Control de acceso)
Implementación local
+**7. Backend como servicio**: +Todas las ofertas de Dify vienen con APIs correspondientes, por lo que podrías integrar Dify sin esfuerzo en tu propia lógica empresarial. ## Usando Dify - **Nube
** -Hospedamos un servicio [Dify Cloud](https://dify.ai) para que cualquiera lo pruebe sin configuración. Proporciona todas las capacidades de la versión autoimplementada e incluye 200 llamadas gratuitas a GPT-4 en el plan sandbox. + Hospedamos un servicio [Dify Cloud](https://dify.ai) para que cualquiera lo pruebe sin configuración. Proporciona todas las capacidades de la versión autoimplementada e incluye 200 llamadas gratuitas a GPT-4 en el plan sandbox. - **Auto-alojamiento de Dify Community Edition
** -Pon rápidamente Dify en funcionamiento en tu entorno con esta [guía de inicio rápido](#quick-start). -Usa nuestra [documentación](https://docs.dify.ai) para más referencias e instrucciones más detalladas. + Pon rápidamente Dify en funcionamiento en tu entorno con esta [guía de inicio rápido](#quick-start). + Usa nuestra [documentación](https://docs.dify.ai) para más referencias e instrucciones más detalladas. - **Dify para Empresas / Organizaciones
** -Proporcionamos características adicionales centradas en la empresa. [Envíanos un correo electrónico](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) para discutir las necesidades empresariales.
- > Para startups y pequeñas empresas que utilizan AWS, echa un vistazo a [Dify Premium en AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e impleméntalo en tu propio VPC de AWS con un clic. Es una AMI asequible que ofrece la opción de crear aplicaciones con logotipo y marca personalizados. + Proporcionamos características adicionales centradas en la empresa. [Envíanos un correo electrónico](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) para discutir las necesidades empresariales.
+ > Para startups y pequeñas empresas que utilizan AWS, echa un vistazo a [Dify Premium en AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e impleméntalo en tu propio VPC de AWS con un clic. Es una AMI asequible que ofrece la opción de crear aplicaciones con logotipo y marca personalizados. ## Manteniéndote al tanto @@ -168,13 +99,12 @@ Dale estrella a Dify en GitHub y serás notificado instantáneamente de las nuev ![danos estrella](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - - ## Inicio Rápido + > Antes de instalar Dify, asegúrate de que tu máquina cumpla con los siguientes requisitos mínimos del sistema: -> ->- CPU >= 2 núcleos ->- RAM >= 4GB +> +> - CPU >= 2 núcleos +> - RAM >= 4GB
@@ -210,17 +140,21 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop Despliega Dify en una plataforma en la nube con un solo clic utilizando [terraform](https://www.terraform.io/) ##### Azure Global + - [Azure Terraform por @nikawang](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [Google Cloud Terraform por @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) #### Usando AWS CDK para el Despliegue Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) -##### AWS -- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS + +- [AWS CDK por @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK por @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud @@ -230,14 +164,16 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Uso de Azure Devops Pipeline para implementar en AKS + +Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ## Contribuir -Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias. - -> Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c). +> Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c). **Contribuidores** @@ -247,15 +183,22 @@ Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en ## Comunidad y Contacto -* [Discusión en GitHub](https://github.com/langgenius/dify/discussions). Lo mejor para: compartir comentarios y hacer preguntas. -* [Reporte de problemas en GitHub](https://github.com/langgenius/dify/issues). Lo mejor para: errores que encuentres usando Dify.AI y propuestas de características. Consulta nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -* [Discord](https://discord.gg/FngNHpbcY7). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. -* [X(Twitter)](https://twitter.com/dify_ai). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. +- [Discusión en GitHub](https://github.com/langgenius/dify/discussions). Lo mejor para: compartir comentarios y hacer preguntas. +- [Reporte de problemas en GitHub](https://github.com/langgenius/dify/issues). Lo mejor para: errores que encuentres usando Dify.AI y propuestas de características. Consulta nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [Discord](https://discord.gg/FngNHpbcY7). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. +- [X(Twitter)](https://twitter.com/dify_ai). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. ## Historial de Estrellas [![Gráfico de Historial de Estrellas](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) +## Divulgación de Seguridad + +Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En su lugar, envía tus preguntas a security@dify.ai y te proporcionaremos una respuesta más detallada. + +## Licencia + +Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. ## Divulgación de Seguridad @@ -264,10 +207,3 @@ Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En ## Licencia Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. -## Divulgación de Seguridad - -Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En su lugar, envía tus preguntas a security@dify.ai y te proporcionaremos una respuesta más detallada. - -## Licencia - -Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. diff --git a/README_FR.md b/README_FR.md index b2209fb495..316d50c929 100644 --- a/README_FR.md +++ b/README_FR.md @@ -48,7 +48,7 @@ README in বাংলা

-# +#

langgenius%2Fdify | Trendshift @@ -56,111 +56,42 @@ Dify est une plateforme de développement d'applications LLM open source. Son interface intuitive combine un flux de travail d'IA, un pipeline RAG, des capacités d'agent, une gestion de modèles, des fonctionnalités d'observabilité, et plus encore, vous permettant de passer rapidement du prototype à la production. Voici une liste des fonctionnalités principales:

-**1. Flux de travail** : - Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore. +**1. Flux de travail** : +Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore. -**2. Prise en charge complète des modèles** : - Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). +**2. Prise en charge complète des modèles** : +Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) +**3. IDE de prompt** : +Interface intuitive pour créer des prompts, comparer les performances des modèles et ajouter des fonctionnalités supplémentaires telles que la synthèse vocale à une application basée sur des chats. -**3. IDE de prompt** : - Interface intuitive pour créer des prompts, comparer les performances des modèles et ajouter des fonctionnalités supplémentaires telles que la synthèse vocale à une application basée sur des chats. +**4. Pipeline RAG** : +Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants. -**4. Pipeline RAG** : - Des capacités RAG étendues qui couvrent tout, de l'ingestion de documents à la récupération, avec un support prêt à l'emploi pour l'extraction de texte à partir de PDF, PPT et autres formats de document courants. +**5. Capacités d'agent** : +Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha. -**5. Capacités d'agent** : - Vous pouvez définir des agents basés sur l'appel de fonction LLM ou ReAct, et ajouter des outils pré-construits ou personnalisés pour l'agent. Dify fournit plus de 50 outils intégrés pour les agents d'IA, tels que la recherche Google, DALL·E, Stable Diffusion et WolframAlpha. +**6. LLMOps** : +Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations. -**6. LLMOps** : - Surveillez et analysez les journaux d'application et les performances au fil du temps. Vous pouvez continuellement améliorer les prompts, les ensembles de données et les modèles en fonction des données de production et des annotations. - -**7. Backend-as-a-Service** : - Toutes les offres de Dify sont accompagnées d'API correspondantes, vous permettant d'intégrer facilement Dify dans votre propre logique métier. - - -## Comparaison des fonctionnalités - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FonctionnalitéDify.AILangChainFlowiseOpenAI Assistants API
Approche de programmationAPI + ApplicationCode PythonApplicationAPI
LLMs pris en chargeGrande variétéGrande variétéGrande variétéUniquement OpenAI
Moteur RAG
Agent
Flux de travail
Observabilité
Fonctionnalité d'entreprise (SSO/Contrôle d'accès)
Déploiement local
+**7. Backend-as-a-Service** : +Toutes les offres de Dify sont accompagnées d'API correspondantes, vous permettant d'intégrer facilement Dify dans votre propre logique métier. ## Utiliser Dify - **Cloud
** -Nous hébergeons un service [Dify Cloud](https://dify.ai) pour que tout le monde puisse l'essayer sans aucune configuration. Il fournit toutes les capacités de la version auto-hébergée et comprend 200 appels GPT-4 gratuits dans le plan bac à sable. + Nous hébergeons un service [Dify Cloud](https://dify.ai) pour que tout le monde puisse l'essayer sans aucune configuration. Il fournit toutes les capacités de la version auto-hébergée et comprend 200 appels GPT-4 gratuits dans le plan bac à sable. - **Auto-hébergement Dify Community Edition
** -Lancez rapidement Dify dans votre environnement avec ce [guide de démarrage](#quick-start). -Utilisez notre [documentation](https://docs.dify.ai) pour plus de références et des instructions plus détaillées. + Lancez rapidement Dify dans votre environnement avec ce [guide de démarrage](#quick-start). + Utilisez notre [documentation](https://docs.dify.ai) pour plus de références et des instructions plus détaillées. - **Dify pour les entreprises / organisations
** -Nous proposons des fonctionnalités supplémentaires adaptées aux entreprises. [Envoyez-nous un e-mail](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) pour discuter des besoins de l'entreprise.
- > Pour les startups et les petites entreprises utilisant AWS, consultez [Dify Premium sur AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) et déployez-le dans votre propre VPC AWS en un clic. C'est une offre AMI abordable avec la possibilité de créer des applications avec un logo et une marque personnalisés. + Nous proposons des fonctionnalités supplémentaires adaptées aux entreprises. [Envoyez-nous un e-mail](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) pour discuter des besoins de l'entreprise.
+ > Pour les startups et les petites entreprises utilisant AWS, consultez [Dify Premium sur AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) et déployez-le dans votre propre VPC AWS en un clic. C'est une offre AMI abordable avec la possibilité de créer des applications avec un logo et une marque personnalisés. ## Rester en avance @@ -168,13 +99,12 @@ Mettez une étoile à Dify sur GitHub et soyez instantanément informé des nouv ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - - ## Démarrage rapide + > Avant d'installer Dify, assurez-vous que votre machine répond aux exigences système minimales suivantes: -> ->- CPU >= 2 cœurs ->- RAM >= 4 Go +> +> - CPU >= 2 cœurs +> - RAM >= 4 Go
@@ -208,17 +138,21 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau Déployez Dify sur une plateforme cloud en un clic en utilisant [terraform](https://www.terraform.io/) ##### Azure Global + - [Azure Terraform par @nikawang](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [Google Cloud Terraform par @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) #### Utilisation d'AWS CDK pour le déploiement Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) -##### AWS -- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS + +- [AWS CDK par @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK par @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud @@ -228,14 +162,16 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Utilisation d'Azure Devops Pipeline pour déployer sur AKS + +Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ## Contribuer -Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences. - -> Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c). +> Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c). **Contributeurs** @@ -245,15 +181,22 @@ Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur le ## Communauté & Contact -* [Discussion GitHub](https://github.com/langgenius/dify/discussions). Meilleur pour: partager des commentaires et poser des questions. -* [Problèmes GitHub](https://github.com/langgenius/dify/issues). Meilleur pour: les bogues que vous rencontrez en utilisant Dify.AI et les propositions de fonctionnalités. Consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -* [Discord](https://discord.gg/FngNHpbcY7). Meilleur pour: partager vos applications et passer du temps avec la communauté. -* [X(Twitter)](https://twitter.com/dify_ai). Meilleur pour: partager vos applications et passer du temps avec la communauté. +- [Discussion GitHub](https://github.com/langgenius/dify/discussions). Meilleur pour: partager des commentaires et poser des questions. +- [Problèmes GitHub](https://github.com/langgenius/dify/issues). Meilleur pour: les bogues que vous rencontrez en utilisant Dify.AI et les propositions de fonctionnalités. Consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [Discord](https://discord.gg/FngNHpbcY7). Meilleur pour: partager vos applications et passer du temps avec la communauté. +- [X(Twitter)](https://twitter.com/dify_ai). Meilleur pour: partager vos applications et passer du temps avec la communauté. ## Historique des étoiles [![Graphique de l'historique des étoiles](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) +## Divulgation de sécurité + +Pour protéger votre vie privée, veuillez éviter de publier des problèmes de sécurité sur GitHub. Au lieu de cela, envoyez vos questions à security@dify.ai et nous vous fournirons une réponse plus détaillée. + +## Licence + +Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires. ## Divulgation de sécurité @@ -262,10 +205,3 @@ Pour protéger votre vie privée, veuillez éviter de publier des problèmes de ## Licence Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires. -## Divulgation de sécurité - -Pour protéger votre vie privée, veuillez éviter de publier des problèmes de sécurité sur GitHub. Au lieu de cela, envoyez vos questions à security@dify.ai et nous vous fournirons une réponse plus détaillée. - -## Licence - -Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires. diff --git a/README_JA.md b/README_JA.md index c658225f90..785706a88a 100644 --- a/README_JA.md +++ b/README_JA.md @@ -48,7 +48,7 @@ README in বাংলা

-# +#

langgenius%2Fdify | Trendshift @@ -58,110 +58,41 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ

**1. ワークフロー**: - 強力なAIワークフローをビジュアルキャンバス上で構築し、テストできます。すべての機能、および以下の機能を使用できます。 +強力なAIワークフローをビジュアルキャンバス上で構築し、テストできます。すべての機能、および以下の機能を使用できます。 **2. 総合的なモデルサポート**: - 数百ものプロプライエタリ/オープンソースのLLMと、数十もの推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、OpenAI APIと互換性のあるすべてのモデルを統合されています。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。 +数百ものプロプライエタリ/オープンソースのLLMと、数十もの推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、OpenAI APIと互換性のあるすべてのモデルを統合されています。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。 ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) - **3. プロンプトIDE**: - プロンプトの作成、モデルパフォーマンスの比較が行え、チャットベースのアプリに音声合成などの機能も追加できます。 +プロンプトの作成、モデルパフォーマンスの比較が行え、チャットベースのアプリに音声合成などの機能も追加できます。 **4. RAGパイプライン**: - ドキュメントの取り込みから検索までをカバーする広範なRAG機能ができます。ほかにもPDF、PPT、その他の一般的なドキュメントフォーマットからのテキスト抽出のサポートも提供します。 +ドキュメントの取り込みから検索までをカバーする広範なRAG機能ができます。ほかにもPDF、PPT、その他の一般的なドキュメントフォーマットからのテキスト抽出のサポートも提供します。 **5. エージェント機能**: - LLM Function CallingやReActに基づくエージェントの定義が可能で、AIエージェント用のプリビルトまたはカスタムツールを追加できます。Difyには、Google検索、DALL·E、Stable Diffusion、WolframAlphaなどのAIエージェント用の50以上の組み込みツールが提供します。 +LLM Function CallingやReActに基づくエージェントの定義が可能で、AIエージェント用のプリビルトまたはカスタムツールを追加できます。Difyには、Google検索、DALL·E、Stable Diffusion、WolframAlphaなどのAIエージェント用の50以上の組み込みツールが提供します。 **6. LLMOps**: - アプリケーションのログやパフォーマンスを監視と分析し、生産のデータと注釈に基づいて、プロンプト、データセット、モデルを継続的に改善できます。 +アプリケーションのログやパフォーマンスを監視と分析し、生産のデータと注釈に基づいて、プロンプト、データセット、モデルを継続的に改善できます。 **7. Backend-as-a-Service**: - すべての機能はAPIを提供されており、Difyを自分のビジネスロジックに簡単に統合できます。 - - -## 機能比較 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
機能Dify.AILangChainFlowiseOpenAI Assistants API
プログラミングアプローチAPI + アプリ指向Pythonコードアプリ指向API指向
サポートされているLLMバラエティ豊かバラエティ豊かバラエティ豊かOpenAIのみ
RAGエンジン
エージェント
ワークフロー
観測性
エンタープライズ機能(SSO/アクセス制御)
ローカル展開
+すべての機能はAPIを提供されており、Difyを自分のビジネスロジックに簡単に統合できます。 ## Difyの使用方法 - **クラウド
** -[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。 + [こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。 - **Dify Community Editionのセルフホスティング
** -この[スタートガイド](#クイックスタート)を使用して、ローカル環境でDifyを簡単に実行できます。 -詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。 + この[スタートガイド](#%E3%82%AF%E3%82%A4%E3%83%83%E3%82%AF%E3%82%B9%E3%82%BF%E3%83%BC%E3%83%88)を使用して、ローカル環境でDifyを簡単に実行できます。 + 詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。 - **企業/組織向けのDify
** -企業中心の機能を提供しています。[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)して企業のニーズについて相談してください。
- > AWSを使用しているスタートアップ企業や中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで自分のAWS VPCにデプロイできます。さらに、手頃な価格のAMIオファリングとして、ロゴやブランディングをカスタマイズしてアプリケーションを作成するオプションがあります。 + 企業中心の機能を提供しています。[メールを送信](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry)して企業のニーズについて相談してください。
+ > AWSを使用しているスタートアップ企業や中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで自分のAWS VPCにデプロイできます。さらに、手頃な価格のAMIオファリングとして、ロゴやブランディングをカスタマイズしてアプリケーションを作成するオプションがあります。 ## 最新の情報を入手 @@ -169,13 +100,12 @@ GitHub上でDifyにスターを付けることで、Difyに関する新しいニ ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - - ## クイックスタート + > Difyをインストールする前に、お使いのマシンが以下の最小システム要件を満たしていることを確認してください: > ->- CPU >= 2コア ->- RAM >= 4GB +> - CPU >= 2コア +> - RAM >= 4GB
@@ -209,9 +139,11 @@ docker compose up -d [terraform](https://www.terraform.io/) を使用して、ワンクリックでDifyをクラウドプラットフォームにデプロイします ##### Azure Global + - [@nikawangによるAzure Terraform](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [@sotazumによるGoogle Cloud Terraform](https://github.com/DeNA/dify-google-cloud-terraform) #### AWS CDK を使用したデプロイ @@ -219,22 +151,28 @@ docker compose up -d [CDK](https://aws.amazon.com/cdk/) を使用して、DifyをAWSにデプロイします ##### AWS -- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + +- [@KevinZhaoによるAWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [@tmokmssによるAWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud + [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) #### Alibaba Cloud Data Management + [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます +#### AKSへのデプロイにAzure Devops Pipelineを使用 + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)を使用してDifyをAKSにワンクリックでデプロイ ## 貢献 コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。 同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。 - -> Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。 +> Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。 **貢献者** @@ -244,12 +182,10 @@ docker compose up -d ## コミュニティ & お問い合わせ -* [GitHub Discussion](https://github.com/langgenius/dify/discussions). 主に: フィードバックの共有や質問。 -* [GitHub Issues](https://github.com/langgenius/dify/issues). 主に: Dify.AIを使用する際に発生するエラーや問題については、[貢献ガイド](CONTRIBUTING_JA.md)を参照してください -* [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。 -* [X(Twitter)](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。 - - +- [GitHub Discussion](https://github.com/langgenius/dify/discussions). 主に: フィードバックの共有や質問。 +- [GitHub Issues](https://github.com/langgenius/dify/issues). 主に: Dify.AIを使用する際に発生するエラーや問題については、[貢献ガイド](CONTRIBUTING_JA.md)を参照してください +- [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。 +- [X(Twitter)](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。 ## ライセンス diff --git a/README_KL.md b/README_KL.md index bfafcc7407..93da9a6140 100644 --- a/README_KL.md +++ b/README_KL.md @@ -48,7 +48,7 @@ README in বাংলা

-# +#

langgenius%2Fdify | Trendshift @@ -56,111 +56,42 @@ Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:

-**1. Workflow**: - Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. +**1. Workflow**: +Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. -**2. Comprehensive model support**: - Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). +**2. Comprehensive model support**: +Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) +**3. Prompt IDE**: +Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app. -**3. Prompt IDE**: - Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app. +**4. RAG Pipeline**: +Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats. -**4. RAG Pipeline**: - Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats. +**5. Agent capabilities**: +You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha. -**5. Agent capabilities**: - You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha. +**6. LLMOps**: +Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations. -**6. LLMOps**: - Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations. - -**7. Backend-as-a-Service**: - All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. - - -## Feature Comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
+**7. Backend-as-a-Service**: +All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. ## Using Dify - **Cloud
** -We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. + We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. - **Self-hosting Dify Community Edition
** -Quickly get Dify running in your environment with this [starter guide](#quick-start). -Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. + Quickly get Dify running in your environment with this [starter guide](#quick-start). + Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. - **Dify for Enterprise / Organizations
** -We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
- > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding. + We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs.
+ > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding. ## Staying ahead @@ -168,13 +99,12 @@ Star Dify on GitHub and be instantly notified of new releases. ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - - ## Quick Start + > Before installing Dify, make sure your machine meets the following minimum system requirements: -> ->- CPU >= 2 Core ->- RAM >= 4GB +> +> - CPU >= 2 Core +> - RAM >= 4GB
@@ -208,17 +138,21 @@ If you'd like to configure a highly-available setup, there are community-contrib wa'logh nIqHom neH ghun deployment toy'wI' [terraform](https://www.terraform.io/) lo'laH. ##### Azure Global + - [Azure Terraform mung @nikawang](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [Google Cloud Terraform qachlot @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) #### AWS CDK atorlugh pilersitsineq wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo'laH. -##### AWS -- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS + +- [AWS CDK qachlot @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK qachlot @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud @@ -228,14 +162,16 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### AKS 'e' Deploy je Azure Devops Pipeline lo'laH + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) lo'laH Dify AKS 'e' wa'DIch click 'e' Deploy ## Contributing -For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). At the same time, please consider supporting Dify by sharing it on social media and at events and conferences. - -> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). +> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). **Contributors** @@ -245,18 +181,18 @@ At the same time, please consider supporting Dify by sharing it on social media ## Community & Contact -* [GitHub Discussion](https://github.com/langgenius/dify/discussions +- \[GitHub Discussion\](https://github.com/langgenius/dify/discussions ). Best for: sharing feedback and asking questions. -* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. -* [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. + +- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. +- [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) - ## Security Disclosure To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. diff --git a/README_KR.md b/README_KR.md index 282117e776..3b58339e12 100644 --- a/README_KR.md +++ b/README_KR.md @@ -48,99 +48,30 @@ README in বাংলা

- - Dify는 오픈 소스 LLM 앱 개발 플랫폼입니다. 직관적인 인터페이스를 통해 AI 워크플로우, RAG 파이프라인, 에이전트 기능, 모델 관리, 관찰 기능 등을 결합하여 프로토타입에서 프로덕션까지 빠르게 전환할 수 있습니다. 주요 기능 목록은 다음과 같습니다:

+Dify는 오픈 소스 LLM 앱 개발 플랫폼입니다. 직관적인 인터페이스를 통해 AI 워크플로우, RAG 파이프라인, 에이전트 기능, 모델 관리, 관찰 기능 등을 결합하여 프로토타입에서 프로덕션까지 빠르게 전환할 수 있습니다. 주요 기능 목록은 다음과 같습니다:

**1. 워크플로우**: - 다음 기능들을 비롯한 다양한 기능을 활용하여 시각적 캔버스에서 강력한 AI 워크플로우를 구축하고 테스트하세요. +다음 기능들을 비롯한 다양한 기능을 활용하여 시각적 캔버스에서 강력한 AI 워크플로우를 구축하고 테스트하세요. -**2. 포괄적인 모델 지원:**: +**2. 포괄적인 모델 지원:**: 수십 개의 추론 제공업체와 자체 호스팅 솔루션에서 제공하는 수백 개의 독점 및 오픈 소스 LLM과 원활하게 통합되며, GPT, Mistral, Llama3 및 모든 OpenAI API 호환 모델을 포함합니다. 지원되는 모델 제공업체의 전체 목록은 [여기](https://docs.dify.ai/getting-started/readme/model-providers)에서 확인할 수 있습니다. ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) - **3. 통합 개발환경**: - 프롬프트를 작성하고, 모델 성능을 비교하며, 텍스트-음성 변환과 같은 추가 기능을 채팅 기반 앱에 추가할 수 있는 직관적인 인터페이스를 제공합니다. +프롬프트를 작성하고, 모델 성능을 비교하며, 텍스트-음성 변환과 같은 추가 기능을 채팅 기반 앱에 추가할 수 있는 직관적인 인터페이스를 제공합니다. -**4. RAG 파이프라인**: - 문서 수집부터 검색까지 모든 것을 다루며, PDF, PPT 및 기타 일반적인 문서 형식에서 텍스트 추출을 위한 기본 지원이 포함되어 있는 광범위한 RAG 기능을 제공합니다. +**4. RAG 파이프라인**: +문서 수집부터 검색까지 모든 것을 다루며, PDF, PPT 및 기타 일반적인 문서 형식에서 텍스트 추출을 위한 기본 지원이 포함되어 있는 광범위한 RAG 기능을 제공합니다. **5. 에이전트 기능**: - LLM 함수 호출 또는 ReAct를 기반으로 에이전트를 정의하고 에이전트에 대해 사전 구축된 도구나 사용자 정의 도구를 추가할 수 있습니다. Dify는 Google Search, DALL·E, Stable Diffusion, WolframAlpha 등 AI 에이전트를 위한 50개 이상의 내장 도구를 제공합니다. +LLM 함수 호출 또는 ReAct를 기반으로 에이전트를 정의하고 에이전트에 대해 사전 구축된 도구나 사용자 정의 도구를 추가할 수 있습니다. Dify는 Google Search, DALL·E, Stable Diffusion, WolframAlpha 등 AI 에이전트를 위한 50개 이상의 내장 도구를 제공합니다. **6. LLMOps**: - 시간 경과에 따른 애플리케이션 로그와 성능을 모니터링하고 분석합니다. 생산 데이터와 주석을 기반으로 프롬프트, 데이터세트, 모델을 지속적으로 개선할 수 있습니다. +시간 경과에 따른 애플리케이션 로그와 성능을 모니터링하고 분석합니다. 생산 데이터와 주석을 기반으로 프롬프트, 데이터세트, 모델을 지속적으로 개선할 수 있습니다. **7. Backend-as-a-Service**: - Dify의 모든 제품에는 해당 API가 함께 제공되므로 Dify를 자신의 비즈니스 로직에 쉽게 통합할 수 있습니다. - -## 기능 비교 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
기능Dify.AILangChainFlowiseOpenAI Assistants API
프로그래밍 접근 방식API + 앱 중심Python 코드앱 중심API 중심
지원되는 LLMs다양한 종류다양한 종류다양한 종류OpenAI 전용
RAG 엔진
에이전트
워크플로우
가시성
기업용 기능 (SSO/접근 제어)
로컬 배포
+Dify의 모든 제품에는 해당 API가 함께 제공되므로 Dify를 자신의 비즈니스 로직에 쉽게 통합할 수 있습니다. ## Dify 사용하기 @@ -148,27 +79,26 @@ 우리는 누구나 설정이 필요 없이 사용해 볼 수 있도록 [Dify 클라우드](https://dify.ai) 서비스를 호스팅합니다. 이는 자체 배포 버전의 모든 기능을 제공하며, 샌드박스 플랜에서 무료로 200회의 GPT-4 호출을 포함합니다. - **셀프-호스팅 Dify 커뮤니티 에디션
** - 환경에서 Dify를 빠르게 실행하려면 이 [스타터 가이드를](#quick-start) 참조하세요. + 환경에서 Dify를 빠르게 실행하려면 이 [스타터 가이드를](#quick-start) 참조하세요. 추가 참조 및 더 심층적인 지침은 [문서](https://docs.dify.ai)를 사용하세요. - **기업 / 조직을 위한 Dify
** - 우리는 추가적인 기업 중심 기능을 제공합니다. 잡거나 [이메일 보내기](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)를 통해 기업 요구 사항을 논의하십시오.
+ 우리는 추가적인 기업 중심 기능을 제공합니다. 잡거나 [이메일 보내기](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry)를 통해 기업 요구 사항을 논의하십시오.
+ > AWS를 사용하는 스타트업 및 중소기업의 경우 [AWS Marketplace에서 Dify Premium](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)을 확인하고 한 번의 클릭으로 자체 AWS VPC에 배포하십시오. 맞춤형 로고와 브랜딩이 포함된 앱을 생성할 수 있는 옵션이 포함된 저렴한 AMI 제품입니다. - - ## 앞서가기 GitHub에서 Dify에 별표를 찍어 새로운 릴리스를 즉시 알림 받으세요. ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - - ## 빠른 시작 ->Dify를 설치하기 전에 컴퓨터가 다음과 같은 최소 시스템 요구 사항을 충족하는지 확인하세요 : ->- CPU >= 2 Core ->- RAM >= 4GB + +> Dify를 설치하기 전에 컴퓨터가 다음과 같은 최소 시스템 요구 사항을 충족하는지 확인하세요 : +> +> - CPU >= 2 Core +> - RAM >= 4GB
@@ -202,17 +132,21 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 [terraform](https://www.terraform.io/)을 사용하여 단 한 번의 클릭으로 Dify를 클라우드 플랫폼에 배포하십시오 ##### Azure Global + - [nikawang의 Azure Terraform](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [sotazum의 Google Cloud Terraform](https://github.com/DeNA/dify-google-cloud-terraform) #### AWS CDK를 사용한 배포 [CDK](https://aws.amazon.com/cdk/)를 사용하여 AWS에 Dify 배포 -##### AWS -- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS + +- [KevinZhao의 AWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [tmokmss의 AWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud @@ -222,14 +156,16 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다 +#### AKS에 배포하기 위해 Azure Devops Pipeline 사용 + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS)을 사용하여 Dify를 AKS에 원클릭으로 배포 ## 기여 코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. 동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다. - -> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요. +> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요. **기여자** @@ -239,17 +175,15 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ## 커뮤니티 & 연락처 -* [GitHub 토론](https://github.com/langgenius/dify/discussions). 피드백 공유 및 질문하기에 적합합니다. -* [GitHub 이슈](https://github.com/langgenius/dify/issues). Dify.AI 사용 중 발견한 버그와 기능 제안에 적합합니다. [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. -* [디스코드](https://discord.gg/FngNHpbcY7). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. -* [트위터](https://twitter.com/dify_ai). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. - +- [GitHub 토론](https://github.com/langgenius/dify/discussions). 피드백 공유 및 질문하기에 적합합니다. +- [GitHub 이슈](https://github.com/langgenius/dify/issues). Dify.AI 사용 중 발견한 버그와 기능 제안에 적합합니다. [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. +- [디스코드](https://discord.gg/FngNHpbcY7). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. +- [트위터](https://twitter.com/dify_ai). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. ## Star 히스토리 [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) - ## 보안 공개 개인정보 보호를 위해 보안 문제를 GitHub에 게시하지 마십시오. 대신 security@dify.ai로 질문을 보내주시면 더 자세한 답변을 드리겠습니다. diff --git a/README_PT.md b/README_PT.md index 576f6b48f7..ec2e4245f6 100644 --- a/README_PT.md +++ b/README_PT.md @@ -1,4 +1,5 @@ ![cover-v5-optimized](./images/GitHub_README_if.png) +

📌 Introduzindo o Dify Workflow com Upload de Arquivo: Recrie o Podcast Google NotebookLM

@@ -55,111 +56,42 @@ Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto. Sua interface intuitiva combina workflow de IA, pipeline RAG, capacidades de agente, gerenciamento de modelos, recursos de observabilidade e muito mais, permitindo que você vá rapidamente do protótipo à produção. Aqui está uma lista das principais funcionalidades:

-**1. Workflow**: - Construa e teste workflows poderosos de IA em uma interface visual, aproveitando todos os recursos a seguir e muito mais. +**1. Workflow**: +Construa e teste workflows poderosos de IA em uma interface visual, aproveitando todos os recursos a seguir e muito mais. -**2. Suporte abrangente a modelos**: - Integração perfeita com centenas de LLMs proprietários e de código aberto de diversas provedoras e soluções auto-hospedadas, abrangendo GPT, Mistral, Llama3 e qualquer modelo compatível com a API da OpenAI. A lista completa de provedores suportados pode ser encontrada [aqui](https://docs.dify.ai/getting-started/readme/model-providers). +**2. Suporte abrangente a modelos**: +Integração perfeita com centenas de LLMs proprietários e de código aberto de diversas provedoras e soluções auto-hospedadas, abrangendo GPT, Mistral, Llama3 e qualquer modelo compatível com a API da OpenAI. A lista completa de provedores suportados pode ser encontrada [aqui](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) +**3. IDE de Prompt**: +Interface intuitiva para criação de prompts, comparação de desempenho de modelos e adição de recursos como conversão de texto para fala em um aplicativo baseado em chat. -**3. IDE de Prompt**: - Interface intuitiva para criação de prompts, comparação de desempenho de modelos e adição de recursos como conversão de texto para fala em um aplicativo baseado em chat. +**4. Pipeline RAG**: +Extensas capacidades de RAG que cobrem desde a ingestão de documentos até a recuperação, com suporte nativo para extração de texto de PDFs, PPTs e outros formatos de documentos comuns. -**4. Pipeline RAG**: - Extensas capacidades de RAG que cobrem desde a ingestão de documentos até a recuperação, com suporte nativo para extração de texto de PDFs, PPTs e outros formatos de documentos comuns. +**5. Capacidades de agente**: +Você pode definir agentes com base em LLM Function Calling ou ReAct e adicionar ferramentas pré-construídas ou personalizadas para o agente. O Dify oferece mais de 50 ferramentas integradas para agentes de IA, como Google Search, DALL·E, Stable Diffusion e WolframAlpha. -**5. Capacidades de agente**: - Você pode definir agentes com base em LLM Function Calling ou ReAct e adicionar ferramentas pré-construídas ou personalizadas para o agente. O Dify oferece mais de 50 ferramentas integradas para agentes de IA, como Google Search, DALL·E, Stable Diffusion e WolframAlpha. +**6. LLMOps**: +Monitore e analise os registros e o desempenho do aplicativo ao longo do tempo. É possível melhorar continuamente prompts, conjuntos de dados e modelos com base nos dados de produção e anotações. -**6. LLMOps**: - Monitore e analise os registros e o desempenho do aplicativo ao longo do tempo. É possível melhorar continuamente prompts, conjuntos de dados e modelos com base nos dados de produção e anotações. - -**7. Backend como Serviço**: - Todas os recursos do Dify vêm com APIs correspondentes, permitindo que você integre o Dify sem esforço na lógica de negócios da sua empresa. - - -## Comparação de recursos - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
RecursoDify.AILangChainFlowiseOpenAI Assistants API
Abordagem de ProgramaçãoOrientada a API + AplicativoCódigo PythonOrientada a AplicativoOrientada a API
LLMs SuportadosVariedade RicaVariedade RicaVariedade RicaApenas OpenAI
RAG Engine
Agente
Workflow
Observabilidade
Recursos Empresariais (SSO/Controle de Acesso)
Implantação Local
+**7. Backend como Serviço**: +Todas os recursos do Dify vêm com APIs correspondentes, permitindo que você integre o Dify sem esforço na lógica de negócios da sua empresa. ## Usando o Dify - **Nuvem
** -Oferecemos o serviço [Dify Cloud](https://dify.ai) para qualquer pessoa experimentar sem nenhuma configuração. Ele fornece todas as funcionalidades da versão auto-hospedada, incluindo 200 chamadas GPT-4 gratuitas no plano sandbox. + Oferecemos o serviço [Dify Cloud](https://dify.ai) para qualquer pessoa experimentar sem nenhuma configuração. Ele fornece todas as funcionalidades da versão auto-hospedada, incluindo 200 chamadas GPT-4 gratuitas no plano sandbox. - **Auto-hospedagem do Dify Community Edition
** -Configure rapidamente o Dify no seu ambiente com este [guia inicial](#quick-start). -Use nossa [documentação](https://docs.dify.ai) para referências adicionais e instruções mais detalhadas. + Configure rapidamente o Dify no seu ambiente com este [guia inicial](#quick-start). + Use nossa [documentação](https://docs.dify.ai) para referências adicionais e instruções mais detalhadas. - **Dify para empresas/organizações
** -Oferecemos recursos adicionais voltados para empresas. [Envie suas perguntas através deste chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) ou [envie-nos um e-mail](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) para discutir necessidades empresariais.
- > Para startups e pequenas empresas que utilizam AWS, confira o [Dify Premium no AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e implemente no seu próprio AWS VPC com um clique. É uma oferta AMI acessível com a opção de criar aplicativos com logotipo e marca personalizados. + Oferecemos recursos adicionais voltados para empresas. [Envie suas perguntas através deste chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) ou [envie-nos um e-mail](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) para discutir necessidades empresariais.
+ > Para startups e pequenas empresas que utilizam AWS, confira o [Dify Premium no AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e implemente no seu próprio AWS VPC com um clique. É uma oferta AMI acessível com a opção de criar aplicativos com logotipo e marca personalizados. ## Mantendo-se atualizado @@ -167,13 +99,12 @@ Dê uma estrela no Dify no GitHub e seja notificado imediatamente sobre novos la ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - - ## Início rápido + > Antes de instalar o Dify, certifique-se de que sua máquina atenda aos seguintes requisitos mínimos de sistema: -> ->- CPU >= 2 Núcleos ->- RAM >= 4 GiB +> +> - CPU >= 2 Núcleos +> - RAM >= 4 GiB
@@ -207,17 +138,21 @@ Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts] Implante o Dify na Plataforma Cloud com um único clique usando [terraform](https://www.terraform.io/) ##### Azure Global + - [Azure Terraform por @nikawang](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [Google Cloud Terraform por @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) #### Usando AWS CDK para Implantação Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) -##### AWS -- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS + +- [AWS CDK por @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK por @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud @@ -227,13 +162,16 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Usando Azure Devops Pipeline para Implantar no AKS + +Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ## Contribuindo -Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências. -> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c). +> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c). **Contribuidores** @@ -243,10 +181,10 @@ Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em ## Comunidade e contato -* [Discussões no GitHub](https://github.com/langgenius/dify/discussions). Melhor para: compartilhar feedback e fazer perguntas. -* [Problemas no GitHub](https://github.com/langgenius/dify/issues). Melhor para: relatar bugs encontrados no Dify.AI e propor novos recursos. Veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -* [Discord](https://discord.gg/FngNHpbcY7). Melhor para: compartilhar suas aplicações e interagir com a comunidade. -* [X(Twitter)](https://twitter.com/dify_ai). Melhor para: compartilhar suas aplicações e interagir com a comunidade. +- [Discussões no GitHub](https://github.com/langgenius/dify/discussions). Melhor para: compartilhar feedback e fazer perguntas. +- [Problemas no GitHub](https://github.com/langgenius/dify/issues). Melhor para: relatar bugs encontrados no Dify.AI e propor novos recursos. Veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [Discord](https://discord.gg/FngNHpbcY7). Melhor para: compartilhar suas aplicações e interagir com a comunidade. +- [X(Twitter)](https://twitter.com/dify_ai). Melhor para: compartilhar suas aplicações e interagir com a comunidade. ## Histórico de estrelas diff --git a/README_SI.md b/README_SI.md index 7ded001d86..c20dc3484f 100644 --- a/README_SI.md +++ b/README_SI.md @@ -50,14 +50,14 @@ README in বাংলা

- -Dify je odprtokodna platforma za razvoj aplikacij LLM. Njegov intuitivni vmesnik združuje agentski potek dela z umetno inteligenco, cevovod RAG, zmogljivosti agentov, upravljanje modelov, funkcije opazovanja in več, kar vam omogoča hiter prehod od prototipa do proizvodnje. +Dify je odprtokodna platforma za razvoj aplikacij LLM. Njegov intuitivni vmesnik združuje agentski potek dela z umetno inteligenco, cevovod RAG, zmogljivosti agentov, upravljanje modelov, funkcije opazovanja in več, kar vam omogoča hiter prehod od prototipa do proizvodnje. ## Hitri začetek + > Preden namestite Dify, se prepričajte, da vaša naprava izpolnjuje naslednje minimalne sistemske zahteve: -> ->- CPU >= 2 Core ->- RAM >= 4 GiB +> +> - CPU >= 2 Core +> - RAM >= 4 GiB
@@ -73,116 +73,48 @@ docker compose up -d Po zagonu lahko dostopate do nadzorne plošče Dify v brskalniku na [http://localhost/install](http://localhost/install) in začnete postopek inicializacije. #### Iskanje pomoči + Prosimo, glejte naša pogosta vprašanja [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) če naletite na težave pri nastavitvi Dify. Če imate še vedno težave, se obrnite na [skupnost ali nas](#community--contact). > Če želite prispevati k Difyju ali narediti dodaten razvoj, glejte naš vodnik za [uvajanje iz izvorne kode](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) ## Ključne značilnosti -**1. Potek dela**: - Zgradite in preizkusite zmogljive poteke dela AI na vizualnem platnu, pri čemer izkoristite vse naslednje funkcije in več. -**2. Celovita podpora za modele**: - Brezhibna integracija s stotinami lastniških/odprtokodnih LLM-jev ducatov ponudnikov sklepanja in samostojnih rešitev, ki pokrivajo GPT, Mistral, Llama3 in vse modele, združljive z API-jem OpenAI. Celoten seznam podprtih ponudnikov modelov najdete [tukaj](https://docs.dify.ai/getting-started/readme/model-providers). +**1. Potek dela**: +Zgradite in preizkusite zmogljive poteke dela AI na vizualnem platnu, pri čemer izkoristite vse naslednje funkcije in več. + +**2. Celovita podpora za modele**: +Brezhibna integracija s stotinami lastniških/odprtokodnih LLM-jev ducatov ponudnikov sklepanja in samostojnih rešitev, ki pokrivajo GPT, Mistral, Llama3 in vse modele, združljive z API-jem OpenAI. Celoten seznam podprtih ponudnikov modelov najdete [tukaj](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) +**3. Prompt IDE**: +intuitivni vmesnik za ustvarjanje pozivov, primerjavo zmogljivosti modela in dodajanje dodatnih funkcij, kot je pretvorba besedila v govor, aplikaciji, ki temelji na klepetu. -**3. Prompt IDE**: - intuitivni vmesnik za ustvarjanje pozivov, primerjavo zmogljivosti modela in dodajanje dodatnih funkcij, kot je pretvorba besedila v govor, aplikaciji, ki temelji na klepetu. +**4. RAG Pipeline**: +E Obsežne zmogljivosti RAG, ki pokrivajo vse od vnosa dokumenta do priklica, s podporo za ekstrakcijo besedila iz datotek PDF, PPT in drugih običajnih formatov dokumentov. -**4. RAG Pipeline**: - E Obsežne zmogljivosti RAG, ki pokrivajo vse od vnosa dokumenta do priklica, s podporo za ekstrakcijo besedila iz datotek PDF, PPT in drugih običajnih formatov dokumentov. +**5. Agent capabilities**: +definirate lahko agente, ki temeljijo na klicanju funkcij LLM ali ReAct, in dodate vnaprej izdelana orodja ali orodja po meri za agenta. Dify ponuja več kot 50 vgrajenih orodij za agente AI, kot so Google Search, DALL·E, Stable Diffusion in WolframAlpha. -**5. Agent capabilities**: - definirate lahko agente, ki temeljijo na klicanju funkcij LLM ali ReAct, in dodate vnaprej izdelana orodja ali orodja po meri za agenta. Dify ponuja več kot 50 vgrajenih orodij za agente AI, kot so Google Search, DALL·E, Stable Diffusion in WolframAlpha. +**6. LLMOps**: +Spremljajte in analizirajte dnevnike aplikacij in učinkovitost skozi čas. Pozive, nabore podatkov in modele lahko nenehno izboljšujete na podlagi proizvodnih podatkov in opomb. -**6. LLMOps**: - Spremljajte in analizirajte dnevnike aplikacij in učinkovitost skozi čas. Pozive, nabore podatkov in modele lahko nenehno izboljšujete na podlagi proizvodnih podatkov in opomb. - -**7. Backend-as-a-Service**: - AVse ponudbe Difyja so opremljene z ustreznimi API-ji, tako da lahko Dify brez težav integrirate v svojo poslovno logiko. - -## Primerjava Funkcij - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FunkcijaDify.AILangChainFlowiseOpenAI Assistants API
Programski pristopAPI + usmerjeno v aplikacijePython kodaUsmerjeno v aplikacijeUsmerjeno v API
Podprti LLM-jiBogata izbiraBogata izbiraBogata izbiraSamo OpenAI
RAG pogon
Agent
Potek dela
Spremljanje
Funkcija za podjetja (SSO/nadzor dostopa)
Lokalna namestitev
+**7. Backend-as-a-Service**: +AVse ponudbe Difyja so opremljene z ustreznimi API-ji, tako da lahko Dify brez težav integrirate v svojo poslovno logiko. ## Uporaba Dify - **Cloud
** -Gostimo storitev Dify Cloud za vsakogar, ki jo lahko preizkusite brez nastavitev. Zagotavlja vse zmožnosti različice za samostojno namestitev in vključuje 200 brezplačnih klicev GPT-4 v načrtu peskovnika. + Gostimo storitev Dify Cloud za vsakogar, ki jo lahko preizkusite brez nastavitev. Zagotavlja vse zmožnosti različice za samostojno namestitev in vključuje 200 brezplačnih klicev GPT-4 v načrtu peskovnika. - **Self-hosting Dify Community Edition
** -Hitro zaženite Dify v svojem okolju s tem [začetnim vodnikom](#quick-start) . Za dodatne reference in podrobnejša navodila uporabite našo [dokumentacijo](https://docs.dify.ai) . - + Hitro zaženite Dify v svojem okolju s tem [začetnim vodnikom](#quick-start) . Za dodatne reference in podrobnejša navodila uporabite našo [dokumentacijo](https://docs.dify.ai) . - **Dify za podjetja/organizacije
** -Ponujamo dodatne funkcije, osredotočene na podjetja. Zabeležite svoja vprašanja prek tega klepetalnega robota ali nam pošljite e-pošto, da se pogovorimo o potrebah podjetja.
- > Za novoustanovljena podjetja in mala podjetja, ki uporabljajo AWS, si oglejte Dify Premium na AWS Marketplace in ga z enim klikom uvedite v svoj AWS VPC. To je cenovno ugodna ponudba AMI z možnostjo ustvarjanja aplikacij z logotipom in blagovno znamko po meri. + Ponujamo dodatne funkcije, osredotočene na podjetja. Zabeležite svoja vprašanja prek tega klepetalnega robota ali nam pošljite e-pošto, da se pogovorimo o potrebah podjetja.
+ > Za novoustanovljena podjetja in mala podjetja, ki uporabljajo AWS, si oglejte Dify Premium na AWS Marketplace in ga z enim klikom uvedite v svoj AWS VPC. To je cenovno ugodna ponudba AMI z možnostjo ustvarjanja aplikacij z logotipom in blagovno znamko po meri. ## Staying ahead @@ -190,7 +122,6 @@ Star Dify on GitHub and be instantly notified of new releases. ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - ## Napredne nastavitve Če morate prilagoditi konfiguracijo, si oglejte komentarje v naši datoteki .env.example in posodobite ustrezne vrednosti v svoji .env datoteki. Poleg tega boste morda morali prilagoditi docker-compose.yamlsamo datoteko, na primer spremeniti različice slike, preslikave vrat ali namestitve nosilca, glede na vaše specifično okolje in zahteve za uvajanje. Po kakršnih koli spremembah ponovno zaženite docker-compose up -d. Celoten seznam razpoložljivih spremenljivk okolja najdete tukaj . @@ -208,17 +139,21 @@ Star Dify on GitHub and be instantly notified of new releases. namestite Dify v Cloud Platform z enim klikom z uporabo [terraform](https://www.terraform.io/) ##### Azure Global + - [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) #### Uporaba AWS CDK za uvajanje Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) -##### AWS -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS + +- [AWS CDK by @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK by @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud @@ -228,21 +163,22 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Uporaba Azure Devops Pipeline za uvajanje v AKS + +Z enim klikom namestite Dify v AKS z uporabo [Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ## Prispevam -Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah. - - +Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah. > Iščemo sodelavce za pomoč pri prevajanju Difyja v jezike, ki niso mandarinščina ali angleščina. Če želite pomagati, si oglejte i18n README za več informacij in nam pustite komentar v global-userskanalu našega strežnika skupnosti Discord . ## Skupnost in stik -* [GitHub Discussion](https://github.com/langgenius/dify/discussions). Najboljše za: izmenjavo povratnih informacij in postavljanje vprašanj. -* [GitHub Issues](https://github.com/langgenius/dify/issues). Najboljše za: hrošče, na katere naletite pri uporabi Dify.AI, in predloge funkcij. Oglejte si naš [vodnik za prispevke](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -* [Discord](https://discord.gg/FngNHpbcY7). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo. -* [X(Twitter)](https://twitter.com/dify_ai). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo. +- [GitHub Discussion](https://github.com/langgenius/dify/discussions). Najboljše za: izmenjavo povratnih informacij in postavljanje vprašanj. +- [GitHub Issues](https://github.com/langgenius/dify/issues). Najboljše za: hrošče, na katere naletite pri uporabi Dify.AI, in predloge funkcij. Oglejte si naš [vodnik za prispevke](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +- [Discord](https://discord.gg/FngNHpbcY7). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo. +- [X(Twitter)](https://twitter.com/dify_ai). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo. **Contributors** @@ -254,7 +190,6 @@ Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkra [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) - ## Varnostno razkritje Zaradi zaščite vaše zasebnosti se izogibajte objavljanju varnostnih vprašanj na GitHub. Namesto tega pošljite vprašanja na security@dify.ai in zagotovili vam bomo podrobnejši odgovor. diff --git a/README_TR.md b/README_TR.md index 6e94e54fa0..510b112e68 100644 --- a/README_TR.md +++ b/README_TR.md @@ -48,11 +48,10 @@ README in বাংলা

- Dify, açık kaynaklı bir LLM uygulama geliştirme platformudur. Sezgisel arayüzü, AI iş akışı, RAG pipeline'ı, ajan yetenekleri, model yönetimi, gözlemlenebilirlik özellikleri ve daha fazlasını birleştirerek, prototipten üretime hızlıca geçmenizi sağlar. İşte temel özelliklerin bir listesi:

-**1. Workflow**: +**1. Workflow**: Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edin, aşağıdaki tüm özellikleri ve daha fazlasını kullanarak. **2. Kapsamlı model desteği**: @@ -60,101 +59,33 @@ Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edi ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) +**3. Prompt IDE**: +Komut istemlerini oluşturmak, model performansını karşılaştırmak ve sohbet tabanlı uygulamalara metin-konuşma gibi ek özellikler eklemek için kullanıcı dostu bir arayüz. -**3. Prompt IDE**: - Komut istemlerini oluşturmak, model performansını karşılaştırmak ve sohbet tabanlı uygulamalara metin-konuşma gibi ek özellikler eklemek için kullanıcı dostu bir arayüz. +**4. RAG Pipeline**: +Belge alımından bilgi çekmeye kadar geniş kapsamlı RAG yetenekleri. PDF'ler, PPT'ler ve diğer yaygın belge formatlarından metin çıkarma için hazır destek sunar. -**4. RAG Pipeline**: - Belge alımından bilgi çekmeye kadar geniş kapsamlı RAG yetenekleri. PDF'ler, PPT'ler ve diğer yaygın belge formatlarından metin çıkarma için hazır destek sunar. +**5. Ajan yetenekleri**: +LLM Fonksiyon Çağırma veya ReAct'a dayalı ajanlar tanımlayabilir ve bu ajanlara önceden hazırlanmış veya özel araçlar ekleyebilirsiniz. Dify, AI ajanları için Google Arama, DALL·E, Stable Diffusion ve WolframAlpha gibi 50'den fazla yerleşik araç sağlar. -**5. Ajan yetenekleri**: - LLM Fonksiyon Çağırma veya ReAct'a dayalı ajanlar tanımlayabilir ve bu ajanlara önceden hazırlanmış veya özel araçlar ekleyebilirsiniz. Dify, AI ajanları için Google Arama, DALL·E, Stable Diffusion ve WolframAlpha gibi 50'den fazla yerleşik araç sağlar. +**6. LLMOps**: +Uygulama loglarını ve performans metriklerini zaman içinde izleme ve analiz etme imkanı. Üretim ortamından elde edilen verilere ve kullanıcı geri bildirimlerine dayanarak, prompt'ları, veri setlerini ve modelleri sürekli olarak optimize edebilirsiniz. Bu sayede, AI uygulamanızın performansını ve doğruluğunu sürekli olarak artırabilirsiniz. -**6. LLMOps**: - Uygulama loglarını ve performans metriklerini zaman içinde izleme ve analiz etme imkanı. Üretim ortamından elde edilen verilere ve kullanıcı geri bildirimlerine dayanarak, prompt'ları, veri setlerini ve modelleri sürekli olarak optimize edebilirsiniz. Bu sayede, AI uygulamanızın performansını ve doğruluğunu sürekli olarak artırabilirsiniz. - -**7. Hizmet Olarak Backend**: - Dify'ın tüm özellikleri ilgili API'lerle birlikte gelir, böylece Dify'ı kendi iş mantığınıza kolayca entegre edebilirsiniz. - - -## Özellik karşılaştırması - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ÖzellikDify.AILangChainFlowiseOpenAI Assistants API
Programlama YaklaşımıAPI + Uygulama odaklıPython KoduUygulama odaklıAPI odaklı
Desteklenen LLM'lerZengin ÇeşitlilikZengin ÇeşitlilikZengin ÇeşitlilikYalnızca OpenAI
RAG Motoru
Ajan
İş Akışı
Gözlemlenebilirlik
Kurumsal Özellikler (SSO/Erişim kontrolü)
Yerel Dağıtım
+**7. Hizmet Olarak Backend**: +Dify'ın tüm özellikleri ilgili API'lerle birlikte gelir, böylece Dify'ı kendi iş mantığınıza kolayca entegre edebilirsiniz. ## Dify'ı Kullanma - **Cloud
** -Herkesin sıfır kurulumla denemesi için bir [Dify Cloud](https://dify.ai) hizmeti sunuyoruz. Bu hizmet, kendi kendine dağıtılan versiyonun tüm yeteneklerini sağlar ve sandbox planında 200 ücretsiz GPT-4 çağrısı içerir. + Herkesin sıfır kurulumla denemesi için bir [Dify Cloud](https://dify.ai) hizmeti sunuyoruz. Bu hizmet, kendi kendine dağıtılan versiyonun tüm yeteneklerini sağlar ve sandbox planında 200 ücretsiz GPT-4 çağrısı içerir. - **Dify Topluluk Sürümünü Kendi Sunucunuzda Barındırma
** -Bu [başlangıç kılavuzu](#quick-start) ile Dify'ı kendi ortamınızda hızlıca çalıştırın. -Daha fazla referans ve detaylı talimatlar için [dokümantasyonumuzu](https://docs.dify.ai) kullanın. + Bu [başlangıç kılavuzu](#quick-start) ile Dify'ı kendi ortamınızda hızlıca çalıştırın. + Daha fazla referans ve detaylı talimatlar için [dokümantasyonumuzu](https://docs.dify.ai) kullanın. - **Kurumlar / organizasyonlar için Dify
** -Ek kurumsal odaklı özellikler sunuyoruz. Kurumsal ihtiyaçları görüşmek için [bize bir e-posta gönderin](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry).
+ Ek kurumsal odaklı özellikler sunuyoruz. Kurumsal ihtiyaçları görüşmek için [bize bir e-posta gönderin](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry).
+ > AWS kullanan startuplar ve küçük işletmeler için, [AWS Marketplace'deki Dify Premium'a](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) göz atın ve tek tıklamayla kendi AWS VPC'nize dağıtın. Bu, özel logo ve marka ile uygulamalar oluşturma seçeneğine sahip uygun fiyatlı bir AMI teklifdir. ## Güncel Kalma @@ -163,13 +94,12 @@ GitHub'da Dify'a yıldız verin ve yeni sürümlerden anında haberdar olun. ![bizi-yıldızlayın](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - - ## Hızlı başlangıç + > Dify'ı kurmadan önce, makinenizin aşağıdaki minimum sistem gereksinimlerini karşıladığından emin olun: -> ->- CPU >= 2 Çekirdek ->- RAM >= 4GB +> +> - CPU >= 2 Çekirdek +> - RAM >= 4GB
Dify sunucusunu başlatmanın en kolay yolu, [docker-compose.yml](docker/docker-compose.yaml) dosyamızı çalıştırmaktır. Kurulum komutunu çalıştırmadan önce, makinenizde [Docker](https://docs.docker.com/get-docker/) ve [Docker Compose](https://docs.docker.com/compose/install/)'un kurulu olduğundan emin olun: @@ -201,17 +131,21 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify' Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.terraform.io/) kullanarak ##### Azure Global + - [Azure Terraform tarafından @nikawang](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [Google Cloud Terraform tarafından @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) #### AWS CDK ile Dağıtım [CDK](https://aws.amazon.com/cdk/) kullanarak Dify'ı AWS'ye dağıtın -##### AWS -- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS + +- [AWS CDK tarafından @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK tarafından @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud @@ -221,13 +155,16 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın +#### AKS'ye Dağıtım için Azure Devops Pipeline Kullanımı + +[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) kullanarak Dify'ı tek tıkla AKS'ye dağıtın ## Katkıda Bulunma Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz. Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda paylaşarak desteklemeyi düşünün. -> Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın. +> Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın. **Katkıda Bulunanlar** @@ -237,10 +174,10 @@ Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda p ## Topluluk & iletişim -* [GitHub Tartışmaları](https://github.com/langgenius/dify/discussions). En uygun: geri bildirim paylaşmak ve soru sormak için. -* [GitHub Sorunları](https://github.com/langgenius/dify/issues). En uygun: Dify.AI kullanırken karşılaştığınız hatalar ve özellik önerileri için. [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakın. -* [Discord](https://discord.gg/FngNHpbcY7). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. -* [X(Twitter)](https://twitter.com/dify_ai). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. +- [GitHub Tartışmaları](https://github.com/langgenius/dify/discussions). En uygun: geri bildirim paylaşmak ve soru sormak için. +- [GitHub Sorunları](https://github.com/langgenius/dify/issues). En uygun: Dify.AI kullanırken karşılaştığınız hatalar ve özellik önerileri için. [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakın. +- [Discord](https://discord.gg/FngNHpbcY7). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. +- [X(Twitter)](https://twitter.com/dify_ai). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. ## Star history diff --git a/README_TW.md b/README_TW.md index 6e3e22b5c1..35a01fa16a 100644 --- a/README_TW.md +++ b/README_TW.md @@ -106,85 +106,18 @@ docker compose up -d **7. 後端即服務**: Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 整合到您自己的業務邏輯中。 -## 功能比較 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
功能Dify.AILangChainFlowiseOpenAI Assistants API
程式設計方法API + 應用導向Python 代碼應用導向API 導向
支援的 LLM 模型豐富多樣豐富多樣豐富多樣僅限 OpenAI
RAG 引擎
代理功能
工作流程
可觀察性
企業級功能 (SSO/存取控制)
本地部署
- ## 使用 Dify - **雲端服務
** 我們提供 [Dify Cloud](https://dify.ai) 服務,任何人都可以零配置嘗試。它提供與自部署版本相同的所有功能,並在沙盒計劃中包含 200 次免費 GPT-4 調用。 - **自託管 Dify 社區版
** - 使用這份[快速指南](#快速開始)在您的環境中快速運行 Dify。 + 使用這份[快速指南](#%E5%BF%AB%E9%80%9F%E9%96%8B%E5%A7%8B)在您的環境中快速運行 Dify。 使用我們的[文檔](https://docs.dify.ai)獲取更多參考和深入指導。 - **企業/組織版 Dify
** - 我們提供額外的企業中心功能。[通過這個聊天機器人記錄您的問題](https://udify.app/chat/22L1zSxg6yW1cWQg)或[發送電子郵件給我們](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)討論企業需求。
+ 我們提供額外的企業中心功能。[通過這個聊天機器人記錄您的問題](https://udify.app/chat/22L1zSxg6yW1cWQg)或[發送電子郵件給我們](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry)討論企業需求。
+ > 對於使用 AWS 的初創企業和小型企業,請查看 [AWS Marketplace 上的 Dify Premium](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6),並一鍵部署到您自己的 AWS VPC。這是一個經濟實惠的 AMI 產品,可選擇使用自定義徽標和品牌創建應用。 ## 保持領先 @@ -223,7 +156,8 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify ### AWS -- [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [由 @KevinZhao 提供的 AWS CDK (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [由 @tmokmss 提供的 AWS CDK (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### 使用 阿里云计算巢進行部署 @@ -233,13 +167,16 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲 +#### 使用 Azure Devops Pipeline 部署到AKS + +使用[Azure Devops Pipeline Helm Chart by @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) 將 Dify 一鍵部署到 AKS ## 貢獻 對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。 -> 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。 +> 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。 ## 社群與聯絡方式 diff --git a/README_VI.md b/README_VI.md index 51314e6de5..f161b20f9d 100644 --- a/README_VI.md +++ b/README_VI.md @@ -48,115 +48,45 @@ README in বাংলা

- Dify là một nền tảng phát triển ứng dụng LLM mã nguồn mở. Giao diện trực quan kết hợp quy trình làm việc AI, mô hình RAG, khả năng tác nhân, quản lý mô hình, tính năng quan sát và hơn thế nữa, cho phép bạn nhanh chóng chuyển từ nguyên mẫu sang sản phẩm. Đây là danh sách các tính năng cốt lõi:

-**1. Quy trình làm việc**: - Xây dựng và kiểm tra các quy trình làm việc AI mạnh mẽ trên một canvas trực quan, tận dụng tất cả các tính năng sau đây và hơn thế nữa. +**1. Quy trình làm việc**: +Xây dựng và kiểm tra các quy trình làm việc AI mạnh mẽ trên một canvas trực quan, tận dụng tất cả các tính năng sau đây và hơn thế nữa. -**2. Hỗ trợ mô hình toàn diện**: - Tích hợp liền mạch với hàng trăm mô hình LLM độc quyền / mã nguồn mở từ hàng chục nhà cung cấp suy luận và giải pháp tự lưu trữ, bao gồm GPT, Mistral, Llama3, và bất kỳ mô hình tương thích API OpenAI nào. Danh sách đầy đủ các nhà cung cấp mô hình được hỗ trợ có thể được tìm thấy [tại đây](https://docs.dify.ai/getting-started/readme/model-providers). +**2. Hỗ trợ mô hình toàn diện**: +Tích hợp liền mạch với hàng trăm mô hình LLM độc quyền / mã nguồn mở từ hàng chục nhà cung cấp suy luận và giải pháp tự lưu trữ, bao gồm GPT, Mistral, Llama3, và bất kỳ mô hình tương thích API OpenAI nào. Danh sách đầy đủ các nhà cung cấp mô hình được hỗ trợ có thể được tìm thấy [tại đây](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) +**3. IDE Prompt**: +Giao diện trực quan để tạo prompt, so sánh hiệu suất mô hình và thêm các tính năng bổ sung như chuyển văn bản thành giọng nói cho một ứng dụng dựa trên trò chuyện. -**3. IDE Prompt**: - Giao diện trực quan để tạo prompt, so sánh hiệu suất mô hình và thêm các tính năng bổ sung như chuyển văn bản thành giọng nói cho một ứng dụng dựa trên trò chuyện. +**4. Mô hình RAG**: +Khả năng RAG mở rộng bao gồm mọi thứ từ nhập tài liệu đến truy xuất, với hỗ trợ sẵn có cho việc trích xuất văn bản từ PDF, PPT và các định dạng tài liệu phổ biến khác. -**4. Mô hình RAG**: - Khả năng RAG mở rộng bao gồm mọi thứ từ nhập tài liệu đến truy xuất, với hỗ trợ sẵn có cho việc trích xuất văn bản từ PDF, PPT và các định dạng tài liệu phổ biến khác. +**5. Khả năng tác nhân**: +Bạn có thể định nghĩa các tác nhân dựa trên LLM Function Calling hoặc ReAct, và thêm các công cụ được xây dựng sẵn hoặc tùy chỉnh cho tác nhân. Dify cung cấp hơn 50 công cụ tích hợp sẵn cho các tác nhân AI, như Google Search, DALL·E, Stable Diffusion và WolframAlpha. -**5. Khả năng tác nhân**: - Bạn có thể định nghĩa các tác nhân dựa trên LLM Function Calling hoặc ReAct, và thêm các công cụ được xây dựng sẵn hoặc tùy chỉnh cho tác nhân. Dify cung cấp hơn 50 công cụ tích hợp sẵn cho các tác nhân AI, như Google Search, DALL·E, Stable Diffusion và WolframAlpha. +**6. LLMOps**: +Giám sát và phân tích nhật ký và hiệu suất ứng dụng theo thời gian. Bạn có thể liên tục cải thiện prompt, bộ dữ liệu và mô hình dựa trên dữ liệu sản xuất và chú thích. -**6. LLMOps**: - Giám sát và phân tích nhật ký và hiệu suất ứng dụng theo thời gian. Bạn có thể liên tục cải thiện prompt, bộ dữ liệu và mô hình dựa trên dữ liệu sản xuất và chú thích. - -**7. Backend-as-a-Service**: - Tất cả các dịch vụ của Dify đều đi kèm với các API tương ứng, vì vậy bạn có thể dễ dàng tích hợp Dify vào logic kinh doanh của riêng mình. - - -## So sánh tính năng - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Tính năngDify.AILangChainFlowiseOpenAI Assistants API
Phương pháp lập trìnhHướng API + Ứng dụngMã PythonHướng ứng dụngHướng API
LLMs được hỗ trợĐa dạng phong phúĐa dạng phong phúĐa dạng phong phúChỉ OpenAI
RAG Engine
Agent
Quy trình làm việc
Khả năng quan sát
Tính năng doanh nghiệp (SSO/Kiểm soát truy cập)
Triển khai cục bộ
+**7. Backend-as-a-Service**: +Tất cả các dịch vụ của Dify đều đi kèm với các API tương ứng, vì vậy bạn có thể dễ dàng tích hợp Dify vào logic kinh doanh của riêng mình. ## Sử dụng Dify - **Cloud
** -Chúng tôi lưu trữ dịch vụ [Dify Cloud](https://dify.ai) cho bất kỳ ai muốn thử mà không cần cài đặt. Nó cung cấp tất cả các khả năng của phiên bản tự triển khai và bao gồm 200 lượt gọi GPT-4 miễn phí trong gói sandbox. + Chúng tôi lưu trữ dịch vụ [Dify Cloud](https://dify.ai) cho bất kỳ ai muốn thử mà không cần cài đặt. Nó cung cấp tất cả các khả năng của phiên bản tự triển khai và bao gồm 200 lượt gọi GPT-4 miễn phí trong gói sandbox. - **Tự triển khai Dify Community Edition
** -Nhanh chóng chạy Dify trong môi trường của bạn với [hướng dẫn bắt đầu](#quick-start) này. -Sử dụng [tài liệu](https://docs.dify.ai) của chúng tôi để tham khảo thêm và nhận hướng dẫn chi tiết hơn. + Nhanh chóng chạy Dify trong môi trường của bạn với [hướng dẫn bắt đầu](#quick-start) này. + Sử dụng [tài liệu](https://docs.dify.ai) của chúng tôi để tham khảo thêm và nhận hướng dẫn chi tiết hơn. - **Dify cho doanh nghiệp / tổ chức
** -Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Ghi lại câu hỏi của bạn cho chúng tôi thông qua chatbot này](https://udify.app/chat/22L1zSxg6yW1cWQg) hoặc [gửi email cho chúng tôi](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp.
- > Đối với các công ty khởi nghiệp và doanh nghiệp nhỏ sử dụng AWS, hãy xem [Dify Premium trên AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) và triển khai nó vào AWS VPC của riêng bạn chỉ với một cú nhấp chuột. Đây là một AMI giá cả phải chăng với tùy chọn tạo ứng dụng với logo và thương hiệu tùy chỉnh. + Chúng tôi cung cấp các tính năng bổ sung tập trung vào doanh nghiệp. [Ghi lại câu hỏi của bạn cho chúng tôi thông qua chatbot này](https://udify.app/chat/22L1zSxg6yW1cWQg) hoặc [gửi email cho chúng tôi](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) để thảo luận về nhu cầu doanh nghiệp.
+ > Đối với các công ty khởi nghiệp và doanh nghiệp nhỏ sử dụng AWS, hãy xem [Dify Premium trên AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) và triển khai nó vào AWS VPC của riêng bạn chỉ với một cú nhấp chuột. Đây là một AMI giá cả phải chăng với tùy chọn tạo ứng dụng với logo và thương hiệu tùy chỉnh. ## Luôn cập nhật @@ -164,13 +94,12 @@ Yêu thích Dify trên GitHub và được thông báo ngay lập tức về cá ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - - ## Bắt đầu nhanh + > Trước khi cài đặt Dify, hãy đảm bảo máy của bạn đáp ứng các yêu cầu hệ thống tối thiểu sau: -> ->- CPU >= 2 Core ->- RAM >= 4GB +> +> - CPU >= 2 Core +> - RAM >= 4GB
@@ -203,18 +132,21 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có Triển khai Dify lên nền tảng đám mây với một cú nhấp chuột bằng cách sử dụng [terraform](https://www.terraform.io/) ##### Azure Global + - [Azure Terraform bởi @nikawang](https://github.com/nikawang/dify-azure-terraform) ##### Google Cloud + - [Google Cloud Terraform bởi @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) #### Sử dụng AWS CDK để Triển khai Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) -##### AWS -- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +##### AWS +- [AWS CDK bởi @KevinZhao (EKS based)](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +- [AWS CDK bởi @tmokmss (ECS based)](https://github.com/aws-samples/dify-self-hosted-on-aws) #### Alibaba Cloud @@ -224,14 +156,16 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) +#### Sử dụng Azure Devops Pipeline để Triển khai lên AKS + +Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure Devops Pipeline Helm Chart bởi @LeoZhang](https://github.com/Ruiruiz30/Dify-helm-chart-AKS) ## Đóng góp -Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. +Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị. - -> Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi. +> Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi. **Người đóng góp** @@ -241,10 +175,10 @@ Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng ## Cộng đồng & liên hệ -* [Thảo luận GitHub](https://github.com/langgenius/dify/discussions). Tốt nhất cho: chia sẻ phản hồi và đặt câu hỏi. -* [Vấn đề GitHub](https://github.com/langgenius/dify/issues). Tốt nhất cho: lỗi bạn gặp phải khi sử dụng Dify.AI và đề xuất tính năng. Xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. -* [Discord](https://discord.gg/FngNHpbcY7). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. -* [X(Twitter)](https://twitter.com/dify_ai). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. +- [Thảo luận GitHub](https://github.com/langgenius/dify/discussions). Tốt nhất cho: chia sẻ phản hồi và đặt câu hỏi. +- [Vấn đề GitHub](https://github.com/langgenius/dify/issues). Tốt nhất cho: lỗi bạn gặp phải khi sử dụng Dify.AI và đề xuất tính năng. Xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. +- [Discord](https://discord.gg/FngNHpbcY7). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. +- [X(Twitter)](https://twitter.com/dify_ai). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. ## Lịch sử Yêu thích diff --git a/api/.env.example b/api/.env.example index 80b1c12cd8..3052dbfe2b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -4,6 +4,11 @@ # Alternatively you can set it with `SECRET_KEY` environment variable. SECRET_KEY= +# Ensure UTF-8 encoding +LANG=en_US.UTF-8 +LC_ALL=en_US.UTF-8 +PYTHONIOENCODING=utf-8 + # Console API base URL CONSOLE_API_URL=http://localhost:5001 CONSOLE_WEB_URL=http://localhost:3000 @@ -37,6 +42,15 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +# SSL configuration for Redis (when REDIS_USE_SSL=true) +REDIS_SSL_CERT_REQS=CERT_NONE +# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +REDIS_SSL_CA_CERTS= +# Path to CA certificate file for SSL verification +REDIS_SSL_CERTFILE= +# Path to client certificate file for SSL authentication +REDIS_SSL_KEYFILE= +# Path to client private key file for SSL authentication REDIS_DB=0 # redis Sentinel configuration. @@ -227,6 +241,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com TABLESTORE_INSTANCE_NAME=instance-name TABLESTORE_ACCESS_KEY_ID=xxx TABLESTORE_ACCESS_KEY_SECRET=xxx +TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false # Tidb Vector configuration TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com @@ -463,6 +478,13 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node # API workflow run repository implementation API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository +# Workflow log cleanup configuration +# Enable automatic cleanup of workflow run logs to manage database size +WORKFLOW_LOG_CLEANUP_ENABLED=true +# Number of days to retain workflow run logs (default: 30 days) +WORKFLOW_LOG_RETENTION_DAYS=30 +# Batch size for workflow log cleanup operations (default: 100) +WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 # App configuration APP_MAX_EXECUTION_TIME=1200 diff --git a/api/.ruff.toml b/api/.ruff.toml index 0169613bf8..db6872b9c8 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -42,6 +42,8 @@ select = [ "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S302", # suspicious-marshal-usage, disallow use of `marshal` module "S311", # suspicious-non-cryptographic-random-usage + "G001", # don't use str format to logging messages + "G004", # don't use f-strings to format logging messages ] ignore = [ diff --git a/api/Dockerfile b/api/Dockerfile index 8c7a1717b9..79a4892768 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base WORKDIR /app/api # Install uv -ENV UV_VERSION=0.7.11 +ENV UV_VERSION=0.8.9 RUN pip install --no-cache-dir uv==${UV_VERSION} @@ -19,7 +19,7 @@ RUN apt-get update \ # Install Python dependencies COPY pyproject.toml uv.lock ./ -RUN uv sync --locked +RUN uv sync --locked --no-dev # production stage FROM base AS production @@ -37,6 +37,11 @@ EXPOSE 5001 # set timezone ENV TZ=UTC +# Set UTF-8 locale +ENV LANG=en_US.UTF-8 +ENV LC_ALL=en_US.UTF-8 +ENV PYTHONIOENCODING=utf-8 + WORKDIR /app/api RUN \ diff --git a/api/README.md b/api/README.md index 6ab923070e..8309a0e69b 100644 --- a/api/README.md +++ b/api/README.md @@ -3,7 +3,7 @@ ## Usage > [!IMPORTANT] -> +> > In the v1.3.0 release, `poetry` has been replaced with > [`uv`](https://docs.astral.sh/uv/) as the package manager > for Dify API backend service. @@ -20,25 +20,29 @@ cd ../api ``` -2. Copy `.env.example` to `.env` +1. Copy `.env.example` to `.env` ```cli - cp .env.example .env + cp .env.example .env ``` -3. Generate a `SECRET_KEY` in the `.env` file. + +1. Generate a `SECRET_KEY` in the `.env` file. bash for Linux + ```bash for Linux sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env ``` + bash for Mac + ```bash for Mac secret_key=$(openssl rand -base64 42) sed -i '' "/^SECRET_KEY=/c\\ SECRET_KEY=${secret_key}" .env ``` -4. Create environment. +1. Create environment. Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies. First, you need to add the uv package manager, if you don't have it already. @@ -49,13 +53,13 @@ brew install uv ``` -5. Install dependencies +1. Install dependencies ```bash uv sync --dev ``` -6. Run migrate +1. Run migrate Before the first launch, migrate the database to the latest version. @@ -63,24 +67,27 @@ uv run flask db upgrade ``` -7. Start backend +1. Start backend ```bash uv run flask run --host 0.0.0.0 --port=5001 --debug ``` -8. Start Dify [web](../web) service. -9. Setup your application by visiting `http://localhost:3000`. -10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. +1. Start Dify [web](../web) service. - ```bash - uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin - ``` +1. Setup your application by visiting `http://localhost:3000`. - Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: - ```bash - uv run celery -A app.celery beat - ``` +1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. + +```bash +uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation +``` + +Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: + +```bash +uv run celery -A app.celery beat +``` ## Testing @@ -90,9 +97,16 @@ uv sync --dev ``` -2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` +1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md) - ```bash - uv run -P api bash dev/pytest/pytest_all_tests.sh + ```cli + uv run --project api pytest # Run all tests + uv run --project api pytest tests/unit_tests/ # Unit tests only + uv run --project api pytest tests/integration_tests/ # Integration tests + + # Code quality + ./dev/reformat # Run all formatters and linters + uv run --project api ruff check --fix ./ # Fix linting issues + uv run --project api ruff format ./ # Format code + uv run --project api mypy . # Type checking ``` - diff --git a/api/app_factory.py b/api/app_factory.py index 3a258be28f..8a0417dd72 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -5,6 +5,8 @@ from configs import dify_config from contexts.wrapper import RecyclableContextVar from dify_app import DifyApp +logger = logging.getLogger(__name__) + # ---------------------------- # Application Factory Function @@ -32,7 +34,7 @@ def create_app() -> DifyApp: initialize_extensions(app) end_time = time.perf_counter() if dify_config.DEBUG: - logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)") + logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2)) return app @@ -51,6 +53,7 @@ def initialize_extensions(app: DifyApp): ext_login, ext_mail, ext_migrate, + ext_orjson, ext_otel, ext_proxy_fix, ext_redis, @@ -67,6 +70,7 @@ def initialize_extensions(app: DifyApp): ext_logging, ext_warnings, ext_import_modules, + ext_orjson, ext_set_secretkey, ext_compress, ext_code_based_extension, @@ -91,14 +95,14 @@ def initialize_extensions(app: DifyApp): is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True if not is_enabled: if dify_config.DEBUG: - logging.info(f"Skipped {short_name}") + logger.info("Skipped %s", short_name) continue start_time = time.perf_counter() ext.init_app(app) end_time = time.perf_counter() if dify_config.DEBUG: - logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)") + logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2)) def create_migrations_app(): diff --git a/api/commands.py b/api/commands.py index c2e62ec261..6b38e34b9b 100644 --- a/api/commands.py +++ b/api/commands.py @@ -5,10 +5,11 @@ import secrets from typing import Any, Optional import click +import sqlalchemy as sa from flask import current_app from pydantic import TypeAdapter from sqlalchemy import select -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError from configs import dify_config from constants.languages import languages @@ -35,6 +36,7 @@ from services.account_service import AccountService, RegisterService, TenantServ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration +from tasks.remove_app_and_related_data_task import delete_draft_variables_batch @click.command("reset-password", help="Reset the account password.") @@ -53,13 +55,13 @@ def reset_password(email, new_password, password_confirm): account = db.session.query(Account).where(Account.email == email).one_or_none() if not account: - click.echo(click.style("Account not found for email: {}".format(email), fg="red")) + click.echo(click.style(f"Account not found for email: {email}", fg="red")) return try: valid_password(new_password) except: - click.echo(click.style("Invalid password. Must match {}".format(password_pattern), fg="red")) + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) return # generate password salt @@ -92,13 +94,13 @@ def reset_email(email, new_email, email_confirm): account = db.session.query(Account).where(Account.email == email).one_or_none() if not account: - click.echo(click.style("Account not found for email: {}".format(email), fg="red")) + click.echo(click.style(f"Account not found for email: {email}", fg="red")) return try: email_validate(new_email) except: - click.echo(click.style("Invalid email: {}".format(new_email), fg="red")) + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return account.email = new_email @@ -142,7 +144,7 @@ def reset_encrypt_key_pair(): click.echo( click.style( - "Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id), + f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", fg="green", ) ) @@ -180,8 +182,8 @@ def migrate_annotation_vector_database(): ) if not apps: break - except NotFound: - break + except SQLAlchemyError: + raise page += 1 for app in apps: @@ -190,14 +192,14 @@ def migrate_annotation_vector_database(): f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." ) try: - click.echo("Creating app annotation index: {}".format(app.id)) + click.echo(f"Creating app annotation index: {app.id}") app_annotation_setting = ( db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() ) if not app_annotation_setting: skipped_count = skipped_count + 1 - click.echo("App annotation setting disabled: {}".format(app.id)) + click.echo(f"App annotation setting disabled: {app.id}") continue # get dataset_collection_binding info dataset_collection_binding = ( @@ -206,7 +208,7 @@ def migrate_annotation_vector_database(): .first() ) if not dataset_collection_binding: - click.echo("App annotation collection binding not found: {}".format(app.id)) + click.echo(f"App annotation collection binding not found: {app.id}") continue annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() dataset = Dataset( @@ -252,9 +254,7 @@ def migrate_annotation_vector_database(): create_count += 1 except Exception as e: click.echo( - click.style( - "Error creating app annotation index: {} {}".format(e.__class__.__name__, str(e)), fg="red" - ) + click.style(f"Error creating app annotation index: {e.__class__.__name__} {str(e)}", fg="red") ) continue @@ -309,8 +309,8 @@ def migrate_knowledge_vector_database(): ) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except NotFound: - break + except SQLAlchemyError: + raise page += 1 for dataset in datasets: @@ -319,7 +319,7 @@ def migrate_knowledge_vector_database(): f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." ) try: - click.echo("Creating dataset vector database index: {}".format(dataset.id)) + click.echo(f"Creating dataset vector database index: {dataset.id}") if dataset.index_struct_dict: if dataset.index_struct_dict["type"] == vector_type: skipped_count = skipped_count + 1 @@ -423,9 +423,7 @@ def migrate_knowledge_vector_database(): create_count += 1 except Exception as e: db.session.rollback() - click.echo( - click.style("Error creating dataset index: {} {}".format(e.__class__.__name__, str(e)), fg="red") - ) + click.echo(click.style(f"Error creating dataset index: {e.__class__.__name__} {str(e)}", fg="red")) continue click.echo( @@ -461,7 +459,7 @@ def convert_to_agent_apps(): """ with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query)) + rs = conn.execute(sa.text(sql_query)) apps = [] for i in rs: @@ -476,7 +474,7 @@ def convert_to_agent_apps(): break for app in apps: - click.echo("Converting app: {}".format(app.id)) + click.echo(f"Converting app: {app.id}") try: app.mode = AppMode.AGENT_CHAT.value @@ -488,11 +486,11 @@ def convert_to_agent_apps(): ) db.session.commit() - click.echo(click.style("Converted app: {}".format(app.id), fg="green")) + click.echo(click.style(f"Converted app: {app.id}", fg="green")) except Exception as e: - click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red")) + click.echo(click.style(f"Convert app error: {e.__class__.__name__} {str(e)}", fg="red")) - click.echo(click.style("Conversion complete. Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green")) + click.echo(click.style(f"Conversion complete. Converted {len(proceeded_app_ids)} agent apps.", fg="green")) @click.command("add-qdrant-index", help="Add Qdrant index.") @@ -564,8 +562,8 @@ def old_metadata_migration(): .order_by(DatasetDocument.created_at.desc()) ) documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except NotFound: - break + except SQLAlchemyError: + raise if not documents: break for document in documents: @@ -665,7 +663,7 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str click.echo( click.style( - "Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password), + f"Account and tenant created.\nAccount: {email}\nPassword: {new_password}", fg="green", ) ) @@ -706,7 +704,7 @@ def fix_app_site_missing(): sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id where sites.id is null limit 1000""" with db.engine.begin() as conn: - rs = conn.execute(db.text(sql)) + rs = conn.execute(sa.text(sql)) processed_count = 0 for i in rs: @@ -726,16 +724,16 @@ where sites.id is null limit 1000""" if tenant: accounts = tenant.get_accounts() if not accounts: - print("Fix failed for app {}".format(app.id)) + print(f"Fix failed for app {app.id}") continue account = accounts[0] - print("Fixing missing site for app {}".format(app.id)) + print(f"Fixing missing site for app {app.id}") app_was_created.send(app, account=account) except Exception: failed_app_ids.append(app_id) - click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red")) - logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}") + click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red")) + logging.exception("Failed to fix app related site missing issue, app_id: %s", app_id) continue if not processed_count: @@ -920,7 +918,7 @@ def clear_orphaned_file_records(force: bool): ) orphaned_message_files = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])}) @@ -941,7 +939,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style("- Deleting orphaned message_files records", fg="white")) query = "DELETE FROM message_files WHERE id IN :ids" with db.engine.begin() as conn: - conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) + conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])}) click.echo( click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") ) @@ -958,7 +956,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white")) query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]}) click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) @@ -978,7 +976,7 @@ def clear_orphaned_file_records(force: bool): f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" ) with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) elif ids_table["type"] == "text": @@ -993,7 +991,7 @@ def clear_orphaned_file_records(force: bool): f"FROM {ids_table['table']}" ) with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: for j in i[0]: all_ids_in_tables.append({"table": ids_table["table"], "id": j}) @@ -1012,7 +1010,7 @@ def clear_orphaned_file_records(force: bool): f"FROM {ids_table['table']}" ) with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: for j in i[0]: all_ids_in_tables.append({"table": ids_table["table"], "id": j}) @@ -1041,7 +1039,7 @@ def clear_orphaned_file_records(force: bool): click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white")) query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids" with db.engine.begin() as conn: - conn.execute(db.text(query), {"ids": tuple(orphaned_files)}) + conn.execute(sa.text(query), {"ids": tuple(orphaned_files)}) except Exception as e: click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")) return @@ -1111,7 +1109,7 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white")) query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" with db.engine.begin() as conn: - rs = conn.execute(db.text(query)) + rs = conn.execute(sa.text(query)) for i in rs: all_files_in_tables.append(str(i[0])) click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) @@ -1205,3 +1203,138 @@ def setup_system_tool_oauth_client(provider, client_params): db.session.add(oauth_client) db.session.commit() click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) + + +def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: + """ + Find draft variables that reference non-existent apps. + + Args: + batch_size: Maximum number of orphaned app IDs to return + + Returns: + List of app IDs that have draft variables but don't exist in the apps table + """ + query = """ + SELECT DISTINCT wdv.app_id + FROM workflow_draft_variables AS wdv + WHERE NOT EXISTS( + SELECT 1 FROM apps WHERE apps.id = wdv.app_id + ) + LIMIT :batch_size + """ + + with db.engine.connect() as conn: + result = conn.execute(sa.text(query), {"batch_size": batch_size}) + return [row[0] for row in result] + + +def _count_orphaned_draft_variables() -> dict[str, Any]: + """ + Count orphaned draft variables by app. + + Returns: + Dictionary with statistics about orphaned variables + """ + query = """ + SELECT + wdv.app_id, + COUNT(*) as variable_count + FROM workflow_draft_variables AS wdv + WHERE NOT EXISTS( + SELECT 1 FROM apps WHERE apps.id = wdv.app_id + ) + GROUP BY wdv.app_id + ORDER BY variable_count DESC + """ + + with db.engine.connect() as conn: + result = conn.execute(sa.text(query)) + orphaned_by_app = {row[0]: row[1] for row in result} + + total_orphaned = sum(orphaned_by_app.values()) + app_count = len(orphaned_by_app) + + return { + "total_orphaned_variables": total_orphaned, + "orphaned_app_count": app_count, + "orphaned_by_app": orphaned_by_app, + } + + +@click.command() +@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting") +@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)") +@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)") +@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") +def cleanup_orphaned_draft_variables( + dry_run: bool, + batch_size: int, + max_apps: int | None, + force: bool = False, +): + """ + Clean up orphaned draft variables from the database. + + This script finds and removes draft variables that belong to apps + that no longer exist in the database. + """ + logger = logging.getLogger(__name__) + + # Get statistics + stats = _count_orphaned_draft_variables() + + logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"]) + logger.info("Across %s non-existent apps", stats["orphaned_app_count"]) + + if stats["total_orphaned_variables"] == 0: + logger.info("No orphaned draft variables found. Exiting.") + return + + if dry_run: + logger.info("DRY RUN: Would delete the following:") + for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[ + :10 + ]: # Show top 10 + logger.info(" App %s: %s variables", app_id, count) + if len(stats["orphaned_by_app"]) > 10: + logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10) + return + + # Confirm deletion + if not force: + click.confirm( + f"Are you sure you want to delete {stats['total_orphaned_variables']} " + f"orphaned draft variables from {stats['orphaned_app_count']} apps?", + abort=True, + ) + + total_deleted = 0 + processed_apps = 0 + + while True: + if max_apps and processed_apps >= max_apps: + logger.info("Reached maximum app limit (%s). Stopping.", max_apps) + break + + orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10) + if not orphaned_app_ids: + logger.info("No more orphaned draft variables found.") + break + + for app_id in orphaned_app_ids: + if max_apps and processed_apps >= max_apps: + break + + try: + deleted_count = delete_draft_variables_batch(app_id, batch_size) + total_deleted += deleted_count + processed_apps += 1 + + logger.info("Deleted %s variables for app %s", deleted_count, app_id) + + except Exception: + logger.exception("Error processing app %s", app_id) + continue + + logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps) diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 20f8c40427..d3b1cf9d5b 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -41,7 +41,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource): case RemoteSettingsSourceName.NACOS: remote_source = NacosSettingsSource(current_state) case _: - logger.warning(f"Unsupported remote source: {remote_source_name}") + logger.warning("Unsupported remote source: %s", remote_source_name) return {} d: dict[str, Any] = {} diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 9f1646ea7d..2bccc4b7a0 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -330,17 +330,17 @@ class HttpConfig(BaseSettings): def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") - HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[ - PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests") - ] = 10 + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field( + ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10 + ) - HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[ - PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests") - ] = 60 + HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field( + ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60 + ) - HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[ - PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests") - ] = 20 + HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field( + ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20 + ) HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( description="Maximum allowed size in bytes for binary data in HTTP requests", @@ -552,12 +552,18 @@ class RepositoryConfig(BaseSettings): """ CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( - description="Repository implementation for WorkflowExecution. Specify as a module path", + description="Repository implementation for WorkflowExecution. Options: " + "'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), " + "'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'", default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", ) CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( - description="Repository implementation for WorkflowNodeExecution. Specify as a module path", + description="Repository implementation for WorkflowNodeExecution. Options: " + "'core.repositories.sqlalchemy_workflow_node_execution_repository." + "SQLAlchemyWorkflowNodeExecutionRepository' (default), " + "'core.repositories.celery_workflow_node_execution_repository." + "CeleryWorkflowNodeExecutionRepository'", default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", ) @@ -962,6 +968,14 @@ class AccountConfig(BaseSettings): ) +class WorkflowLogConfig(BaseSettings): + WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup") + WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs") + WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field( + default=100, description="Batch size for workflow run log cleanup operations" + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -997,5 +1011,6 @@ class FeatureConfig( HostedServiceConfig, CeleryBeatConfig, CeleryScheduleTasksConfig, + WorkflowLogConfig, ): pass diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 587ea55ca7..ba8bbc7135 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -10,6 +10,7 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig from .storage.amazon_s3_storage_config import S3StorageConfig from .storage.azure_blob_storage_config import AzureBlobStorageConfig from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig +from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig from .storage.google_cloud_storage_config import GoogleCloudStorageConfig from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig from .storage.oci_storage_config import OCIStorageConfig @@ -20,6 +21,7 @@ from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from .vdb.analyticdb_config import AnalyticdbConfig from .vdb.baidu_vector_config import BaiduVectorDBConfig from .vdb.chroma_config import ChromaConfig +from .vdb.clickzetta_config import ClickzettaConfig from .vdb.couchbase_config import CouchbaseConfig from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig @@ -52,6 +54,7 @@ class StorageConfig(BaseSettings): "aliyun-oss", "azure-blob", "baidu-obs", + "clickzetta-volume", "google-storage", "huawei-obs", "oci-storage", @@ -61,8 +64,9 @@ class StorageConfig(BaseSettings): "local", ] = Field( description="Type of storage to use." - " Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', " - "'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.", + " Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', " + "'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', " + "'volcengine-tos', 'supabase'. Default is 'opendal'.", default="opendal", ) @@ -140,7 +144,8 @@ class DatabaseConfig(BaseSettings): default="postgresql", ) - @computed_field + @computed_field # type: ignore[misc] + @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS @@ -215,7 +220,7 @@ class DatabaseConfig(BaseSettings): class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( - description="Backend for Celery task results. Options: 'database', 'redis'.", + description="Backend for Celery task results. Options: 'database', 'redis', 'rabbitmq'.", default="redis", ) @@ -245,11 +250,12 @@ class CeleryConfig(DatabaseConfig): @computed_field def CELERY_RESULT_BACKEND(self) -> str | None: - return ( - "db+{}".format(self.SQLALCHEMY_DATABASE_URI) - if self.CELERY_BACKEND == "database" - else self.CELERY_BROKER_URL - ) + if self.CELERY_BACKEND in ("database", "rabbitmq"): + return f"db+{self.SQLALCHEMY_DATABASE_URI}" + elif self.CELERY_BACKEND == "redis": + return self.CELERY_BROKER_URL + else: + return None @property def BROKER_USE_SSL(self) -> bool: @@ -302,6 +308,7 @@ class MiddlewareConfig( AliyunOSSStorageConfig, AzureBlobStorageConfig, BaiduOBSStorageConfig, + ClickZettaVolumeStorageConfig, GoogleCloudStorageConfig, HuaweiCloudOBSStorageConfig, OCIStorageConfig, @@ -314,6 +321,7 @@ class MiddlewareConfig( VectorStoreConfig, AnalyticdbConfig, ChromaConfig, + ClickzettaConfig, HuaweiCloudConfig, MilvusConfig, MyScaleConfig, diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 916f52e165..16dca98cfa 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -39,6 +39,26 @@ class RedisConfig(BaseSettings): default=False, ) + REDIS_SSL_CERT_REQS: str = Field( + description="SSL certificate requirements (CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED)", + default="CERT_NONE", + ) + + REDIS_SSL_CA_CERTS: Optional[str] = Field( + description="Path to the CA certificate file for SSL verification", + default=None, + ) + + REDIS_SSL_CERTFILE: Optional[str] = Field( + description="Path to the client certificate file for SSL authentication", + default=None, + ) + + REDIS_SSL_KEYFILE: Optional[str] = Field( + description="Path to the client private key file for SSL authentication", + default=None, + ) + REDIS_USE_SENTINEL: Optional[bool] = Field( description="Enable Redis Sentinel mode for high availability", default=False, diff --git a/api/configs/middleware/storage/clickzetta_volume_storage_config.py b/api/configs/middleware/storage/clickzetta_volume_storage_config.py new file mode 100644 index 0000000000..56e1b6a957 --- /dev/null +++ b/api/configs/middleware/storage/clickzetta_volume_storage_config.py @@ -0,0 +1,65 @@ +"""ClickZetta Volume Storage Configuration""" + +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class ClickZettaVolumeStorageConfig(BaseSettings): + """Configuration for ClickZetta Volume storage.""" + + CLICKZETTA_VOLUME_USERNAME: Optional[str] = Field( + description="Username for ClickZetta Volume authentication", + default=None, + ) + + CLICKZETTA_VOLUME_PASSWORD: Optional[str] = Field( + description="Password for ClickZetta Volume authentication", + default=None, + ) + + CLICKZETTA_VOLUME_INSTANCE: Optional[str] = Field( + description="ClickZetta instance identifier", + default=None, + ) + + CLICKZETTA_VOLUME_SERVICE: str = Field( + description="ClickZetta service endpoint", + default="api.clickzetta.com", + ) + + CLICKZETTA_VOLUME_WORKSPACE: str = Field( + description="ClickZetta workspace name", + default="quick_start", + ) + + CLICKZETTA_VOLUME_VCLUSTER: str = Field( + description="ClickZetta virtual cluster name", + default="default_ap", + ) + + CLICKZETTA_VOLUME_SCHEMA: str = Field( + description="ClickZetta schema name", + default="dify", + ) + + CLICKZETTA_VOLUME_TYPE: str = Field( + description="ClickZetta volume type (table|user|external)", + default="user", + ) + + CLICKZETTA_VOLUME_NAME: Optional[str] = Field( + description="ClickZetta volume name for external volumes", + default=None, + ) + + CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field( + description="Prefix for ClickZetta volume table names", + default="dataset_", + ) + + CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field( + description="Directory prefix for User Volume to organize Dify files", + default="dify_km", + ) diff --git a/api/configs/middleware/vdb/clickzetta_config.py b/api/configs/middleware/vdb/clickzetta_config.py new file mode 100644 index 0000000000..04f81e25fc --- /dev/null +++ b/api/configs/middleware/vdb/clickzetta_config.py @@ -0,0 +1,69 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class ClickzettaConfig(BaseModel): + """ + Clickzetta Lakehouse vector database configuration + """ + + CLICKZETTA_USERNAME: Optional[str] = Field( + description="Username for authenticating with Clickzetta Lakehouse", + default=None, + ) + + CLICKZETTA_PASSWORD: Optional[str] = Field( + description="Password for authenticating with Clickzetta Lakehouse", + default=None, + ) + + CLICKZETTA_INSTANCE: Optional[str] = Field( + description="Clickzetta Lakehouse instance ID", + default=None, + ) + + CLICKZETTA_SERVICE: Optional[str] = Field( + description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')", + default="api.clickzetta.com", + ) + + CLICKZETTA_WORKSPACE: Optional[str] = Field( + description="Clickzetta workspace name", + default="default", + ) + + CLICKZETTA_VCLUSTER: Optional[str] = Field( + description="Clickzetta virtual cluster name", + default="default_ap", + ) + + CLICKZETTA_SCHEMA: Optional[str] = Field( + description="Database schema name in Clickzetta", + default="public", + ) + + CLICKZETTA_BATCH_SIZE: Optional[int] = Field( + description="Batch size for bulk insert operations", + default=100, + ) + + CLICKZETTA_ENABLE_INVERTED_INDEX: Optional[bool] = Field( + description="Enable inverted index for full-text search capabilities", + default=True, + ) + + CLICKZETTA_ANALYZER_TYPE: Optional[str] = Field( + description="Analyzer type for full-text search: keyword, english, chinese, unicode", + default="chinese", + ) + + CLICKZETTA_ANALYZER_MODE: Optional[str] = Field( + description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)", + default="smart", + ) + + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: Optional[str] = Field( + description="Distance function for vector similarity: l2_distance or cosine_distance", + default="cosine_distance", + ) diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py index df8182985d..8c4b333d45 100644 --- a/api/configs/middleware/vdb/elasticsearch_config.py +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -1,12 +1,13 @@ from typing import Optional -from pydantic import Field, PositiveInt +from pydantic import Field, PositiveInt, model_validator from pydantic_settings import BaseSettings class ElasticsearchConfig(BaseSettings): """ - Configuration settings for Elasticsearch + Configuration settings for both self-managed and Elastic Cloud deployments. + Can load from environment variables or .env files. """ ELASTICSEARCH_HOST: Optional[str] = Field( @@ -28,3 +29,50 @@ class ElasticsearchConfig(BaseSettings): description="Password for authenticating with Elasticsearch (default is 'elastic')", default="elastic", ) + + # Elastic Cloud (optional) + ELASTICSEARCH_USE_CLOUD: Optional[bool] = Field( + description="Set to True to use Elastic Cloud instead of self-hosted Elasticsearch", default=False + ) + ELASTICSEARCH_CLOUD_URL: Optional[str] = Field( + description="Full URL for Elastic Cloud deployment (e.g., 'https://example.es.region.aws.found.io:443')", + default=None, + ) + ELASTICSEARCH_API_KEY: Optional[str] = Field( + description="API key for authenticating with Elastic Cloud", default=None + ) + + # Common options + ELASTICSEARCH_CA_CERTS: Optional[str] = Field( + description="Path to CA certificate file for SSL verification", default=None + ) + ELASTICSEARCH_VERIFY_CERTS: bool = Field( + description="Whether to verify SSL certificates (default is False)", default=False + ) + ELASTICSEARCH_REQUEST_TIMEOUT: int = Field( + description="Request timeout in milliseconds (default is 100000)", default=100000 + ) + ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = Field( + description="Whether to retry requests on timeout (default is True)", default=True + ) + ELASTICSEARCH_MAX_RETRIES: int = Field( + description="Maximum number of retry attempts (default is 10000)", default=10000 + ) + + @model_validator(mode="after") + def validate_elasticsearch_config(self): + """Validate Elasticsearch configuration based on deployment type.""" + if self.ELASTICSEARCH_USE_CLOUD: + if not self.ELASTICSEARCH_CLOUD_URL: + raise ValueError("ELASTICSEARCH_CLOUD_URL is required when using Elastic Cloud") + if not self.ELASTICSEARCH_API_KEY: + raise ValueError("ELASTICSEARCH_API_KEY is required when using Elastic Cloud") + else: + if not self.ELASTICSEARCH_HOST: + raise ValueError("ELASTICSEARCH_HOST is required for self-hosted Elasticsearch") + if not self.ELASTICSEARCH_USERNAME: + raise ValueError("ELASTICSEARCH_USERNAME is required for self-hosted Elasticsearch") + if not self.ELASTICSEARCH_PASSWORD: + raise ValueError("ELASTICSEARCH_PASSWORD is required for self-hosted Elasticsearch") + + return self diff --git a/api/configs/middleware/vdb/tablestore_config.py b/api/configs/middleware/vdb/tablestore_config.py index c4dcc0d465..1aab01c6e1 100644 --- a/api/configs/middleware/vdb/tablestore_config.py +++ b/api/configs/middleware/vdb/tablestore_config.py @@ -28,3 +28,8 @@ class TableStoreConfig(BaseSettings): description="AccessKey secret for the instance name", default=None, ) + + TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field( + description="Whether to normalize full-text search scores to [0, 1]", + default=False, + ) diff --git a/api/configs/remote_settings_sources/apollo/client.py b/api/configs/remote_settings_sources/apollo/client.py index 88b30d3987..877ff8409f 100644 --- a/api/configs/remote_settings_sources/apollo/client.py +++ b/api/configs/remote_settings_sources/apollo/client.py @@ -76,7 +76,7 @@ class ApolloClient: code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) if code == 200: if not body: - logger.error(f"get_json_from_net load configs failed, body is {body}") + logger.error("get_json_from_net load configs failed, body is %s", body) return None data = json.loads(body) data = data["configurations"] @@ -207,7 +207,7 @@ class ApolloClient: # if the length is 0 it is returned directly if len(notifications) == 0: return - url = "{}/notifications/v2".format(self.config_url) + url = f"{self.config_url}/notifications/v2" params = { "appId": self.app_id, "cluster": self.cluster, @@ -222,7 +222,7 @@ class ApolloClient: return if http_code == 200: if not body: - logger.error(f"_long_poll load configs failed,body is {body}") + logger.error("_long_poll load configs failed,body is %s", body) return data = json.loads(body) for entry in data: @@ -273,12 +273,12 @@ class ApolloClient: time.sleep(60 * 10) # 10 minutes def _do_heart_beat(self, namespace): - url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip) + url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}" try: code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) if code == 200: if not body: - logger.error(f"_do_heart_beat load configs failed,body is {body}") + logger.error("_do_heart_beat load configs failed,body is %s", body) return None data = json.loads(body) if self.last_release_key == data["releaseKey"]: diff --git a/api/configs/remote_settings_sources/apollo/utils.py b/api/configs/remote_settings_sources/apollo/utils.py index 6136112e03..f5b82908ee 100644 --- a/api/configs/remote_settings_sources/apollo/utils.py +++ b/api/configs/remote_settings_sources/apollo/utils.py @@ -24,7 +24,7 @@ def url_encode_wrapper(params): def no_key_cache_key(namespace, key): - return "{}{}{}".format(namespace, len(namespace), key) + return f"{namespace}{len(namespace)}{key}" # Returns whether the obtained value is obtained, and None if it does not diff --git a/api/constants/__init__.py b/api/constants/__init__.py index 9e052320ac..c98f4d55c8 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -9,10 +9,10 @@ DEFAULT_FILE_NUMBER_LIMITS = 3 IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) -VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"] +VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"] VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS]) -AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"] +AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) diff --git a/api/constants/languages.py b/api/constants/languages.py index 1157ec4307..ab19392c59 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -28,5 +28,5 @@ def supported_language(lang): if lang in languages: return lang - error = "{lang} is not a valid language.".format(lang=lang) + error = f"{lang} is not a valid language." raise ValueError(error) diff --git a/api/controllers/common/errors.py b/api/controllers/common/errors.py index 9f762b3135..6e2ea952fc 100644 --- a/api/controllers/common/errors.py +++ b/api/controllers/common/errors.py @@ -1,5 +1,7 @@ from werkzeug.exceptions import HTTPException +from libs.exception import BaseHTTPException + class FilenameNotExistsError(HTTPException): code = 400 @@ -9,3 +11,27 @@ class FilenameNotExistsError(HTTPException): class RemoteFileUploadError(HTTPException): code = 400 description = "Error uploading remote file." + + +class FileTooLargeError(BaseHTTPException): + error_code = "file_too_large" + description = "File size exceeded. {message}" + code = 413 + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = "unsupported_file_type" + description = "File type not allowed." + code = 415 + + +class TooManyFilesError(BaseHTTPException): + error_code = "too_many_files" + description = "Only one file is allowed." + code = 400 + + +class NoFileUploadedError(BaseHTTPException): + error_code = "no_file_uploaded" + description = "Please upload your file." + code = 400 diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 3466eea1f6..df9de825de 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import AppIconUrlField @@ -10,6 +10,12 @@ parameters__system_parameters = { "workflow_file_upload_limit": fields.Integer, } + +def build_system_parameters_model(api_or_ns: Api | Namespace): + """Build the system parameters model for the API or Namespace.""" + return api_or_ns.model("SystemParameters", parameters__system_parameters) + + parameters_fields = { "opening_statement": fields.String, "suggested_questions": fields.Raw, @@ -25,6 +31,14 @@ parameters_fields = { "system_parameters": fields.Nested(parameters__system_parameters), } + +def build_parameters_model(api_or_ns: Api | Namespace): + """Build the parameters model for the API or Namespace.""" + copied_fields = parameters_fields.copy() + copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns)) + return api_or_ns.model("Parameters", copied_fields) + + site_fields = { "title": fields.String, "chat_color_theme": fields.String, @@ -41,3 +55,8 @@ site_fields = { "show_workflow_steps": fields.Boolean, "use_icon_as_answer_icon": fields.Boolean, } + + +def build_site_model(api_or_ns: Api | Namespace): + """Build the site model for the API or Namespace.""" + return api_or_ns.model("Site", site_fields) diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index 008f1f0f7a..6a5197635e 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -1,3 +1,4 @@ +import contextlib import mimetypes import os import platform @@ -65,10 +66,8 @@ def guess_file_info_from_response(response: httpx.Response): # Use python-magic to guess MIME type if still unknown or generic if mimetype == "application/octet-stream" and magic is not None: - try: + with contextlib.suppress(magic.MagicException): mimetype = magic.from_buffer(response.content[:1024], mime=True) - except magic.MagicException: - pass extension = os.path.splitext(filename)[1] diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 8a55197fb6..7e5c28200a 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index d7500c415c..401e88709a 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,8 +1,8 @@ -from typing import Any +from typing import Any, Optional -import flask_restful +import flask_restx from flask_login import current_user -from flask_restful import Resource, fields, marshal_with +from flask_restx import Resource, fields, marshal_with from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -40,7 +40,7 @@ def _get_resource(resource_id, tenant_id, resource_model): ).scalar_one_or_none() if resource is None: - flask_restful.abort(404, message=f"{resource_model.__name__} not found.") + flask_restx.abort(404, message=f"{resource_model.__name__} not found.") return resource @@ -49,7 +49,7 @@ class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Any = None + resource_model: Optional[Any] = None resource_id_field: str | None = None token_prefix: str | None = None max_keys = 10 @@ -81,7 +81,7 @@ class BaseApiKeyListResource(Resource): ) if current_key_count >= self.max_keys: - flask_restful.abort( + flask_restx.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", code="max_keys_exceeded", @@ -102,7 +102,7 @@ class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Any = None + resource_model: Optional[Any] = None resource_id_field: str | None = None def delete(self, resource_id, api_key_id): @@ -126,7 +126,7 @@ class BaseApiKeyResource(Resource): ) if key is None: - flask_restful.abort(404, message="API key not found") + flask_restx.abort(404, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c228743fa5..c6cb6f6e3a 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index d433415894..a964154207 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 2b48afd550..37d23ccd9f 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,11 +1,12 @@ +from typing import Literal + from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden +from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.console import api -from controllers.console.app.error import NoFileUploadedError -from controllers.console.datasets.error import TooManyFilesError from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, @@ -25,7 +26,7 @@ class AnnotationReplyActionApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") - def post(self, app_id, action): + def post(self, app_id, action: Literal["enable", "disable"]): if not current_user.is_editor: raise Forbidden() @@ -39,8 +40,6 @@ class AnnotationReplyActionApi(Resource): result = AppAnnotationService.enable_app_annotation(args, app_id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) - else: - raise ValueError("Unsupported annotation reply action") return result, 200 @@ -86,7 +85,7 @@ class AnnotationReplyActionStatusApi(Resource): raise Forbidden() job_id = str(job_id) - app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) + app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job does not exist.") @@ -94,13 +93,13 @@ class AnnotationReplyActionStatusApi(Resource): job_status = cache_result.decode() error_msg = "" if job_status == "error": - app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) + app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}" error_msg = redis_client.get(app_annotation_error_key).decode() return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 -class AnnotationListApi(Resource): +class AnnotationApi(Resource): @setup_required @login_required @account_initialization_required @@ -123,22 +122,6 @@ class AnnotationListApi(Resource): } return response, 200 - -class AnnotationExportApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, app_id): - if not current_user.is_editor: - raise Forbidden() - - app_id = str(app_id) - annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = {"data": marshal(annotation_list, annotation_fields)} - return response, 200 - - -class AnnotationCreateApi(Resource): @setup_required @login_required @account_initialization_required @@ -156,6 +139,48 @@ class AnnotationCreateApi(Resource): annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) return annotation + @setup_required + @login_required + @account_initialization_required + def delete(self, app_id): + if not current_user.is_editor: + raise Forbidden() + + app_id = str(app_id) + + # Use request.args.getlist to get annotation_ids array directly + annotation_ids = request.args.getlist("annotation_id") + + # If annotation_ids are provided, handle batch deletion + if annotation_ids: + # Check if any annotation_ids contain empty strings or invalid values + if not all(annotation_id.strip() for annotation_id in annotation_ids if annotation_id): + return { + "code": "bad_request", + "message": "annotation_ids are required if the parameter is provided.", + }, 400 + + result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids) + return result, 204 + # If no annotation_ids are provided, handle clearing all annotations + else: + AppAnnotationService.clear_all_annotations(app_id) + return {"result": "success"}, 204 + + +class AnnotationExportApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, app_id): + if not current_user.is_editor: + raise Forbidden() + + app_id = str(app_id) + annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) + response = {"data": marshal(annotation_list, annotation_fields)} + return response, 200 + class AnnotationUpdateDeleteApi(Resource): @setup_required @@ -199,14 +224,15 @@ class AnnotationBatchImportApi(Resource): raise Forbidden() app_id = str(app_id) - # get file from request - file = request.files["file"] # check file if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() + + # get file from request + file = request.files["file"] # check file type if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") @@ -223,14 +249,14 @@ class AnnotationBatchImportStatusApi(Resource): raise Forbidden() job_id = str(job_id) - indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) + indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job does not exist.") job_status = cache_result.decode() error_msg = "" if job_status == "error": - indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) + indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}" error_msg = redis_client.get(indexing_error_msg_key).decode() return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 @@ -265,7 +291,7 @@ api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply api.add_resource( AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" ) -api.add_resource(AnnotationListApi, "/apps//annotations") +api.add_resource(AnnotationApi, "/apps//annotations") api.add_resource(AnnotationExportApi, "/apps//annotations/export") api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 9fe32dde6d..a6eb86122d 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -2,7 +2,7 @@ import uuid from typing import cast from flask_login import current_user -from flask_restful import Resource, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort @@ -28,6 +28,12 @@ from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] +def _validate_description_length(description): + if description and len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + class AppListApi(Resource): @setup_required @login_required @@ -94,7 +100,7 @@ class AppListApi(Resource): """Create app""" parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") @@ -146,7 +152,7 @@ class AppApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") @@ -189,7 +195,7 @@ class AppCopyApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") + parser.add_argument("description", type=_validate_description_length, location="json") parser.add_argument("icon_type", type=str, location="json") parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 9ffb94e9f9..aee93a8814 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,7 +1,7 @@ from typing import cast from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 665cf1aede..ea1869a587 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError import services diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 732f5b799a..bd5e7d0924 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,7 +1,8 @@ import logging import flask_login -from flask_restful import Resource, reqparse +from flask import request +from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound import services @@ -24,6 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value @@ -115,6 +117,10 @@ class ChatMessageApi(Resource): streaming = args["response_mode"] != "blocking" args["auto_generate_name"] = False + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id + account = flask_login.current_user try: diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index b5b6d1f75b..06f0218771 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -2,8 +2,8 @@ from datetime import datetime import pytz # pip install pytz from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, marshal_with, reqparse +from flask_restx.inputs import int_range from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound @@ -24,6 +24,8 @@ from libs.helper import DatetimeString from libs.login import login_required from models import Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode +from services.conversation_service import ConversationService +from services.errors.conversation import ConversationNotExistsError class CompletionConversationApi(Resource): @@ -46,13 +48,15 @@ class CompletionConversationApi(Resource): parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") + query = db.select(Conversation).where( + Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) + ) if args["keyword"]: query = query.join(Message, Message.conversation_id == Conversation.id).where( or_( - Message.query.ilike("%{}%".format(args["keyword"])), - Message.answer.ilike("%{}%".format(args["keyword"])), + Message.query.ilike(f"%{args['keyword']}%"), + Message.answer.ilike(f"%{args['keyword']}%"), ) ) @@ -119,18 +123,11 @@ class CompletionConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() - ) - - if not conversation: + try: + ConversationService.delete(app_model, conversation_id, current_user) + except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - conversation.is_deleted = True - db.session.commit() - return {"result": "success"}, 204 @@ -171,10 +168,10 @@ class ChatConversationApi(Resource): .subquery() ) - query = db.select(Conversation).where(Conversation.app_id == app_model.id) + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) if args["keyword"]: - keyword_filter = "%{}%".format(args["keyword"]) + keyword_filter = f"%{args['keyword']}%" query = ( query.join( Message, @@ -284,18 +281,11 @@ class ChatConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = ( - db.session.query(Conversation) - .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) - .first() - ) - - if not conversation: + try: + ConversationService.delete(app_model, conversation_id, current_user) + except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - conversation.is_deleted = True - db.session.commit() - return {"result": "success"}, 204 diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index d49f433ba1..5ca4c33f87 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index 1559f82d6e..fbd7901646 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -79,18 +79,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException): code = 400 -class NoFileUploadedError(BaseHTTPException): - error_code = "no_file_uploaded" - description = "Please upload your file." - code = 400 - - -class TooManyFilesError(BaseHTTPException): - error_code = "too_many_files" - description = "Only one file is allowed." - code = 400 - - class DraftWorkflowNotExist(BaseHTTPException): error_code = "draft_workflow_not_exist" description = "Draft workflow need to be initialized." diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 790369c052..497fd53df7 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,7 @@ -import os +from collections.abc import Sequence from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.app.error import ( @@ -12,6 +12,8 @@ from controllers.console.app.error import ( ) from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider +from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required @@ -29,15 +31,12 @@ class RuleGenerateApi(Resource): args = parser.parse_args() account = current_user - PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) - try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], no_variable=args["no_variable"], - rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -64,14 +63,12 @@ class RuleCodeGenerateApi(Resource): args = parser.parse_args() account = current_user - CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024")) try: code_result = LLMGenerator.generate_code( tenant_id=account.current_tenant_id, instruction=args["instruction"], model_config=args["model_config"], code_language=args["code_language"], - max_tokens=CODE_GENERATION_MAX_TOKENS, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -114,6 +111,121 @@ class RuleStructuredOutputGenerateApi(Resource): return structured_output +class InstructionGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("flow_id", type=str, required=True, default="", location="json") + parser.add_argument("node_id", type=str, required=False, default="", location="json") + parser.add_argument("current", type=str, required=False, default="", location="json") + parser.add_argument("language", type=str, required=False, default="javascript", location="json") + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser.add_argument("ideal_output", type=str, required=False, default="", location="json") + args = parser.parse_args() + code_template = ( + Python3CodeProvider.get_default_code() + if args["language"] == "python" + else (JavascriptCodeProvider.get_default_code()) + if args["language"] == "javascript" + else "" + ) + try: + # Generate from nothing for a workflow node + if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "": + from models import App, db + from services.workflow_service import WorkflowService + + app = db.session.query(App).where(App.id == args["flow_id"]).first() + if not app: + return {"error": f"app {args['flow_id']} not found"}, 400 + workflow = WorkflowService().get_draft_workflow(app_model=app) + if not workflow: + return {"error": f"workflow {args['flow_id']} not found"}, 400 + nodes: Sequence = workflow.graph_dict["nodes"] + node = [node for node in nodes if node["id"] == args["node_id"]] + if len(node) == 0: + return {"error": f"node {args['node_id']} not found"}, 400 + node_type = node[0]["data"]["type"] + match node_type: + case "llm": + return LLMGenerator.generate_rule_config( + current_user.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=True, + ) + case "agent": + return LLMGenerator.generate_rule_config( + current_user.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=True, + ) + case "code": + return LLMGenerator.generate_code( + tenant_id=current_user.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + code_language=args["language"], + ) + case _: + return {"error": f"invalid node type: {node_type}"} + if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow + return LLMGenerator.instruction_modify_legacy( + tenant_id=current_user.current_tenant_id, + flow_id=args["flow_id"], + current=args["current"], + instruction=args["instruction"], + model_config=args["model_config"], + ideal_output=args["ideal_output"], + ) + if args["node_id"] != "" and args["current"] != "": # For workflow node + return LLMGenerator.instruction_modify_workflow( + tenant_id=current_user.current_tenant_id, + flow_id=args["flow_id"], + node_id=args["node_id"], + current=args["current"], + instruction=args["instruction"], + model_config=args["model_config"], + ideal_output=args["ideal_output"], + ) + return {"error": "incompatible parameters"}, 400 + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + + +class InstructionGenerationTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self) -> dict: + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, default=False, location="json") + args = parser.parse_args() + match args["type"]: + case "prompt": + from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT + + return {"data": INSTRUCTION_GENERATE_TEMPLATE_PROMPT} + case "code": + from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_CODE + + return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} + case _: + raise ValueError(f"Invalid type: {args['type']}") + + api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") +api.add_resource(InstructionGenerateApi, "/instruction-generate") +api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template") diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 2344fd5acb..541803e539 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -2,7 +2,7 @@ import json from enum import StrEnum from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 5e79e8dece..57cc825fe9 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,11 +1,10 @@ import logging from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, fields, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -import services from controllers.console import api from controllers.console.app.error import ( CompletionRequestError, @@ -28,7 +27,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import AppMode, Conversation, Message, MessageAnnotation +from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError @@ -125,17 +124,34 @@ class MessageFeedbackApi(Resource): parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() - try: - MessageService.create_feedback( - app_model=app_model, - message_id=str(args["message_id"]), - user=current_user, - rating=args.get("rating"), - content=None, - ) - except services.errors.message.MessageNotExistsError: + message_id = str(args["message_id"]) + + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() + + if not message: raise NotFound("Message Not Exists.") + feedback = message.admin_feedback + + if not args["rating"] and feedback: + db.session.delete(feedback) + elif args["rating"] and feedback: + feedback.rating = args["rating"] + elif not args["rating"] and not feedback: + raise ValueError("rating cannot be None when feedback not exists") + else: + feedback = MessageFeedback( + app_id=app_model.id, + conversation_id=message.conversation_id, + message_id=message.id, + rating=args["rating"], + from_source="admin", + from_account_id=current_user.id, + ) + db.session.add(feedback) + + db.session.commit() + return {"result": "success"} diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 029138fb6b..52ff9b923d 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -3,7 +3,7 @@ from typing import cast from flask import request from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 978c02412c..74c2867c2f 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import BadRequest from controllers.console import api diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 03418f1dd2..778ce92da6 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 32b64d10c5..27e405af38 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -5,7 +5,7 @@ import pytz import sqlalchemy as sa from flask import jsonify from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -67,7 +67,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "message_count": i.message_count}) @@ -176,7 +176,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) @@ -234,7 +234,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} @@ -310,7 +310,7 @@ ORDER BY response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} @@ -373,7 +373,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( { @@ -435,7 +435,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) @@ -495,7 +495,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index a9f088a276..8dcffb1666 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from typing import cast from flask import abort, request -from flask_restful import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, inputs, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -23,6 +23,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File +from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from factories import file_factory, variable_factory from fields.workflow_fields import workflow_fields, workflow_pagination_fields @@ -185,6 +186,10 @@ class AdvancedChatDraftWorkflowRunApi(Resource): args = parser.parse_args() + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id + try: response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True @@ -373,6 +378,10 @@ class DraftWorkflowRunApi(Resource): parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id + try: response = AppGenerateService.generate( app_model=app_model, diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 310146a5e7..8d8cdc93cf 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,6 +1,6 @@ from dateutil.parser import isoparse -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, marshal_with, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session from controllers.console import api diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index ba93f82756..4e625db24d 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -2,7 +2,7 @@ import logging from typing import Any, NoReturn from flask import Response -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -163,11 +163,11 @@ class WorkflowVariableCollectionApi(Resource): draft_var_srv = WorkflowDraftVariableService( session=session, ) - workflow_vars = draft_var_srv.list_variables_without_values( - app_id=app_model.id, - page=args.page, - limit=args.limit, - ) + workflow_vars = draft_var_srv.list_variables_without_values( + app_id=app_model.id, + page=args.page, + limit=args.limit, + ) return workflow_vars diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 9099700213..dccbfd8648 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,8 +1,8 @@ from typing import cast from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, marshal_with, reqparse +from flask_restx.inputs import int_range from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 6c7c73707b..7cef175c14 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -2,9 +2,10 @@ from datetime import datetime from decimal import Decimal import pytz +import sqlalchemy as sa from flask import jsonify from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -71,7 +72,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "runs": i.runs}) @@ -133,7 +134,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) @@ -195,7 +196,7 @@ WHERE response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( { @@ -277,7 +278,7 @@ GROUP BY response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(sa.text(sql_query), arg_dict) for i in rs: response_data.append( {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 2562fb5eb8..e82e403ec2 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from constants.languages import supported_language from controllers.console import api diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index b8c3c8f012..796e6916cc 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 4c9697cc32..d4cf20549a 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -3,7 +3,7 @@ import logging import requests from flask import current_app, redirect, request from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from werkzeug.exceptions import Forbidden from configs import dify_config @@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource): oauth_provider.get_access_token(code) except requests.exceptions.HTTPError as e: logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" + "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text ) return {"error": "OAuth data source process failed"}, 400 @@ -103,7 +103,9 @@ class OAuthDataSourceSync(Resource): try: oauth_provider.sync_data_source(binding_id) except requests.exceptions.HTTPError as e: - logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") + logging.exception( + "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text + ) return {"error": "OAuth data source process failed"}, 400 return {"result": "success"}, 200 diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 1984339add..8c5e23de58 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -113,9 +113,3 @@ class MemberNotInTenantError(BaseHTTPException): error_code = "member_not_in_tenant" description = "The member is not in the workspace." code = 400 - - -class AccountInFreezeError(BaseHTTPException): - error_code = "account_in_freeze" - description = "This email is temporarily unavailable." - code = 400 diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 3bbe3177fc..ede0696854 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,7 +2,7 @@ import base64 import secrets from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 5f2a24322d..a5ad6a1cd7 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -2,7 +2,7 @@ from typing import cast import flask_login from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse import services from configs import dify_config @@ -221,7 +221,7 @@ class EmailCodeLoginApi(Resource): email=user_email, name=user_email, interface_language=languages[0] ) except WorkSpaceNotAllowedCreateError: - return NotAllowedCreateWorkspace() + raise NotAllowedCreateWorkspace() except AccountRegisterError as are: raise AccountInFreezeError() except WorkspacesLimitExceededError: diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index d0a4f3ff6d..3c76394cf9 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -3,7 +3,7 @@ from typing import Optional import requests from flask import current_app, redirect, request -from flask_restful import Resource +from flask_restx import Resource from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized @@ -80,7 +80,7 @@ class OAuthCallback(Resource): user_info = oauth_provider.get_user_info(token) except requests.exceptions.RequestException as e: error_text = e.response.text if e.response else str(e) - logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") + logging.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4b0c82ae6c..8ebb745a60 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index 9679632ac7..4bc073f679 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -1,6 +1,6 @@ from flask import request from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from libs.helper import extract_remote_ip from libs.login import login_required diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 39f8ab5787..6083a53bec 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -2,7 +2,7 @@ import json from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f551bc2432..a23536f82e 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,7 +1,7 @@ -import flask_restful +import flask_restx from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound import services @@ -41,7 +41,7 @@ def _validate_name(name): def _validate_description_length(description): - if len(description) > 400: + if description and len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description @@ -113,7 +113,7 @@ class DatasetListApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", @@ -589,7 +589,7 @@ class DatasetApiKeyApi(Resource): ) if current_key_count >= self.max_keys: - flask_restful.abort( + flask_restx.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", code="max_keys_exceeded", @@ -629,7 +629,7 @@ class DatasetApiDeleteApi(Resource): ) if key is None: - flask_restful.abort(404, message="API key not found") + flask_restx.abort(404, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() @@ -683,6 +683,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.HUAWEI_CLOUD | VectorType.TENCENT | VectorType.MATRIXONE + | VectorType.CLICKZETTA ): return { "retrieval_method": [ @@ -731,6 +732,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.TENCENT | VectorType.HUAWEI_CLOUD | VectorType.MATRIXONE + | VectorType.CLICKZETTA ): return { "retrieval_method": [ diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index d14b208a4b..f823ed603b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,10 +1,10 @@ import logging from argparse import ArgumentTypeError -from typing import cast +from typing import Literal, cast from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_restx import Resource, marshal, marshal_with, reqparse from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -642,7 +642,7 @@ class DocumentIndexingStatusApi(DocumentResource): return marshal(document_dict, document_status_fields) -class DocumentDetailApi(DocumentResource): +class DocumentApi(DocumentResource): METADATA_CHOICES = {"all", "only", "without"} @setup_required @@ -730,13 +730,35 @@ class DocumentDetailApi(DocumentResource): return response, 200 + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_rate_limit_check("knowledge") + def delete(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + + document = self.get_document(dataset_id, document_id) + + try: + DocumentService.delete_document(document) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError("Cannot delete document during indexing.") + + return {"result": "success"}, 204 + class DocumentProcessingApi(DocumentResource): @setup_required @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, document_id, action): + def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]): dataset_id = str(dataset_id) document_id = str(document_id) document = self.get_document(dataset_id, document_id) @@ -762,36 +784,10 @@ class DocumentProcessingApi(DocumentResource): document.paused_at = None document.is_paused = False db.session.commit() - else: - raise InvalidActionError() return {"result": "success"}, 200 -class DocumentDeleteApi(DocumentResource): - @setup_required - @login_required - @account_initialization_required - @cloud_edition_billing_rate_limit_check("knowledge") - def delete(self, dataset_id, document_id): - dataset_id = str(dataset_id) - document_id = str(document_id) - dataset = DatasetService.get_dataset(dataset_id) - if dataset is None: - raise NotFound("Dataset not found.") - # check user's model setting - DatasetService.check_dataset_model_setting(dataset) - - document = self.get_document(dataset_id, document_id) - - try: - DocumentService.delete_document(document) - except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError("Cannot delete document during indexing.") - - return {"result": "success"}, 204 - - class DocumentMetadataApi(DocumentResource): @setup_required @login_required @@ -842,7 +838,7 @@ class DocumentStatusApi(DocumentResource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id, action): + def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -970,7 +966,7 @@ class DocumentRetryApi(DocumentResource): raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception: - logging.exception(f"Failed to retry document, document id: {document_id}") + logging.exception("Failed to retry document, document id: %s", document_id) continue # retry document DocumentService.retry_document(dataset_id, retry_documents) @@ -1037,11 +1033,10 @@ api.add_resource( api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets//batch//indexing-estimate") api.add_resource(DocumentBatchIndexingStatusApi, "/datasets//batch//indexing-status") api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") -api.add_resource(DocumentDetailApi, "/datasets//documents/") +api.add_resource(DocumentApi, "/datasets//documents/") api.add_resource( DocumentProcessingApi, "/datasets//documents//processing/" ) -api.add_resource(DocumentDeleteApi, "/datasets//documents/") api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") api.add_resource(DocumentStatusApi, "/datasets//documents/status//batch") api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index b3704ce8b1..463fd2d7ec 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,9 +1,8 @@ import uuid -import pandas as pd from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_restx import Resource, marshal, reqparse from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound @@ -14,8 +13,6 @@ from controllers.console.datasets.error import ( ChildChunkDeleteIndexError, ChildChunkIndexingError, InvalidActionError, - NoFileUploadedError, - TooManyFilesError, ) from controllers.console.wraps import ( account_initialization_required, @@ -32,6 +29,7 @@ from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import login_required from models.dataset import ChildChunk, DocumentSegment +from models.model import UploadFile from services.dataset_service import DatasetService, DocumentService, SegmentService from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError @@ -184,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource): raise ProviderNotInitializeError(ex.description) segment_ids = request.args.getlist("segment_id") - document_indexing_cache_key = "document_{}_indexing".format(document.id) + document_indexing_cache_key = f"document_{document.id}_indexing" cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") @@ -365,37 +363,28 @@ class DatasetDocumentSegmentBatchImportApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") - # get file from request - file = request.files["file"] - # check file - if "file" not in request.files: - raise NoFileUploadedError() - if len(request.files) > 1: - raise TooManyFilesError() + parser = reqparse.RequestParser() + parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + upload_file_id = args["upload_file_id"] + + upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() + if not upload_file: + raise NotFound("UploadFile not found.") + # check file type - if not file.filename or not file.filename.lower().endswith(".csv"): + if not upload_file.name or not upload_file.name.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: - # Skip the first row - df = pd.read_csv(file) - result = [] - for index, row in df.iterrows(): - if document.doc_form == "qa_model": - data = {"content": row.iloc[0], "answer": row.iloc[1]} - else: - data = {"content": row.iloc[0]} - result.append(data) - if len(result) == 0: - raise ValueError("The CSV file is empty.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = "segment_batch_import_{}".format(str(job_id)) + indexing_cache_key = f"segment_batch_import_{str(job_id)}" # send batch add segments task redis_client.setnx(indexing_cache_key, "waiting") batch_create_segment_to_index_task.delay( - str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id + str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id ) except Exception as e: return {"error": str(e)}, 500 @@ -406,7 +395,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @account_initialization_required def get(self, job_id): job_id = str(job_id) - indexing_cache_key = "segment_batch_import_{}".format(job_id) + indexing_cache_key = f"segment_batch_import_{job_id}" cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job does not exist.") @@ -595,7 +584,12 @@ class ChildChunkUpdateApi(Resource): child_chunk_id = str(child_chunk_id) child_chunk = ( db.session.query(ChildChunk) - .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .where( + ChildChunk.id == str(child_chunk_id), + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.segment_id == segment.id, + ChildChunk.document_id == document_id, + ) .first() ) if not child_chunk: @@ -644,7 +638,12 @@ class ChildChunkUpdateApi(Resource): child_chunk_id = str(child_chunk_id) child_chunk = ( db.session.query(ChildChunk) - .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .where( + ChildChunk.id == str(child_chunk_id), + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.segment_id == segment.id, + ChildChunk.document_id == document_id, + ) .first() ) if not child_chunk: diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index cb68bb5e81..a43843b551 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -1,30 +1,6 @@ from libs.exception import BaseHTTPException -class NoFileUploadedError(BaseHTTPException): - error_code = "no_file_uploaded" - description = "Please upload your file." - code = 400 - - -class TooManyFilesError(BaseHTTPException): - error_code = "too_many_files" - description = "Only one file is allowed." - code = 400 - - -class FileTooLargeError(BaseHTTPException): - error_code = "file_too_large" - description = "File size exceeded. {message}" - code = 413 - - -class UnsupportedFileTypeError(BaseHTTPException): - error_code = "unsupported_file_type" - description = "File type not allowed." - code = 415 - - class DatasetNotInitializedError(BaseHTTPException): error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index cf9081e154..043f39f623 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,6 +1,6 @@ from flask import request from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_restx import Resource, marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index fba5d4c0f3..2ad192571b 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restx import Resource from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 3b4c076863..304674db5f 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services.dataset_service diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index b1a83aa371..6aa309f930 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,5 +1,7 @@ +from typing import Literal + from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound from controllers.console import api @@ -22,8 +24,8 @@ class DatasetMetadataCreateApi(Resource): @marshal_with(dataset_metadata_fields) def post(self, dataset_id): parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=True, location="json") - parser.add_argument("name", type=str, required=True, nullable=True, location="json") + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() metadata_args = MetadataArgs(**args) @@ -56,7 +58,7 @@ class DatasetMetadataApi(Resource): @marshal_with(dataset_metadata_fields) def patch(self, dataset_id, metadata_id): parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=True, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() dataset_id_str = str(dataset_id) @@ -100,7 +102,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def post(self, dataset_id, action): + def post(self, dataset_id, action: Literal["enable", "disable"]): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -127,7 +129,7 @@ class DocumentMetadataEditApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) parser = reqparse.RequestParser() - parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json") + parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") args = parser.parse_args() metadata_args = MetadataOperationData(**args) diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index fcdc91ec67..bdaa268462 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 6944c56bf8..0645d63be5 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -76,30 +76,6 @@ class EmailSendIpLimitError(BaseHTTPException): code = 429 -class FileTooLargeError(BaseHTTPException): - error_code = "file_too_large" - description = "File size exceeded. {message}" - code = 413 - - -class UnsupportedFileTypeError(BaseHTTPException): - error_code = "unsupported_file_type" - description = "File type not allowed." - code = 415 - - -class TooManyFilesError(BaseHTTPException): - error_code = "too_many_files" - description = "Only one file is allowed." - code = 400 - - -class NoFileUploadedError(BaseHTTPException): - error_code = "no_file_uploaded" - description = "Please upload your file." - code = 400 - - class UnauthorizedAndForceLogout(BaseHTTPException): error_code = "unauthorized_and_force_logout" description = "Unauthorized and force logout." @@ -127,7 +103,7 @@ class EducationActivateLimitError(BaseHTTPException): code = 429 -class CompilanceRateLimitError(BaseHTTPException): - error_code = "compilance_rate_limit" +class ComplianceRateLimitError(BaseHTTPException): + error_code = "compliance_rate_limit" description = "Rate limit exceeded for downloading compliance report." code = 429 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index d564a00a76..2a4d5be82f 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -65,7 +65,7 @@ class ChatAudioApi(InstalledAppResource): class ChatTextApi(InstalledAppResource): def post(self, installed_app): - from flask_restful import reqparse + from flask_restx import reqparse app_model = installed_app.app try: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 4842fefc57..b444a2a197 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,7 +1,7 @@ import logging from flask_login import current_user -from flask_restful import reqparse +from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index d7c161cc6d..a8d46954b5 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,6 +1,6 @@ from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import marshal_with, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index ffdf73c368..3ccedd654b 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -3,7 +3,7 @@ from typing import Any from flask import request from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse +from flask_restx import Resource, inputs, marshal_with, reqparse from sqlalchemy import and_ from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -58,23 +58,40 @@ class InstalledAppsListApi(Resource): # filter out apps that user doesn't have access to if FeatureService.get_system_features().webapp_auth.enabled: user_id = current_user.id - res = [] app_ids = [installed_app["app"].id for installed_app in installed_app_list] webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids) + + # Pre-filter out apps without setting or with sso_verified + filtered_installed_apps = [] + app_id_to_app_code = {} + for installed_app in installed_app_list: - webapp_setting = webapp_settings.get(installed_app["app"].id) - if not webapp_setting: + app_id = installed_app["app"].id + webapp_setting = webapp_settings.get(app_id) + if not webapp_setting or webapp_setting.access_mode == "sso_verified": continue - if webapp_setting.access_mode == "sso_verified": - continue - app_code = AppService.get_app_code_by_id(str(installed_app["app"].id)) - if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( - user_id=user_id, - app_code=app_code, - ): + app_code = AppService.get_app_code_by_id(str(app_id)) + app_id_to_app_code[app_id] = app_code + filtered_installed_apps.append(installed_app) + + app_codes = list(app_id_to_app_code.values()) + + # Batch permission check + permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps( + user_id=user_id, + app_codes=app_codes, + ) + + # Keep only allowed apps + res = [] + for installed_app in filtered_installed_apps: + app_id = installed_app["app"].id + app_code = app_id_to_app_code[app_id] + if permissions.get(app_code): res.append(installed_app) + installed_app_list = res - logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}") + logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id) installed_app_list.sort( key=lambda app: ( diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 822777604a..6df3bca762 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,11 +1,10 @@ import logging from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound -import services from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -29,7 +28,11 @@ from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError -from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError +from services.errors.message import ( + FirstMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) from services.message_service import MessageService @@ -52,9 +55,9 @@ class MessageListApi(InstalledAppResource): return MessageService.pagination_by_first_id( app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] ) - except services.errors.conversation.ConversationNotExistsError: + except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - except services.errors.message.FirstMessageNotExistsError: + except FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") @@ -77,7 +80,7 @@ class MessageFeedbackApi(InstalledAppResource): rating=args.get("rating"), content=args.get("content"), ) - except services.errors.message.MessageNotExistsError: + except MessageNotExistsError: raise NotFound("Message Not Exists.") return {"result": "success"} diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index a1280d91d1..c368744759 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restx import marshal_with from controllers.common import fields from controllers.console import api diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index ce85f495aa..62f9350b71 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from constants.languages import languages from controllers.console import api diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 339e7007a0..5353dbcad5 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,6 +1,6 @@ from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import fields, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 3f625e6609..3d872fc1fc 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restx import reqparse from werkzeug.exceptions import InternalServerError from controllers.console.app.error import ( diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index de97fb149e..e86103184a 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,7 +1,7 @@ from functools import wraps from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from werkzeug.exceptions import NotFound from controllers.console.explore.error import AppAccessDeniedError diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 07a241ef86..e157041c35 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from constants import HIDDEN_VALUE from controllers.console import api diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 70ab4ff865..6236832d39 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from libs.login import login_required from services.feature_service import FeatureService diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 66b6214f82..101a49a32e 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -2,13 +2,19 @@ from typing import Literal from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_restx import Resource, marshal_with from werkzeug.exceptions import Forbidden import services from configs import dify_config from constants import DOCUMENT_EXTENSIONS -from controllers.common.errors import FilenameNotExistsError +from controllers.common.errors import ( + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, @@ -18,13 +24,6 @@ from fields.file_fields import file_fields, upload_config_fields from libs.login import login_required from services.file_service import FileService -from .error import ( - FileTooLargeError, - NoFileUploadedError, - TooManyFilesError, - UnsupportedFileTypeError, -) - PREVIEW_WORDS_LIMIT = 3000 @@ -49,7 +48,6 @@ class FileApi(Resource): @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") def post(self): - file = request.files["file"] source_str = request.form.get("source") source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None @@ -58,6 +56,7 @@ class FileApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + file = request.files["file"] if not file.filename: raise FilenameNotExistsError diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index b19e331d2e..2a37b1708a 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index cd28cc946e..1a53a2347e 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restx import Resource from controllers.console import api diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index b8cf019e4f..73014cfc97 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -3,22 +3,21 @@ from typing import cast import httpx from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse import services from controllers.common import helpers -from controllers.common.errors import RemoteFileUploadError +from controllers.common.errors import ( + FileTooLargeError, + RemoteFileUploadError, + UnsupportedFileTypeError, +) from core.file import helpers as file_helpers from core.helper import ssrf_proxy from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields from models.account import Account from services.file_service import FileService -from .error import ( - FileTooLargeError, - UnsupportedFileTypeError, -) - class RemoteFileInfoApi(Resource): @marshal_with(remote_file_info_fields) diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e1f19a87a3..8e230496f0 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index cb5dedca21..c45e7dbb26 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,11 +1,11 @@ from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required -from fields.tag_fields import tag_fields +from fields.tag_fields import dataset_tag_fields from libs.login import login_required from models.model import Tag from services.tag_service import TagService @@ -21,7 +21,7 @@ class TagListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(tag_fields) + @marshal_with(dataset_tag_fields) def get(self): tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 447cc358f8..96cf627b65 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,7 +2,7 @@ import json import logging import requests -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from packaging import version from configs import dify_config @@ -32,9 +32,9 @@ class VersionApi(Resource): return result try: - response = requests.get(check_update_url, {"current_version": args.get("current_version")}) + response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10)) except Exception as error: - logging.warning("Check update version error: {}.".format(str(error))) + logging.warning("Check update version error: %s.", str(error)) result["version"] = args.get("current_version") return result @@ -55,7 +55,7 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool: # Compare versions return latest > current except version.InvalidVersion: - logging.warning(f"Invalid version format: latest={latest_version}, current={current_version}") + logging.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version) return False diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 45513c368d..5b2828dbab 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,7 +1,9 @@ +from datetime import datetime + import pytz from flask import request from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from sqlalchemy import select from sqlalchemy.orm import Session @@ -9,14 +11,13 @@ from configs import dify_config from constants.languages import supported_language from controllers.console import api from controllers.console.auth.error import ( - AccountInFreezeError, EmailAlreadyInUseError, EmailChangeLimitError, EmailCodeError, InvalidEmailError, InvalidTokenError, ) -from controllers.console.error import AccountNotFound, EmailSendIpLimitError +from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, @@ -328,6 +329,9 @@ class EducationVerifyApi(Resource): class EducationApi(Resource): status_fields = { "result": fields.Boolean, + "is_student": fields.Boolean, + "expire_at": TimestampField, + "allow_refresh": fields.Boolean, } @setup_required @@ -355,7 +359,11 @@ class EducationApi(Resource): def get(self): account = current_user - return BillingService.EducationIdentity.is_active(account.id) + res = BillingService.EducationIdentity.status(account.id) + # convert expire_at to UTC timestamp from isoformat + if res and "expire_at" in res: + res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc) + return res class EducationAutoCompleteApi(Resource): @@ -496,7 +504,7 @@ class ChangeEmailResetApi(Resource): if current_user.email != old_email: raise AccountNotFound() - updated_account = AccountService.update_account(current_user, email=args["new_email"]) + updated_account = AccountService.update_account_email(current_user, email=args["new_email"]) AccountService.send_change_email_completed_notify_email( email=args["new_email"], diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 88c37767e3..08bab6fcb5 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index eb53dcb16e..96e873d42b 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,5 +1,5 @@ from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index b4eb5e246b..2a54511bf0 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index f7424923b9..cf2a10f453 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -2,7 +2,7 @@ from urllib import parse from flask import request from flask_login import current_user -from flask_restful import Resource, abort, marshal_with, reqparse +from flask_restx import Resource, abort, marshal_with, reqparse import services from configs import dify_config @@ -54,7 +54,7 @@ class MemberInviteEmailApi(Resource): @cloud_edition_billing_resource_check("members") def post(self): parser = reqparse.RequestParser() - parser.add_argument("emails", type=str, required=True, location="json", action="append") + parser.add_argument("emails", type=list, required=True, location="json") parser.add_argument("role", type=str, required=True, default="admin", location="json") parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index ff0fcbda6e..281783b3d7 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -2,7 +2,7 @@ import io from flask import send_file from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 37d0f6c764..b8dddb91dd 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,7 +1,7 @@ import logging from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api @@ -73,8 +73,9 @@ class DefaultModelApi(Resource): ) except Exception as ex: logging.exception( - f"Failed to update default model, model type: {model_setting['model_type']}," - f" model:{model_setting.get('model')}" + "Failed to update default model, model type: %s, model: %s", + model_setting["model_type"], + model_setting.get("model"), ) raise ex @@ -160,8 +161,10 @@ class ModelProviderModelApi(Resource): ) except CredentialsValidateFailedError as ex: logging.exception( - f"Failed to save model credentials, tenant_id: {tenant_id}," - f" model: {args.get('model')}, model_type: {args.get('model_type')}" + "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", + tenant_id, + args.get("model"), + args.get("model_type"), ) raise ValueError(str(ex)) diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index 09846d5c94..fd5421fa64 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -2,7 +2,7 @@ import io from flask import request, send_file from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c4d1ef70d8..854ba7ac45 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -3,7 +3,7 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_login import current_user -from flask_restful import ( +from flask_restx import ( Resource, reqparse, ) @@ -862,6 +862,10 @@ class ToolProviderMCPApi(Resource): parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") + parser.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30) + parser.add_argument( + "sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300 + ) args = parser.parse_args() user = current_user if not is_valid_url(args["server_url"]): @@ -876,6 +880,8 @@ class ToolProviderMCPApi(Resource): icon_background=args["icon_background"], user_id=user.id, server_identifier=args["server_identifier"], + timeout=args["timeout"], + sse_read_timeout=args["sse_read_timeout"], ) ) @@ -891,6 +897,8 @@ class ToolProviderMCPApi(Resource): parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json") + parser.add_argument("timeout", type=float, required=False, nullable=True, location="json") + parser.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json") args = parser.parse_args() if not is_valid_url(args["server_url"]): if "[__HIDDEN__]" in args["server_url"]: @@ -906,6 +914,8 @@ class ToolProviderMCPApi(Resource): icon_type=args["icon_type"], icon_background=args["icon_background"], server_identifier=args["server_identifier"], + timeout=args.get("timeout"), + sse_read_timeout=args.get("sse_read_timeout"), ) return {"result": "success"} diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 19999e7361..fb89f6bbbd 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -2,20 +2,20 @@ import logging from flask import request from flask_login import current_user -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from sqlalchemy import select from werkzeug.exceptions import Unauthorized import services -from controllers.common.errors import FilenameNotExistsError -from controllers.console import api -from controllers.console.admin import admin_required -from controllers.console.datasets.error import ( +from controllers.common.errors import ( + FilenameNotExistsError, FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.console import api +from controllers.console.admin import admin_required from controllers.console.error import AccountNotLinkTenantError from controllers.console.wraps import ( account_initialization_required, @@ -191,9 +191,6 @@ class WebappLogoWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): - # get file from request - file = request.files["file"] - # check file if "file" not in request.files: raise NoFileUploadedError() @@ -201,6 +198,8 @@ class WebappLogoWorkspaceApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + # get file from request + file = request.files["file"] if not file.filename: raise FilenameNotExistsError diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index d862dac373..d3fd1d52e5 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,3 +1,4 @@ +import contextlib import json import os import time @@ -178,7 +179,7 @@ def cloud_edition_billing_rate_limit_check(resource: str): def cloud_utm_record(view): @wraps(view) def decorated(*args, **kwargs): - try: + with contextlib.suppress(Exception): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: @@ -187,8 +188,7 @@ def cloud_utm_record(view): if utm_info: utm_info_dict: dict = json.loads(utm_info) OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) - except Exception as e: - pass + return view(*args, **kwargs) return decorated diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index d4c3245708..821ad220a2 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -1,9 +1,20 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi -bp = Blueprint("files", __name__) -api = ExternalApi(bp) +bp = Blueprint("files", __name__, url_prefix="/files") +api = ExternalApi( + bp, + version="1.0", + title="Files API", + description="API for file operations including upload and preview", + doc="/docs", # Enable Swagger UI at /files/docs +) + +files_ns = Namespace("files", description="File operations", path="/") from . import image_preview, tool_files, upload + +api.add_namespace(files_ns) diff --git a/api/controllers/files/error.py b/api/controllers/files/error.py deleted file mode 100644 index a7ce4cd6f7..0000000000 --- a/api/controllers/files/error.py +++ /dev/null @@ -1,7 +0,0 @@ -from libs.exception import BaseHTTPException - - -class UnsupportedFileTypeError(BaseHTTPException): - error_code = "unsupported_file_type" - description = "File type not allowed." - code = 415 diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 46c19e1fbb..48baac6556 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,16 +1,17 @@ from urllib.parse import quote from flask import Response, request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import NotFound import services -from controllers.files import api -from controllers.files.error import UnsupportedFileTypeError +from controllers.common.errors import UnsupportedFileTypeError +from controllers.files import files_ns from services.account_service import TenantService from services.file_service import FileService +@files_ns.route("//image-preview") class ImagePreviewApi(Resource): """ Deprecated @@ -39,6 +40,7 @@ class ImagePreviewApi(Resource): return Response(generator, mimetype=mimetype) +@files_ns.route("//file-preview") class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) @@ -94,6 +96,7 @@ class FilePreviewApi(Resource): return response +@files_ns.route("/workspaces//webapp-logo") class WorkspaceWebappLogoApi(Resource): def get(self, workspace_id): workspace_id = str(workspace_id) @@ -112,8 +115,3 @@ class WorkspaceWebappLogoApi(Resource): raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) - - -api.add_resource(ImagePreviewApi, "/files//image-preview") -api.add_resource(FilePreviewApi, "/files//file-preview") -api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo") diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 1c3430ef4f..faa9b733c2 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,17 +1,18 @@ from urllib.parse import quote from flask import Response -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import Forbidden, NotFound -from controllers.files import api -from controllers.files.error import UnsupportedFileTypeError +from controllers.common.errors import UnsupportedFileTypeError +from controllers.files import files_ns from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager from models import db as global_db -class ToolFilePreviewApi(Resource): +@files_ns.route("/tools/.") +class ToolFileApi(Resource): def get(self, file_id, extension): file_id = str(file_id) @@ -52,6 +53,3 @@ class ToolFilePreviewApi(Resource): response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" return response - - -api.add_resource(ToolFilePreviewApi, "/files/tools/.") diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 15f93d2774..7a2b3b0428 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -1,46 +1,87 @@ from mimetypes import guess_extension +from typing import Optional -from flask import request -from flask_restful import Resource, marshal_with +from flask_restx import Resource, reqparse +from flask_restx.api import HTTPStatus +from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden import services +from controllers.common.errors import ( + FileTooLargeError, + UnsupportedFileTypeError, +) from controllers.console.wraps import setup_required -from controllers.files import api -from controllers.files.error import UnsupportedFileTypeError +from controllers.files import files_ns from controllers.inner_api.plugin.wraps import get_user -from controllers.service_api.app.error import FileTooLargeError from core.file.helpers import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from fields.file_fields import file_fields +from fields.file_fields import build_file_model + +# Define parser for both documentation and validation +upload_parser = reqparse.RequestParser() +upload_parser.add_argument("file", location="files", type=FileStorage, required=True, help="File to upload") +upload_parser.add_argument( + "timestamp", type=str, required=True, location="args", help="Unix timestamp for signature verification" +) +upload_parser.add_argument( + "nonce", type=str, required=True, location="args", help="Random string for signature verification" +) +upload_parser.add_argument( + "sign", type=str, required=True, location="args", help="HMAC signature for request validation" +) +upload_parser.add_argument("tenant_id", type=str, required=True, location="args", help="Tenant identifier") +upload_parser.add_argument("user_id", type=str, required=False, location="args", help="User identifier") +@files_ns.route("/upload/for-plugin") class PluginUploadFileApi(Resource): @setup_required - @marshal_with(file_fields) + @files_ns.expect(upload_parser) + @files_ns.doc("upload_plugin_file") + @files_ns.doc(description="Upload a file for plugin usage with signature verification") + @files_ns.doc( + responses={ + 201: "File uploaded successfully", + 400: "Invalid request parameters", + 403: "Forbidden - Invalid signature or missing parameters", + 413: "File too large", + 415: "Unsupported file type", + } + ) + @files_ns.marshal_with(build_file_model(files_ns), code=HTTPStatus.CREATED) def post(self): - # get file from request - file = request.files["file"] + """Upload a file for plugin usage. - timestamp = request.args.get("timestamp") - nonce = request.args.get("nonce") - sign = request.args.get("sign") - tenant_id = request.args.get("tenant_id") - if not tenant_id: - raise Forbidden("Invalid request.") + Accepts a file upload with signature verification for security. + The file must be accompanied by valid timestamp, nonce, and signature parameters. - user_id = request.args.get("user_id") + Returns: + dict: File metadata including ID, URLs, and properties + int: HTTP status code (201 for success) + + Raises: + Forbidden: Invalid signature or missing required parameters + FileTooLargeError: File exceeds size limit + UnsupportedFileTypeError: File type not supported + """ + # Parse and validate all arguments + args = upload_parser.parse_args() + + file: FileStorage = args["file"] + timestamp: str = args["timestamp"] + nonce: str = args["nonce"] + sign: str = args["sign"] + tenant_id: str = args["tenant_id"] + user_id: Optional[str] = args.get("user_id") user = get_user(tenant_id, user_id) - filename = file.filename - mimetype = file.mimetype + filename: Optional[str] = file.filename + mimetype: Optional[str] = file.mimetype if not filename or not mimetype: raise Forbidden("Invalid request.") - if not timestamp or not nonce or not sign: - raise Forbidden("Invalid request.") - if not verify_plugin_file_signature( filename=filename, mimetype=mimetype, @@ -86,6 +127,3 @@ class PluginUploadFileApi(Resource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - - -api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin") diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py index ce3373d65c..80bbc360de 100644 --- a/api/controllers/inner_api/mail.py +++ b/api/controllers/inner_api/mail.py @@ -1,27 +1,38 @@ -from flask_restful import ( - Resource, # type: ignore - reqparse, -) +from flask_restx import Resource, reqparse from controllers.console.wraps import setup_required from controllers.inner_api import api -from controllers.inner_api.wraps import enterprise_inner_api_only -from services.enterprise.mail_service import DifyMail, EnterpriseMailService +from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only +from tasks.mail_inner_task import send_inner_email_task + +_mail_parser = reqparse.RequestParser() +_mail_parser.add_argument("to", type=str, action="append", required=True) +_mail_parser.add_argument("subject", type=str, required=True) +_mail_parser.add_argument("body", type=str, required=True) +_mail_parser.add_argument("substitutions", type=dict, required=False) -class EnterpriseMail(Resource): - @setup_required - @enterprise_inner_api_only +class BaseMail(Resource): + """Shared logic for sending an inner email.""" + def post(self): - parser = reqparse.RequestParser() - parser.add_argument("to", type=str, action="append", required=True) - parser.add_argument("subject", type=str, required=True) - parser.add_argument("body", type=str, required=True) - parser.add_argument("substitutions", type=dict, required=False) - args = parser.parse_args() - - EnterpriseMailService.send_mail(DifyMail(**args)) + args = _mail_parser.parse_args() + send_inner_email_task.delay( + to=args["to"], + subject=args["subject"], + body=args["body"], + substitutions=args["substitutions"], + ) return {"message": "success"}, 200 +class EnterpriseMail(BaseMail): + method_decorators = [setup_required, enterprise_inner_api_only] + + +class BillingMail(BaseMail): + method_decorators = [setup_required, billing_inner_api_only] + + api.add_resource(EnterpriseMail, "/enterprise/mail") +api.add_resource(BillingMail, "/billing/mail") diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 5dfe41eb6b..9b8d9457f0 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restx import Resource from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index b533614d4d..89b4ac7506 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -4,7 +4,7 @@ from typing import Optional from flask import current_app, request from flask_login import user_logged_in -from flask_restful import reqparse +from flask_restx import reqparse from pydantic import BaseModel from sqlalchemy.orm import Session diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 77568b75f1..1c26416080 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,6 +1,6 @@ import json -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 9e7b3d4f29..c5aa318f58 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -10,6 +10,22 @@ from extensions.ext_database import db from models.model import EndUser +def billing_inner_api_only(view): + @wraps(view) + def decorated(*args, **kwargs): + if not dify_config.INNER_API: + abort(404) + + # get header 'X-Inner-Api-Key' + inner_api_key = request.headers.get("X-Inner-Api-Key") + if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: + abort(401) + + return view(*args, **kwargs) + + return decorated + + def enterprise_inner_api_only(view): @wraps(view) def decorated(*args, **kwargs): diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index 1b3e0a5621..c344ffad08 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -1,8 +1,20 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi bp = Blueprint("mcp", __name__, url_prefix="/mcp") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="MCP API", + description="API for Model Context Protocol operations", + doc="/docs", # Enable Swagger UI at /mcp/docs +) + +mcp_ns = Namespace("mcp", description="MCP operations", path="/") from . import mcp + +api.add_namespace(mcp_ns) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 87d678796f..fc19749011 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,8 +1,10 @@ -from flask_restful import Resource, reqparse +from typing import Optional, Union + +from flask_restx import Resource, reqparse from pydantic import ValidationError from controllers.console.app.mcp_server import AppMCPServerStatus -from controllers.mcp import api +from controllers.mcp import mcp_ns from core.app.app_config.entities import VariableEntity from core.mcp import types from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler @@ -13,22 +15,58 @@ from libs import helper from models.model import App, AppMCPServer, AppMode +def int_or_str(value): + """Validate that a value is either an integer or string.""" + if isinstance(value, (int, str)): + return value + else: + return None + + +# Define parser for both documentation and validation +mcp_request_parser = reqparse.RequestParser() +mcp_request_parser.add_argument( + "jsonrpc", type=str, required=True, location="json", help="JSON-RPC version (should be '2.0')" +) +mcp_request_parser.add_argument("method", type=str, required=True, location="json", help="The method to invoke") +mcp_request_parser.add_argument("params", type=dict, required=False, location="json", help="Parameters for the method") +mcp_request_parser.add_argument( + "id", type=int_or_str, required=False, location="json", help="Request ID for tracking responses" +) + + +@mcp_ns.route("/server//mcp") class MCPAppApi(Resource): - def post(self, server_code): - def int_or_str(value): - if isinstance(value, (int, str)): - return value - else: - return None + @mcp_ns.expect(mcp_request_parser) + @mcp_ns.doc("handle_mcp_request") + @mcp_ns.doc(description="Handle Model Context Protocol (MCP) requests for a specific server") + @mcp_ns.doc(params={"server_code": "Unique identifier for the MCP server"}) + @mcp_ns.doc( + responses={ + 200: "MCP response successfully processed", + 400: "Invalid MCP request or parameters", + 404: "Server or app not found", + } + ) + def post(self, server_code: str): + """Handle MCP requests for a specific server. - parser = reqparse.RequestParser() - parser.add_argument("jsonrpc", type=str, required=True, location="json") - parser.add_argument("method", type=str, required=True, location="json") - parser.add_argument("params", type=dict, required=False, location="json") - parser.add_argument("id", type=int_or_str, required=False, location="json") - args = parser.parse_args() + Processes JSON-RPC formatted requests according to the Model Context Protocol specification. + Validates the server status and associated app before processing the request. - request_id = args.get("id") + Args: + server_code: Unique identifier for the MCP server + + Returns: + dict: JSON-RPC response from the MCP handler + + Raises: + ValidationError: Invalid request format or parameters + """ + # Parse and validate all arguments + args = mcp_request_parser.parse_args() + + request_id: Optional[Union[int, str]] = args.get("id") server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() if not server: @@ -99,6 +137,3 @@ class MCPAppApi(Resource): mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form) response = mcp_server_handler.handle() return helper.compact_generate_response(response) - - -api.add_resource(MCPAppApi, "/server//mcp") diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d964e27819..763345d723 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -1,11 +1,23 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi bp = Blueprint("service_api", __name__, url_prefix="/v1") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="Service API", + description="API for application services", + doc="/docs", # Enable Swagger UI at /v1/docs +) + +service_api_ns = Namespace("service_api", description="Service operations", path="/") from . import index -from .app import annotation, app, audio, completion, conversation, file, message, site, workflow +from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .workspace import models + +api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 595ae118ef..6bc94af8c1 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,40 +1,75 @@ +from typing import Literal + from flask import request -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_restx import Api, Namespace, Resource, fields, reqparse +from flask_restx.api import HTTPStatus from werkzeug.exceptions import Forbidden -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client -from fields.annotation_fields import ( - annotation_fields, -) +from fields.annotation_fields import annotation_fields, build_annotation_model from libs.login import current_user from models.model import App from services.annotation_service import AppAnnotationService +# Define parsers for annotation API +annotation_create_parser = reqparse.RequestParser() +annotation_create_parser.add_argument("question", required=True, type=str, location="json", help="Annotation question") +annotation_create_parser.add_argument("answer", required=True, type=str, location="json", help="Annotation answer") +annotation_reply_action_parser = reqparse.RequestParser() +annotation_reply_action_parser.add_argument( + "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" +) +annotation_reply_action_parser.add_argument( + "embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name" +) +annotation_reply_action_parser.add_argument( + "embedding_model_name", required=True, type=str, location="json", help="Embedding model name" +) + + +@service_api_ns.route("/apps/annotation-reply/") class AnnotationReplyActionApi(Resource): + @service_api_ns.expect(annotation_reply_action_parser) + @service_api_ns.doc("annotation_reply_action") + @service_api_ns.doc(description="Enable or disable annotation reply feature") + @service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"}) + @service_api_ns.doc( + responses={ + 200: "Action completed successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token - def post(self, app_model: App, action): - parser = reqparse.RequestParser() - parser.add_argument("score_threshold", required=True, type=float, location="json") - parser.add_argument("embedding_provider_name", required=True, type=str, location="json") - parser.add_argument("embedding_model_name", required=True, type=str, location="json") - args = parser.parse_args() + def post(self, app_model: App, action: Literal["enable", "disable"]): + """Enable or disable annotation reply feature.""" + args = annotation_reply_action_parser.parse_args() if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_model.id) elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_model.id) - else: - raise ValueError("Unsupported annotation reply action") return result, 200 +@service_api_ns.route("/apps/annotation-reply//status/") class AnnotationReplyActionStatusApi(Resource): + @service_api_ns.doc("get_annotation_reply_action_status") + @service_api_ns.doc(description="Get the status of an annotation reply action job") + @service_api_ns.doc(params={"action": "Action type", "job_id": "Job ID"}) + @service_api_ns.doc( + responses={ + 200: "Job status retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Job not found", + } + ) @validate_app_token def get(self, app_model: App, job_id, action): + """Get the status of an annotation reply action job.""" job_id = str(job_id) - app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) + app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job does not exist.") @@ -42,66 +77,117 @@ class AnnotationReplyActionStatusApi(Resource): job_status = cache_result.decode() error_msg = "" if job_status == "error": - app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) + app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}" error_msg = redis_client.get(app_annotation_error_key).decode() return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 +# Define annotation list response model +annotation_list_fields = { + "data": fields.List(fields.Nested(annotation_fields)), + "has_more": fields.Boolean, + "limit": fields.Integer, + "total": fields.Integer, + "page": fields.Integer, +} + + +def build_annotation_list_model(api_or_ns: Api | Namespace): + """Build the annotation list model for the API or Namespace.""" + copied_annotation_list_fields = annotation_list_fields.copy() + copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) + return api_or_ns.model("AnnotationList", copied_annotation_list_fields) + + +@service_api_ns.route("/apps/annotations") class AnnotationListApi(Resource): + @service_api_ns.doc("list_annotations") + @service_api_ns.doc(description="List annotations for the application") + @service_api_ns.doc( + responses={ + 200: "Annotations retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token + @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns)) def get(self, app_model: App): + """List annotations for the application.""" page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) keyword = request.args.get("keyword", default="", type=str) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) - response = { - "data": marshal(annotation_list, annotation_fields), + return { + "data": annotation_list, "has_more": len(annotation_list) == limit, "limit": limit, "total": total, "page": page, } - return response, 200 + @service_api_ns.expect(annotation_create_parser) + @service_api_ns.doc("create_annotation") + @service_api_ns.doc(description="Create a new annotation") + @service_api_ns.doc( + responses={ + 201: "Annotation created successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token - @marshal_with(annotation_fields) + @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - args = parser.parse_args() + """Create a new annotation.""" + args = annotation_create_parser.parse_args() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) - return annotation + return annotation, 201 +@service_api_ns.route("/apps/annotations/") class AnnotationUpdateDeleteApi(Resource): + @service_api_ns.expect(annotation_create_parser) + @service_api_ns.doc("update_annotation") + @service_api_ns.doc(description="Update an existing annotation") + @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) + @service_api_ns.doc( + responses={ + 200: "Annotation updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Annotation not found", + } + ) @validate_app_token - @marshal_with(annotation_fields) + @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id): + """Update an existing annotation.""" if not current_user.is_editor: raise Forbidden() annotation_id = str(annotation_id) - parser = reqparse.RequestParser() - parser.add_argument("question", required=True, type=str, location="json") - parser.add_argument("answer", required=True, type=str, location="json") - args = parser.parse_args() + args = annotation_create_parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) return annotation + @service_api_ns.doc("delete_annotation") + @service_api_ns.doc(description="Delete an annotation") + @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) + @service_api_ns.doc( + responses={ + 204: "Annotation deleted successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Annotation not found", + } + ) @validate_app_token def delete(self, app_model: App, annotation_id): + """Delete an annotation.""" if not current_user.is_editor: raise Forbidden() annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) return {"result": "success"}, 204 - - -api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/") -api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply//status/") -api.add_resource(AnnotationListApi, "/apps/annotations") -api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/") diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 89222d5e83..2dbeed1d68 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,7 +1,7 @@ -from flask_restful import Resource, marshal_with +from flask_restx import Resource -from controllers.common import fields -from controllers.service_api import api +from controllers.common.fields import build_parameters_model +from controllers.service_api import service_api_ns from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict @@ -9,13 +9,26 @@ from models.model import App, AppMode from services.app_service import AppService +@service_api_ns.route("/parameters") class AppParameterApi(Resource): """Resource for app variables.""" + @service_api_ns.doc("get_app_parameters") + @service_api_ns.doc(description="Retrieve application input parameters and configuration") + @service_api_ns.doc( + responses={ + 200: "Parameters retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Application not found", + } + ) @validate_app_token - @marshal_with(fields.parameters_fields) + @service_api_ns.marshal_with(build_parameters_model(service_api_ns)) def get(self, app_model: App): - """Retrieve app parameters.""" + """Retrieve app parameters. + + Returns the input form parameters and configuration for the application. + """ if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: @@ -35,17 +48,43 @@ class AppParameterApi(Resource): return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) +@service_api_ns.route("/meta") class AppMetaApi(Resource): + @service_api_ns.doc("get_app_meta") + @service_api_ns.doc(description="Get application metadata") + @service_api_ns.doc( + responses={ + 200: "Metadata retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Application not found", + } + ) @validate_app_token def get(self, app_model: App): - """Get app meta""" + """Get app metadata. + + Returns metadata about the application including configuration and settings. + """ return AppService().get_app_meta(app_model) +@service_api_ns.route("/info") class AppInfoApi(Resource): + @service_api_ns.doc("get_app_info") + @service_api_ns.doc(description="Get basic application information") + @service_api_ns.doc( + responses={ + 200: "Application info retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Application not found", + } + ) @validate_app_token def get(self, app_model: App): - """Get app information""" + """Get app information. + + Returns basic information about the application including name, description, tags, and mode. + """ tags = [tag.name for tag in app_model.tags] return { "name": app_model.name, @@ -54,8 +93,3 @@ class AppInfoApi(Resource): "mode": app_model.mode, "author_name": app_model.author_name, } - - -api.add_resource(AppParameterApi, "/parameters") -api.add_resource(AppMetaApi, "/meta") -api.add_resource(AppInfoApi, "/info") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 848863cf1b..61b3020a5f 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,11 +1,11 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -30,9 +30,26 @@ from services.errors.audio import ( ) +@service_api_ns.route("/audio-to-text") class AudioApi(Resource): + @service_api_ns.doc("audio_to_text") + @service_api_ns.doc(description="Convert audio to text using speech-to-text") + @service_api_ns.doc( + responses={ + 200: "Audio successfully transcribed", + 400: "Bad request - no audio or invalid audio", + 401: "Unauthorized - invalid API token", + 413: "Audio file too large", + 415: "Unsupported audio type", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): + """Convert audio to text using speech-to-text. + + Accepts an audio file upload and returns the transcribed text. + """ file = request.files["file"] try: @@ -65,16 +82,35 @@ class AudioApi(Resource): raise InternalServerError() +# Define parser for text-to-audio API +text_to_audio_parser = reqparse.RequestParser() +text_to_audio_parser.add_argument("message_id", type=str, required=False, location="json", help="Message ID") +text_to_audio_parser.add_argument("voice", type=str, location="json", help="Voice to use for TTS") +text_to_audio_parser.add_argument("text", type=str, location="json", help="Text to convert to audio") +text_to_audio_parser.add_argument("streaming", type=bool, location="json", help="Enable streaming response") + + +@service_api_ns.route("/text-to-audio") class TextApi(Resource): + @service_api_ns.expect(text_to_audio_parser) + @service_api_ns.doc("text_to_audio") + @service_api_ns.doc(description="Convert text to audio using text-to-speech") + @service_api_ns.doc( + responses={ + 200: "Text successfully converted to audio", + 400: "Bad request - invalid parameters", + 401: "Unauthorized - invalid API token", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) def post(self, app_model: App, end_user: EndUser): + """Convert text to audio using text-to-speech. + + Converts the provided text to audio using the specified voice. + """ try: - parser = reqparse.RequestParser() - parser.add_argument("message_id", type=str, required=False, location="json") - parser.add_argument("voice", type=str, location="json") - parser.add_argument("text", type=str, location="json") - parser.add_argument("streaming", type=bool, location="json") - args = parser.parse_args() + args = text_to_audio_parser.parse_args() message_id = args.get("message_id", None) text = args.get("text", None) @@ -108,7 +144,3 @@ class TextApi(Resource): except Exception as e: logging.exception("internal server error.") raise InternalServerError() - - -api.add_resource(AudioApi, "/audio-to-text") -api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 7762672494..dddb75d593 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,11 +1,11 @@ import logging from flask import request -from flask_restful import Resource, reqparse -from werkzeug.exceptions import InternalServerError, NotFound +from flask_restx import Resource, reqparse +from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, @@ -30,23 +30,74 @@ from libs import helper from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService +from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError +# Define parser for completion API +completion_parser = reqparse.RequestParser() +completion_parser.add_argument( + "inputs", type=dict, required=True, location="json", help="Input parameters for completion" +) +completion_parser.add_argument("query", type=str, location="json", default="", help="The query string") +completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") +completion_parser.add_argument( + "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" +) +completion_parser.add_argument( + "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" +) +# Define parser for chat API +chat_parser = reqparse.RequestParser() +chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") +chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query") +chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") +chat_parser.add_argument( + "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" +) +chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") +chat_parser.add_argument( + "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" +) +chat_parser.add_argument( + "auto_generate_name", + type=bool, + required=False, + default=True, + location="json", + help="Auto generate conversation name", +) +chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat") + + +@service_api_ns.route("/completion-messages") class CompletionApi(Resource): + @service_api_ns.expect(completion_parser) + @service_api_ns.doc("create_completion") + @service_api_ns.doc(description="Create a completion for the given prompt") + @service_api_ns.doc( + responses={ + 200: "Completion created successfully", + 400: "Bad request - invalid parameters", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): + """Create a completion for the given prompt. + + This endpoint generates a completion based on the provided inputs and query. + Supports both blocking and streaming response modes. + """ if app_model.mode != "completion": raise AppUnavailableError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, location="json", default="") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") - - args = parser.parse_args() + args = completion_parser.parse_args() + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id streaming = args["response_mode"] == "streaming" @@ -84,9 +135,21 @@ class CompletionApi(Resource): raise InternalServerError() +@service_api_ns.route("/completion-messages//stop") class CompletionStopApi(Resource): + @service_api_ns.doc("stop_completion") + @service_api_ns.doc(description="Stop a running completion task") + @service_api_ns.doc(params={"task_id": "The ID of the task to stop"}) + @service_api_ns.doc( + responses={ + 200: "Task stopped successfully", + 401: "Unauthorized - invalid API token", + 404: "Task not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def post(self, app_model: App, end_user: EndUser, task_id): + def post(self, app_model: App, end_user: EndUser, task_id: str): + """Stop a running completion task.""" if app_model.mode != "completion": raise AppUnavailableError() @@ -95,23 +158,33 @@ class CompletionStopApi(Resource): return {"result": "success"}, 200 +@service_api_ns.route("/chat-messages") class ChatApi(Resource): + @service_api_ns.expect(chat_parser) + @service_api_ns.doc("create_chat_message") + @service_api_ns.doc(description="Send a message in a chat conversation") + @service_api_ns.doc( + responses={ + 200: "Message sent successfully", + 400: "Bad request - invalid parameters or workflow issues", + 401: "Unauthorized - invalid API token", + 404: "Conversation or workflow not found", + 429: "Rate limit exceeded", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): + """Send a message in a chat conversation. + + This endpoint handles chat messages for chat, agent chat, and advanced chat applications. + Supports conversation management and both blocking and streaming response modes. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("query", type=str, required=True, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - parser.add_argument("conversation_id", type=uuid_value, location="json") - parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") - parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") - - args = parser.parse_args() + args = chat_parser.parse_args() external_trace_id = get_external_trace_id(request) if external_trace_id: @@ -125,6 +198,12 @@ class ChatApi(Resource): ) return helper.compact_generate_response(response) + except WorkflowNotFoundError as ex: + raise NotFound(str(ex)) + except IsDraftWorkflowError as ex: + raise BadRequest(str(ex)) + except WorkflowIdFormatError as ex: + raise BadRequest(str(ex)) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationCompletedError: @@ -149,9 +228,21 @@ class ChatApi(Resource): raise InternalServerError() +@service_api_ns.route("/chat-messages//stop") class ChatStopApi(Resource): + @service_api_ns.doc("stop_chat_message") + @service_api_ns.doc(description="Stop a running chat message generation") + @service_api_ns.doc(params={"task_id": "The ID of the task to stop"}) + @service_api_ns.doc( + responses={ + 200: "Task stopped successfully", + 401: "Unauthorized - invalid API token", + 404: "Task not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) - def post(self, app_model: App, end_user: EndUser, task_id): + def post(self, app_model: App, end_user: EndUser, task_id: str): + """Stop a running chat message generation.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -159,9 +250,3 @@ class ChatStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionApi, "/completion-messages") -api.add_resource(CompletionStopApi, "/completion-messages//stop") -api.add_resource(ChatApi, "/chat-messages") -api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 36a7905572..4860bf3a79 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,47 +1,97 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Resource, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session -from werkzeug.exceptions import NotFound +from werkzeug.exceptions import BadRequest, NotFound import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( - conversation_delete_fields, - conversation_infinite_scroll_pagination_fields, - simple_conversation_fields, + build_conversation_delete_model, + build_conversation_infinite_scroll_pagination_model, + build_simple_conversation_model, ) from fields.conversation_variable_fields import ( - conversation_variable_infinite_scroll_pagination_fields, + build_conversation_variable_infinite_scroll_pagination_model, + build_conversation_variable_model, ) from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService +# Define parsers for conversation APIs +conversation_list_parser = reqparse.RequestParser() +conversation_list_parser.add_argument( + "last_id", type=uuid_value, location="args", help="Last conversation ID for pagination" +) +conversation_list_parser.add_argument( + "limit", + type=int_range(1, 100), + required=False, + default=20, + location="args", + help="Number of conversations to return", +) +conversation_list_parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + help="Sort order for conversations", +) +conversation_rename_parser = reqparse.RequestParser() +conversation_rename_parser.add_argument("name", type=str, required=False, location="json", help="New conversation name") +conversation_rename_parser.add_argument( + "auto_generate", type=bool, required=False, default=False, location="json", help="Auto-generate conversation name" +) + +conversation_variables_parser = reqparse.RequestParser() +conversation_variables_parser.add_argument( + "last_id", type=uuid_value, location="args", help="Last variable ID for pagination" +) +conversation_variables_parser.add_argument( + "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of variables to return" +) + +conversation_variable_update_parser = reqparse.RequestParser() +# using lambda is for passing the already-typed value without modification +# if no lambda, it will be converted to string +# the string cannot be converted using json.loads +conversation_variable_update_parser.add_argument( + "value", required=True, location="json", type=lambda x: x, help="New value for the conversation variable" +) + + +@service_api_ns.route("/conversations") class ConversationApi(Resource): + @service_api_ns.expect(conversation_list_parser) + @service_api_ns.doc("list_conversations") + @service_api_ns.doc(description="List all conversations for the current user") + @service_api_ns.doc( + responses={ + 200: "Conversations retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Last conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @marshal_with(conversation_infinite_scroll_pagination_fields) + @service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): + """List all conversations for the current user. + + Supports pagination using last_id and limit parameters. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - parser.add_argument( - "sort_by", - type=str, - choices=["created_at", "-created_at", "updated_at", "-updated_at"], - required=False, - default="-updated_at", - location="args", - ) - args = parser.parse_args() + args = conversation_list_parser.parse_args() try: with Session(db.engine) as session: @@ -58,10 +108,22 @@ class ConversationApi(Resource): raise NotFound("Last Conversation Not Exists.") +@service_api_ns.route("/conversations/") class ConversationDetailApi(Resource): + @service_api_ns.doc("delete_conversation") + @service_api_ns.doc(description="Delete a specific conversation") + @service_api_ns.doc(params={"c_id": "Conversation ID"}) + @service_api_ns.doc( + responses={ + 204: "Conversation deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(conversation_delete_fields) + @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) def delete(self, app_model: App, end_user: EndUser, c_id): + """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() @@ -75,20 +137,30 @@ class ConversationDetailApi(Resource): return {"result": "success"}, 204 +@service_api_ns.route("/conversations//name") class ConversationRenameApi(Resource): + @service_api_ns.expect(conversation_rename_parser) + @service_api_ns.doc("rename_conversation") + @service_api_ns.doc(description="Rename a conversation or auto-generate a name") + @service_api_ns.doc(params={"c_id": "Conversation ID"}) + @service_api_ns.doc( + responses={ + 200: "Conversation renamed successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(simple_conversation_fields) + @service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns)) def post(self, app_model: App, end_user: EndUser, c_id): + """Rename a conversation or auto-generate a name.""" app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, location="json") - parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") - args = parser.parse_args() + args = conversation_rename_parser.parse_args() try: return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) @@ -96,10 +168,26 @@ class ConversationRenameApi(Resource): raise NotFound("Conversation Not Exists.") +@service_api_ns.route("/conversations//variables") class ConversationVariablesApi(Resource): + @service_api_ns.expect(conversation_variables_parser) + @service_api_ns.doc("list_conversation_variables") + @service_api_ns.doc(description="List all variables for a conversation") + @service_api_ns.doc(params={"c_id": "Conversation ID"}) + @service_api_ns.doc( + responses={ + 200: "Variables retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @marshal_with(conversation_variable_infinite_scroll_pagination_fields) + @service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser, c_id): + """List all variables for a conversation. + + Conversational variables are only available for chat applications. + """ # conversational variable only for chat app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -107,10 +195,7 @@ class ConversationVariablesApi(Resource): conversation_id = str(c_id) - parser = reqparse.RequestParser() - parser.add_argument("last_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = conversation_variables_parser.parse_args() try: return ConversationService.get_conversational_variable( @@ -120,7 +205,44 @@ class ConversationVariablesApi(Resource): raise NotFound("Conversation Not Exists.") -api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") -api.add_resource(ConversationApi, "/conversations") -api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") -api.add_resource(ConversationVariablesApi, "/conversations//variables", endpoint="conversation_variables") +@service_api_ns.route("/conversations//variables/") +class ConversationVariableDetailApi(Resource): + @service_api_ns.expect(conversation_variable_update_parser) + @service_api_ns.doc("update_conversation_variable") + @service_api_ns.doc(description="Update a conversation variable's value") + @service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"}) + @service_api_ns.doc( + responses={ + 200: "Variable updated successfully", + 400: "Bad request - type mismatch", + 401: "Unauthorized - invalid API token", + 404: "Conversation or variable not found", + } + ) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns)) + def put(self, app_model: App, end_user: EndUser, c_id, variable_id): + """Update a conversation variable's value. + + Allows updating the value of a specific conversation variable. + The value must match the variable's expected type. + """ + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: + raise NotChatAppError() + + conversation_id = str(c_id) + variable_id = str(variable_id) + + args = conversation_variable_update_parser.parse_args() + + try: + return ConversationService.update_conversation_variable( + app_model, conversation_id, variable_id, end_user, args["value"] + ) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationVariableNotExistsError: + raise NotFound("Conversation Variable Not Exists.") + except services.errors.conversation.ConversationVariableTypeMismatchError as e: + raise BadRequest(str(e)) diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index ca91da80c1..0e04a04cb2 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -85,25 +85,13 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException): code = 400 -class NoFileUploadedError(BaseHTTPException): - error_code = "no_file_uploaded" - description = "Please upload your file." - code = 400 +class FileNotFoundError(BaseHTTPException): + error_code = "file_not_found" + description = "The requested file was not found." + code = 404 -class TooManyFilesError(BaseHTTPException): - error_code = "too_many_files" - description = "Only one file is allowed." - code = 400 - - -class FileTooLargeError(BaseHTTPException): - error_code = "file_too_large" - description = "File size exceeded. {message}" - code = 413 - - -class UnsupportedFileTypeError(BaseHTTPException): - error_code = "unsupported_file_type" - description = "File type not allowed." - code = 415 +class FileAccessDeniedError(BaseHTTPException): + error_code = "file_access_denied" + description = "Access to the requested file is denied." + code = 403 diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b0fd8e65ef..05f27545b3 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,37 +1,53 @@ from flask import request -from flask_restful import Resource, marshal_with +from flask_restx import Resource +from flask_restx.api import HTTPStatus import services -from controllers.common.errors import FilenameNotExistsError -from controllers.service_api import api -from controllers.service_api.app.error import ( +from controllers.common.errors import ( + FilenameNotExistsError, FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.service_api import service_api_ns from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token -from fields.file_fields import file_fields +from fields.file_fields import build_file_model from models.model import App, EndUser from services.file_service import FileService +@service_api_ns.route("/files/upload") class FileApi(Resource): + @service_api_ns.doc("upload_file") + @service_api_ns.doc(description="Upload a file for use in conversations") + @service_api_ns.doc( + responses={ + 201: "File uploaded successfully", + 400: "Bad request - no file or invalid file", + 401: "Unauthorized - invalid API token", + 413: "File too large", + 415: "Unsupported file type", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) - @marshal_with(file_fields) + @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App, end_user: EndUser): - file = request.files["file"] + """Upload a file for use in conversations. + Accepts a single file upload via multipart/form-data. + """ # check file if "file" not in request.files: raise NoFileUploadedError() - if not file.mimetype: - raise UnsupportedFileTypeError() - if len(request.files) > 1: raise TooManyFilesError() + file = request.files["file"] + if not file.mimetype: + raise UnsupportedFileTypeError() + if not file.filename: raise FilenameNotExistsError @@ -48,6 +64,3 @@ class FileApi(Resource): raise UnsupportedFileTypeError() return upload_file, 201 - - -api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py new file mode 100644 index 0000000000..84d80ea101 --- /dev/null +++ b/api/controllers/service_api/app/file_preview.py @@ -0,0 +1,187 @@ +import logging +from urllib.parse import quote + +from flask import Response +from flask_restx import Resource, reqparse + +from controllers.service_api import service_api_ns +from controllers.service_api.app.error import ( + FileAccessDeniedError, + FileNotFoundError, +) +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import App, EndUser, Message, MessageFile, UploadFile + +logger = logging.getLogger(__name__) + + +# Define parser for file preview API +file_preview_parser = reqparse.RequestParser() +file_preview_parser.add_argument( + "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment" +) + + +@service_api_ns.route("/files//preview") +class FilePreviewApi(Resource): + """ + Service API File Preview endpoint + + Provides secure file preview/download functionality for external API users. + Files can only be accessed if they belong to messages within the requesting app's context. + """ + + @service_api_ns.expect(file_preview_parser) + @service_api_ns.doc("preview_file") + @service_api_ns.doc(description="Preview or download a file uploaded via Service API") + @service_api_ns.doc(params={"file_id": "UUID of the file to preview"}) + @service_api_ns.doc( + responses={ + 200: "File retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - file access denied", + 404: "File not found", + } + ) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + def get(self, app_model: App, end_user: EndUser, file_id: str): + """ + Preview/Download a file that was uploaded via Service API. + + Provides secure file preview/download functionality. + Files can only be accessed if they belong to messages within the requesting app's context. + """ + file_id = str(file_id) + + # Parse query parameters + args = file_preview_parser.parse_args() + + # Validate file ownership and get file objects + message_file, upload_file = self._validate_file_ownership(file_id, app_model.id) + + # Get file content generator + try: + generator = storage.load(upload_file.key, stream=True) + except Exception as e: + raise FileNotFoundError(f"Failed to load file content: {str(e)}") + + # Build response with appropriate headers + response = self._build_file_response(generator, upload_file, args["as_attachment"]) + + return response + + def _validate_file_ownership(self, file_id: str, app_id: str) -> tuple[MessageFile, UploadFile]: + """ + Validate that the file belongs to a message within the requesting app's context + + Security validations performed: + 1. File exists in MessageFile table (was used in a conversation) + 2. Message belongs to the requesting app + 3. UploadFile record exists and is accessible + 4. File tenant matches app tenant (additional security layer) + + Args: + file_id: UUID of the file to validate + app_id: UUID of the requesting app + + Returns: + Tuple of (MessageFile, UploadFile) if validation passes + + Raises: + FileNotFoundError: File or related records not found + FileAccessDeniedError: File does not belong to the app's context + """ + try: + # Input validation + if not file_id or not app_id: + raise FileAccessDeniedError("Invalid file or app identifier") + + # First, find the MessageFile that references this upload file + message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first() + + if not message_file: + raise FileNotFoundError("File not found in message context") + + # Get the message and verify it belongs to the requesting app + message = ( + db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first() + ) + + if not message: + raise FileAccessDeniedError("File access denied: not owned by requesting app") + + # Get the actual upload file record + upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + raise FileNotFoundError("Upload file record not found") + + # Additional security: verify tenant isolation + app = db.session.query(App).where(App.id == app_id).first() + if app and upload_file.tenant_id != app.tenant_id: + raise FileAccessDeniedError("File access denied: tenant mismatch") + + return message_file, upload_file + + except (FileNotFoundError, FileAccessDeniedError): + # Re-raise our custom exceptions + raise + except Exception as e: + # Log unexpected errors for debugging + logger.exception( + "Unexpected error during file ownership validation", + extra={"file_id": file_id, "app_id": app_id, "error": str(e)}, + ) + raise FileAccessDeniedError("File access validation failed") + + def _build_file_response(self, generator, upload_file: UploadFile, as_attachment: bool = False) -> Response: + """ + Build Flask Response object with appropriate headers for file streaming + + Args: + generator: File content generator from storage + upload_file: UploadFile database record + as_attachment: Whether to set Content-Disposition as attachment + + Returns: + Flask Response object with streaming file content + """ + response = Response( + generator, + mimetype=upload_file.mime_type, + direct_passthrough=True, + headers={}, + ) + + # Add Content-Length if known + if upload_file.size and upload_file.size > 0: + response.headers["Content-Length"] = str(upload_file.size) + + # Add Accept-Ranges header for audio/video files to support seeking + if upload_file.mime_type in [ + "audio/mpeg", + "audio/wav", + "audio/mp4", + "audio/ogg", + "audio/flac", + "audio/aac", + "video/mp4", + "video/webm", + "video/quicktime", + "audio/x-m4a", + ]: + response.headers["Accept-Ranges"] = "bytes" + + # Set Content-Disposition for downloads + if as_attachment and upload_file.name: + encoded_filename = quote(upload_file.name) + response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" + # Override content-type for downloads to force download + response.headers["Content-Type"] = "application/octet-stream" + + # Add caching headers for performance + response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour + + return response diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index d90fa2081f..ad3fac7009 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,25 +1,58 @@ import json import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Api, Namespace, Resource, fields, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom -from fields.conversation_fields import message_file_fields -from fields.message_fields import agent_thought_fields, feedback_fields +from fields.conversation_fields import build_message_file_model +from fields.message_fields import build_agent_thought_model, build_feedback_model from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value from models.model import App, AppMode, EndUser -from services.errors.message import SuggestedQuestionsAfterAnswerDisabledError +from services.errors.message import ( + FirstMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) from services.message_service import MessageService +# Define parsers for message APIs +message_list_parser = reqparse.RequestParser() +message_list_parser.add_argument( + "conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID" +) +message_list_parser.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") +message_list_parser.add_argument( + "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of messages to return" +) -class MessageListApi(Resource): +message_feedback_parser = reqparse.RequestParser() +message_feedback_parser.add_argument( + "rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating" +) +message_feedback_parser.add_argument("content", type=str, location="json", help="Feedback content") + +feedback_list_parser = reqparse.RequestParser() +feedback_list_parser.add_argument("page", type=int, default=1, location="args", help="Page number") +feedback_list_parser.add_argument( + "limit", type=int_range(1, 101), required=False, default=20, location="args", help="Number of feedbacks per page" +) + + +def build_message_model(api_or_ns: Api | Namespace): + """Build the message model for the API or Namespace.""" + # First build the nested models + feedback_model = build_feedback_model(api_or_ns) + agent_thought_model = build_agent_thought_model(api_or_ns) + message_file_model = build_message_file_model(api_or_ns) + + # Then build the message fields with nested models message_fields = { "id": fields.String, "conversation_id": fields.String, @@ -27,37 +60,58 @@ class MessageListApi(Resource): "inputs": FilesContainedField, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), - "message_files": fields.List(fields.Nested(message_file_fields)), - "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "message_files": fields.List(fields.Nested(message_file_model)), + "feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True), "retriever_resources": fields.Raw( attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", []) if obj.message_metadata else [] ), "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), "status": fields.String, "error": fields.String, } + return api_or_ns.model("Message", message_fields) + + +def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): + """Build the message infinite scroll pagination model for the API or Namespace.""" + # Build the nested message model first + message_model = build_message_model(api_or_ns) message_infinite_scroll_pagination_fields = { "limit": fields.Integer, "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_fields)), + "data": fields.List(fields.Nested(message_model)), } + return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields) + +@service_api_ns.route("/messages") +class MessageListApi(Resource): + @service_api_ns.expect(message_list_parser) + @service_api_ns.doc("list_messages") + @service_api_ns.doc(description="List messages in a conversation") + @service_api_ns.doc( + responses={ + 200: "Messages retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Conversation or first message not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @marshal_with(message_infinite_scroll_pagination_fields) + @service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser): + """List messages in a conversation. + + Retrieves messages with pagination support using first_id. + """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() - parser = reqparse.RequestParser() - parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") - parser.add_argument("first_id", type=uuid_value, location="args") - parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") - args = parser.parse_args() + args = message_list_parser.parse_args() try: return MessageService.pagination_by_first_id( @@ -65,19 +119,32 @@ class MessageListApi(Resource): ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - except services.errors.message.FirstMessageNotExistsError: + except FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") +@service_api_ns.route("/messages//feedbacks") class MessageFeedbackApi(Resource): + @service_api_ns.expect(message_feedback_parser) + @service_api_ns.doc("create_message_feedback") + @service_api_ns.doc(description="Submit feedback for a message") + @service_api_ns.doc(params={"message_id": "Message ID"}) + @service_api_ns.doc( + responses={ + 200: "Feedback submitted successfully", + 401: "Unauthorized - invalid API token", + 404: "Message not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, message_id): + """Submit feedback for a message. + + Allows users to rate messages as like/dislike and provide optional feedback content. + """ message_id = str(message_id) - parser = reqparse.RequestParser() - parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") - parser.add_argument("content", type=str, location="json") - args = parser.parse_args() + args = message_feedback_parser.parse_args() try: MessageService.create_feedback( @@ -87,27 +154,54 @@ class MessageFeedbackApi(Resource): rating=args.get("rating"), content=args.get("content"), ) - except services.errors.message.MessageNotExistsError: + except MessageNotExistsError: raise NotFound("Message Not Exists.") return {"result": "success"} +@service_api_ns.route("/app/feedbacks") class AppGetFeedbacksApi(Resource): + @service_api_ns.expect(feedback_list_parser) + @service_api_ns.doc("get_app_feedbacks") + @service_api_ns.doc(description="Get all feedbacks for the application") + @service_api_ns.doc( + responses={ + 200: "Feedbacks retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token def get(self, app_model: App): - """Get All Feedbacks of an app""" - parser = reqparse.RequestParser() - parser.add_argument("page", type=int, default=1, location="args") - parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args") - args = parser.parse_args() + """Get all feedbacks for the application. + + Returns paginated list of all feedback submitted for messages in this app. + """ + args = feedback_list_parser.parse_args() feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"]) return {"data": feedbacks} +@service_api_ns.route("/messages//suggested") class MessageSuggestedApi(Resource): + @service_api_ns.doc("get_suggested_questions") + @service_api_ns.doc(description="Get suggested follow-up questions for a message") + @service_api_ns.doc(params={"message_id": "Message ID"}) + @service_api_ns.doc( + responses={ + 200: "Suggested questions retrieved successfully", + 400: "Suggested questions feature is disabled", + 401: "Unauthorized - invalid API token", + 404: "Message not found", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) def get(self, app_model: App, end_user: EndUser, message_id): + """Get suggested follow-up questions for a message. + + Returns AI-generated follow-up questions based on the message content. + """ message_id = str(message_id) app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -117,7 +211,7 @@ class MessageSuggestedApi(Resource): questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API ) - except services.errors.message.MessageNotExistsError: + except MessageNotExistsError: raise NotFound("Message Not Exists.") except SuggestedQuestionsAfterAnswerDisabledError: raise BadRequest("Suggested Questions Is Disabled.") @@ -126,9 +220,3 @@ class MessageSuggestedApi(Resource): raise InternalServerError() return {"result": "success", "data": questions} - - -api.add_resource(MessageListApi, "/messages") -api.add_resource(MessageFeedbackApi, "/messages//feedbacks") -api.add_resource(MessageSuggestedApi, "/messages//suggested") -api.add_resource(AppGetFeedbacksApi, "/app/feedbacks") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index c157b39f6b..9f8324a84e 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,30 +1,41 @@ -from flask_restful import Resource, marshal_with +from flask_restx import Resource from werkzeug.exceptions import Forbidden -from controllers.common import fields -from controllers.service_api import api +from controllers.common.fields import build_site_model +from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db from models.account import TenantStatus from models.model import App, Site +@service_api_ns.route("/site") class AppSiteApi(Resource): """Resource for app sites.""" + @service_api_ns.doc("get_app_site") + @service_api_ns.doc(description="Get application site configuration") + @service_api_ns.doc( + responses={ + 200: "Site configuration retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - site not found or tenant archived", + } + ) @validate_app_token - @marshal_with(fields.site_fields) + @service_api_ns.marshal_with(build_site_model(service_api_ns)) def get(self, app_model: App): - """Retrieve app site info.""" + """Retrieve app site info. + + Returns the site configuration for the application including theme, icons, and text. + """ site = db.session.query(Site).where(Site.app_id == app_model.id).first() if not site: raise Forbidden() + assert app_model.tenant if app_model.tenant.status == TenantStatus.ARCHIVE: raise Forbidden() return site - - -api.add_resource(AppSiteApi, "/site") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 370ff911b4..19e2e67d7f 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -2,12 +2,12 @@ import logging from dateutil.parser import isoparse from flask import request -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import Api, Namespace, Resource, fields, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session, sessionmaker -from werkzeug.exceptions import InternalServerError +from werkzeug.exceptions import BadRequest, InternalServerError, NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( CompletionRequestError, NotWorkflowAppError, @@ -28,17 +28,46 @@ from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from extensions.ext_database import db -from fields.workflow_app_log_fields import workflow_app_log_pagination_fields +from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from libs import helper from libs.helper import TimestampField from models.model import App, AppMode, EndUser from repositories.factory import DifyAPIRepositoryFactory from services.app_generate_service import AppGenerateService +from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) +# Define parsers for workflow APIs +workflow_run_parser = reqparse.RequestParser() +workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") +workflow_run_parser.add_argument("files", type=list, required=False, location="json") +workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + +workflow_log_parser = reqparse.RequestParser() +workflow_log_parser.add_argument("keyword", type=str, location="args") +workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") +workflow_log_parser.add_argument("created_at__before", type=str, location="args") +workflow_log_parser.add_argument("created_at__after", type=str, location="args") +workflow_log_parser.add_argument( + "created_by_end_user_session_id", + type=str, + location="args", + required=False, + default=None, +) +workflow_log_parser.add_argument( + "created_by_account", + type=str, + location="args", + required=False, + default=None, +) +workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") +workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") + workflow_run_fields = { "id": fields.String, "workflow_id": fields.String, @@ -54,12 +83,29 @@ workflow_run_fields = { } +def build_workflow_run_model(api_or_ns: Api | Namespace): + """Build the workflow run model for the API or Namespace.""" + return api_or_ns.model("WorkflowRun", workflow_run_fields) + + +@service_api_ns.route("/workflows/run/") class WorkflowRunDetailApi(Resource): + @service_api_ns.doc("get_workflow_run_detail") + @service_api_ns.doc(description="Get workflow run details") + @service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"}) + @service_api_ns.doc( + responses={ + 200: "Workflow run details retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Workflow run not found", + } + ) @validate_app_token - @marshal_with(workflow_run_fields) + @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) def get(self, app_model: App, workflow_run_id: str): - """ - Get a workflow task running detail + """Get a workflow task running detail. + + Returns detailed information about a specific workflow run. """ app_mode = AppMode.value_of(app_model.mode) if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: @@ -77,21 +123,33 @@ class WorkflowRunDetailApi(Resource): return workflow_run +@service_api_ns.route("/workflows/run") class WorkflowRunApi(Resource): + @service_api_ns.expect(workflow_run_parser) + @service_api_ns.doc("run_workflow") + @service_api_ns.doc(description="Execute a workflow") + @service_api_ns.doc( + responses={ + 200: "Workflow executed successfully", + 400: "Bad request - invalid parameters or workflow issues", + 401: "Unauthorized - invalid API token", + 404: "Workflow not found", + 429: "Rate limit exceeded", + 500: "Internal server error", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): - """ - Run workflow + """Execute a workflow. + + Runs a workflow with the provided inputs and returns the results. + Supports both blocking and streaming response modes. """ app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") - parser.add_argument("files", type=list, required=False, location="json") - parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") - args = parser.parse_args() + args = workflow_run_parser.parse_args() external_trace_id = get_external_trace_id(request) if external_trace_id: args["external_trace_id"] = external_trace_id @@ -120,12 +178,86 @@ class WorkflowRunApi(Resource): raise InternalServerError() +@service_api_ns.route("/workflows//run") +class WorkflowRunByIdApi(Resource): + @service_api_ns.expect(workflow_run_parser) + @service_api_ns.doc("run_workflow_by_id") + @service_api_ns.doc(description="Execute a specific workflow by ID") + @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"}) + @service_api_ns.doc( + responses={ + 200: "Workflow executed successfully", + 400: "Bad request - invalid parameters or workflow issues", + 401: "Unauthorized - invalid API token", + 404: "Workflow not found", + 429: "Rate limit exceeded", + 500: "Internal server error", + } + ) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser, workflow_id: str): + """Run specific workflow by ID. + + Executes a specific workflow version identified by its ID. + """ + app_mode = AppMode.value_of(app_model.mode) + if app_mode != AppMode.WORKFLOW: + raise NotWorkflowAppError() + + args = workflow_run_parser.parse_args() + + # Add workflow_id to args for AppGenerateService + args["workflow_id"] = workflow_id + + external_trace_id = get_external_trace_id(request) + if external_trace_id: + args["external_trace_id"] = external_trace_id + streaming = args.get("response_mode") == "streaming" + + try: + response = AppGenerateService.generate( + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming + ) + + return helper.compact_generate_response(response) + except WorkflowNotFoundError as ex: + raise NotFound(str(ex)) + except IsDraftWorkflowError as ex: + raise BadRequest(str(ex)) + except WorkflowIdFormatError as ex: + raise BadRequest(str(ex)) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + +@service_api_ns.route("/workflows/tasks//stop") class WorkflowTaskStopApi(Resource): + @service_api_ns.doc("stop_workflow_task") + @service_api_ns.doc(description="Stop a running workflow task") + @service_api_ns.doc(params={"task_id": "Task ID to stop"}) + @service_api_ns.doc( + responses={ + 200: "Task stopped successfully", + 401: "Unauthorized - invalid API token", + 404: "Task not found", + } + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id: str): - """ - Stop workflow task - """ + """Stop a running workflow task.""" app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() @@ -135,35 +267,25 @@ class WorkflowTaskStopApi(Resource): return {"result": "success"} +@service_api_ns.route("/workflows/logs") class WorkflowAppLogApi(Resource): + @service_api_ns.expect(workflow_log_parser) + @service_api_ns.doc("get_workflow_logs") + @service_api_ns.doc(description="Get workflow execution logs") + @service_api_ns.doc( + responses={ + 200: "Logs retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_app_token - @marshal_with(workflow_app_log_pagination_fields) + @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) def get(self, app_model: App): + """Get workflow app logs. + + Returns paginated workflow execution logs with filtering options. """ - Get workflow app logs - """ - parser = reqparse.RequestParser() - parser.add_argument("keyword", type=str, location="args") - parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") - parser.add_argument("created_at__before", type=str, location="args") - parser.add_argument("created_at__after", type=str, location="args") - parser.add_argument( - "created_by_end_user_session_id", - type=str, - location="args", - required=False, - default=None, - ) - parser.add_argument( - "created_by_account", - type=str, - location="args", - required=False, - default=None, - ) - parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") - parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") - args = parser.parse_args() + args = workflow_log_parser.parse_args() args.status = WorkflowExecutionStatus(args.status) if args.status else None if args.created_at__before: @@ -189,9 +311,3 @@ class WorkflowAppLogApi(Resource): ) return workflow_app_log_pagination - - -api.add_resource(WorkflowRunApi, "/workflows/run") -api.add_resource(WorkflowRunDetailApi, "/workflows/run/") -api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") -api.add_resource(WorkflowAppLogApi, "/workflows/logs") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index a499719fc3..c486b0480b 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,9 +1,11 @@ +from typing import Literal + from flask import request -from flask_restful import marshal, marshal_with, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, NotFound import services.dataset_service -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.wraps import ( DatasetApiResource, @@ -14,7 +16,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from fields.tag_fields import tag_fields +from fields.tag_fields import build_dataset_tag_fields from libs.login import current_user from models.dataset import Dataset, DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -29,17 +31,176 @@ def _validate_name(name): def _validate_description_length(description): - if len(description) > 400: + if description and len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description +# Define parsers for dataset operations +dataset_create_parser = reqparse.RequestParser() +dataset_create_parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, +) +dataset_create_parser.add_argument( + "description", + type=_validate_description_length, + nullable=True, + required=False, + default="", +) +dataset_create_parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + help="Invalid indexing technique.", +) +dataset_create_parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + required=False, + nullable=False, +) +dataset_create_parser.add_argument( + "external_knowledge_api_id", + type=str, + nullable=True, + required=False, + default="_validate_name", +) +dataset_create_parser.add_argument( + "provider", + type=str, + nullable=True, + required=False, + default="vendor", +) +dataset_create_parser.add_argument( + "external_knowledge_id", + type=str, + nullable=True, + required=False, +) +dataset_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") +dataset_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") +dataset_create_parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") + +dataset_update_parser = reqparse.RequestParser() +dataset_update_parser.add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, +) +dataset_update_parser.add_argument( + "description", location="json", store_missing=False, type=_validate_description_length +) +dataset_update_parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", +) +dataset_update_parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", +) +dataset_update_parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") +dataset_update_parser.add_argument( + "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." +) +dataset_update_parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") +dataset_update_parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") +dataset_update_parser.add_argument( + "external_retrieval_model", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid external retrieval model.", +) +dataset_update_parser.add_argument( + "external_knowledge_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge id.", +) +dataset_update_parser.add_argument( + "external_knowledge_api_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge api id.", +) + +tag_create_parser = reqparse.RequestParser() +tag_create_parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=lambda x: x + if x and 1 <= len(x) <= 50 + else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), +) + +tag_update_parser = reqparse.RequestParser() +tag_update_parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=lambda x: x + if x and 1 <= len(x) <= 50 + else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), +) +tag_update_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) + +tag_delete_parser = reqparse.RequestParser() +tag_delete_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) + +tag_binding_parser = reqparse.RequestParser() +tag_binding_parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." +) +tag_binding_parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." +) + +tag_unbinding_parser = reqparse.RequestParser() +tag_unbinding_parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") +tag_unbinding_parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + + +@service_api_ns.route("/datasets") class DatasetListApi(DatasetApiResource): """Resource for datasets.""" + @service_api_ns.doc("list_datasets") + @service_api_ns.doc(description="List all datasets") + @service_api_ns.doc( + responses={ + 200: "Datasets retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) # provider = request.args.get("provider", default="vendor") @@ -74,65 +235,20 @@ class DatasetListApi(DatasetApiResource): response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 + @service_api_ns.expect(dataset_create_parser) + @service_api_ns.doc("create_dataset") + @service_api_ns.doc(description="Create a new dataset") + @service_api_ns.doc( + responses={ + 200: "Dataset created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid parameters", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id): """Resource for creating datasets.""" - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument( - "description", - type=str, - nullable=True, - required=False, - default="", - ) - parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help="Invalid indexing technique.", - ) - parser.add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - required=False, - nullable=False, - ) - parser.add_argument( - "external_knowledge_api_id", - type=str, - nullable=True, - required=False, - default="_validate_name", - ) - parser.add_argument( - "provider", - type=str, - nullable=True, - required=False, - default="vendor", - ) - parser.add_argument( - "external_knowledge_id", - type=str, - nullable=True, - required=False, - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - - args = parser.parse_args() + args = dataset_create_parser.parse_args() if args.get("embedding_model_provider"): DatasetService.check_embedding_model_setting( @@ -172,9 +288,21 @@ class DatasetListApi(DatasetApiResource): return marshal(dataset, dataset_detail_fields), 200 +@service_api_ns.route("/datasets/") class DatasetApi(DatasetApiResource): """Resource for dataset.""" + @service_api_ns.doc("get_dataset") + @service_api_ns.doc(description="Get a specific dataset by ID") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Dataset retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Dataset not found", + } + ) def get(self, _, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -214,6 +342,18 @@ class DatasetApi(DatasetApiResource): return data, 200 + @service_api_ns.expect(dataset_update_parser) + @service_api_ns.doc("update_dataset") + @service_api_ns.doc(description="Update an existing dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Dataset updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, _, dataset_id): dataset_id_str = str(dataset_id) @@ -221,63 +361,7 @@ class DatasetApi(DatasetApiResource): if dataset is None: raise NotFound("Dataset not found.") - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - help="type is required. Name must be between 1 to 40 characters.", - type=_validate_name, - ) - parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) - parser.add_argument( - "indexing_technique", - type=str, - location="json", - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help="Invalid indexing technique.", - ) - parser.add_argument( - "permission", - type=str, - location="json", - choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), - help="Invalid permission.", - ) - parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") - parser.add_argument( - "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." - ) - parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") - parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") - - parser.add_argument( - "external_retrieval_model", - type=dict, - required=False, - nullable=True, - location="json", - help="Invalid external retrieval model.", - ) - - parser.add_argument( - "external_knowledge_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge id.", - ) - - parser.add_argument( - "external_knowledge_api_id", - type=str, - required=False, - nullable=True, - location="json", - help="Invalid external knowledge api id.", - ) - args = parser.parse_args() + args = dataset_update_parser.parse_args() data = request.get_json() # check embedding model setting @@ -325,6 +409,17 @@ class DatasetApi(DatasetApiResource): return result_data, 200 + @service_api_ns.doc("delete_dataset") + @service_api_ns.doc(description="Delete a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 204: "Dataset deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + 409: "Conflict - dataset is in use", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, _, dataset_id): """ @@ -355,17 +450,35 @@ class DatasetApi(DatasetApiResource): raise DatasetInUseError() +@service_api_ns.route("/datasets//documents/status/") class DocumentStatusApi(DatasetApiResource): """Resource for batch document status operations.""" - def patch(self, tenant_id, dataset_id, action): + @service_api_ns.doc("update_document_status") + @service_api_ns.doc(description="Batch update document status") + @service_api_ns.doc( + params={ + "dataset_id": "Dataset ID", + "action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'", + } + ) + @service_api_ns.doc( + responses={ + 200: "Document status updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Dataset not found", + 400: "Bad request - invalid action", + } + ) + def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): """ Batch update document status. Args: tenant_id: tenant id dataset_id: dataset id - action: action to perform (enable, disable, archive, un_archive) + action: action to perform (Literal["enable", "disable", "archive", "un_archive"]) Returns: dict: A dictionary with a key 'result' and a value 'success' @@ -405,53 +518,65 @@ class DocumentStatusApi(DatasetApiResource): return {"result": "success"}, 200 +@service_api_ns.route("/datasets/tags") class DatasetTagsApi(DatasetApiResource): + @service_api_ns.doc("list_dataset_tags") + @service_api_ns.doc(description="Get all knowledge type tags") + @service_api_ns.doc( + responses={ + 200: "Tags retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_dataset_token - @marshal_with(tag_fields) + @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _, dataset_id): """Get all knowledge type tags.""" tags = TagService.get_tags("knowledge", current_user.current_tenant_id) return tags, 200 + @service_api_ns.expect(tag_create_parser) + @service_api_ns.doc("create_dataset_tag") + @service_api_ns.doc(description="Add a knowledge type tag") + @service_api_ns.doc( + responses={ + 200: "Tag created successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) + @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def post(self, _, dataset_id): """Add a knowledge type tag.""" if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=DatasetTagsApi._validate_tag_name, - ) - - args = parser.parse_args() + args = tag_create_parser.parse_args() args["type"] = "knowledge" tag = TagService.save_tags(args) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} - return response, 200 + @service_api_ns.expect(tag_update_parser) + @service_api_ns.doc("update_dataset_tag") + @service_api_ns.doc(description="Update a knowledge type tag") + @service_api_ns.doc( + responses={ + 200: "Tag updated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) + @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def patch(self, _, dataset_id): if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "name", - nullable=False, - required=True, - help="Name must be between 1 to 50 characters.", - type=DatasetTagsApi._validate_tag_name, - ) - parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) - args = parser.parse_args() + args = tag_update_parser.parse_args() args["type"] = "knowledge" tag = TagService.update_tags(args, args.get("tag_id")) @@ -461,66 +586,88 @@ class DatasetTagsApi(DatasetApiResource): return response, 200 + @service_api_ns.expect(tag_delete_parser) + @service_api_ns.doc("delete_dataset_tag") + @service_api_ns.doc(description="Delete a knowledge type tag") + @service_api_ns.doc( + responses={ + 204: "Tag deleted successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) @validate_dataset_token def delete(self, _, dataset_id): """Delete a knowledge type tag.""" if not current_user.is_editor: raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) - args = parser.parse_args() + args = tag_delete_parser.parse_args() TagService.delete_tag(args.get("tag_id")) return 204 - @staticmethod - def _validate_tag_name(name): - if not name or len(name) < 1 or len(name) > 50: - raise ValueError("Name must be between 1 to 50 characters.") - return name - +@service_api_ns.route("/datasets/tags/binding") class DatasetTagBindingApi(DatasetApiResource): + @service_api_ns.expect(tag_binding_parser) + @service_api_ns.doc("bind_dataset_tags") + @service_api_ns.doc(description="Bind tags to a dataset") + @service_api_ns.doc( + responses={ + 204: "Tags bound successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument( - "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." - ) - parser.add_argument( - "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." - ) - - args = parser.parse_args() + args = tag_binding_parser.parse_args() args["type"] = "knowledge" TagService.save_tag_binding(args) return 204 +@service_api_ns.route("/datasets/tags/unbinding") class DatasetTagUnbindingApi(DatasetApiResource): + @service_api_ns.expect(tag_unbinding_parser) + @service_api_ns.doc("unbind_dataset_tag") + @service_api_ns.doc(description="Unbind a tag from a dataset") + @service_api_ns.doc( + responses={ + 204: "Tag unbound successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + } + ) @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.is_editor or current_user.is_dataset_editor): raise Forbidden() - parser = reqparse.RequestParser() - parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") - parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") - - args = parser.parse_args() + args = tag_unbinding_parser.parse_args() args["type"] = "knowledge" TagService.delete_tag_binding(args) return 204 +@service_api_ns.route("/datasets//tags") class DatasetTagsBindingStatusApi(DatasetApiResource): + @service_api_ns.doc("get_dataset_tags_binding_status") + @service_api_ns.doc(description="Get tags bound to a specific dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Tags retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_dataset_token def get(self, _, *args, **kwargs): """Get all knowledge type tags.""" @@ -529,12 +676,3 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] response = {"data": tags_list, "total": len(tags)} return response, 200 - - -api.add_resource(DatasetListApi, "/datasets") -api.add_resource(DatasetApi, "/datasets/") -api.add_resource(DocumentStatusApi, "/datasets//documents/status/") -api.add_resource(DatasetTagsApi, "/datasets/tags") -api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") -api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") -api.add_resource(DatasetTagsBindingStatusApi, "/datasets//tags") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ac85c0b38d..43232229c8 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,20 +1,20 @@ import json from flask import request -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from sqlalchemy import desc, select from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.common.errors import FilenameNotExistsError -from controllers.service_api import api -from controllers.service_api.app.error import ( +from controllers.common.errors import ( + FilenameNotExistsError, FileTooLargeError, NoFileUploadedError, - ProviderNotInitializeError, TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.service_api import service_api_ns +from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, DocumentIndexingError, @@ -34,32 +34,64 @@ from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService +# Define parsers for document operations +document_text_create_parser = reqparse.RequestParser() +document_text_create_parser.add_argument("name", type=str, required=True, nullable=False, location="json") +document_text_create_parser.add_argument("text", type=str, required=True, nullable=False, location="json") +document_text_create_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") +document_text_create_parser.add_argument("original_document_id", type=str, required=False, location="json") +document_text_create_parser.add_argument( + "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" +) +document_text_create_parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" +) +document_text_create_parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" +) +document_text_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") +document_text_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") +document_text_create_parser.add_argument( + "embedding_model_provider", type=str, required=False, nullable=True, location="json" +) +document_text_update_parser = reqparse.RequestParser() +document_text_update_parser.add_argument("name", type=str, required=False, nullable=True, location="json") +document_text_update_parser.add_argument("text", type=str, required=False, nullable=True, location="json") +document_text_update_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") +document_text_update_parser.add_argument( + "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" +) +document_text_update_parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" +) +document_text_update_parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + + +@service_api_ns.route( + "/datasets//document/create_by_text", + "/datasets//document/create-by-text", +) class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" + @service_api_ns.expect(document_text_create_parser) + @service_api_ns.doc("create_document_by_text") + @service_api_ns.doc(description="Create a new document by providing text content") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Document created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid parameters", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=False, location="json") - parser.add_argument("text", type=str, required=True, nullable=False, location="json") - parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") - parser.add_argument("original_document_id", type=str, required=False, location="json") - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" - ) - parser.add_argument( - "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") - parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") - parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") - - args = parser.parse_args() + args = document_text_create_parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -117,23 +149,29 @@ class DocumentAddByTextApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route( + "/datasets//documents//update_by_text", + "/datasets//documents//update-by-text", +) class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" + @service_api_ns.expect(document_text_update_parser) + @service_api_ns.doc("update_document_by_text") + @service_api_ns.doc(description="Update an existing document by providing text content") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=False, nullable=True, location="json") - parser.add_argument("text", type=str, required=False, nullable=True, location="json") - parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") - parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") - parser.add_argument( - "doc_language", type=str, default="English", required=False, nullable=False, location="json" - ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") - args = parser.parse_args() + args = document_text_update_parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -187,9 +225,23 @@ class DocumentUpdateByTextApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route( + "/datasets//document/create_by_file", + "/datasets//document/create-by-file", +) class DocumentAddByFileApi(DatasetApiResource): """Resource for documents.""" + @service_api_ns.doc("create_document_by_file") + @service_api_ns.doc(description="Create a new document by uploading a file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Document created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid file or parameters", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") @@ -234,8 +286,6 @@ class DocumentAddByFileApi(DatasetApiResource): args["retrieval_model"].get("reranking_model").get("reranking_model_name"), ) - # save file info - file = request.files["file"] # check file if "file" not in request.files: raise NoFileUploadedError() @@ -243,6 +293,8 @@ class DocumentAddByFileApi(DatasetApiResource): if len(request.files) > 1: raise TooManyFilesError() + # save file info + file = request.files["file"] if not file.filename: raise FilenameNotExistsError @@ -281,9 +333,23 @@ class DocumentAddByFileApi(DatasetApiResource): return documents_and_batch_fields, 200 +@service_api_ns.route( + "/datasets//documents//update_by_file", + "/datasets//documents//update-by-file", +) class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" + @service_api_ns.doc("update_document_by_file") + @service_api_ns.doc(description="Update an existing document by uploading a file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): @@ -358,40 +424,18 @@ class DocumentUpdateByFileApi(DatasetApiResource): return documents_and_batch_fields, 200 -class DocumentDeleteApi(DatasetApiResource): - @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, document_id): - """Delete document.""" - document_id = str(document_id) - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - - # get dataset info - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - - if not dataset: - raise ValueError("Dataset does not exist.") - - document = DocumentService.get_document(dataset.id, document_id) - - # 404 if document not found - if document is None: - raise NotFound("Document Not Exists.") - - # 403 if document is archived - if DocumentService.check_archived(document): - raise ArchivedDocumentImmutableError() - - try: - # delete document - DocumentService.delete_document(document) - except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError("Cannot delete document during indexing.") - - return 204 - - +@service_api_ns.route("/datasets//documents") class DocumentListApi(DatasetApiResource): + @service_api_ns.doc("list_documents") + @service_api_ns.doc(description="List all documents in a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Documents retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -424,7 +468,18 @@ class DocumentListApi(DatasetApiResource): return response +@service_api_ns.route("/datasets//documents//indexing-status") class DocumentIndexingStatusApi(DatasetApiResource): + @service_api_ns.doc("get_document_indexing_status") + @service_api_ns.doc(description="Get indexing status for documents in a batch") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "batch": "Batch ID"}) + @service_api_ns.doc( + responses={ + 200: "Indexing status retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or documents not found", + } + ) def get(self, tenant_id, dataset_id, batch): dataset_id = str(dataset_id) batch = str(batch) @@ -473,9 +528,21 @@ class DocumentIndexingStatusApi(DatasetApiResource): return data -class DocumentDetailApi(DatasetApiResource): +@service_api_ns.route("/datasets//documents/") +class DocumentApi(DatasetApiResource): METADATA_CHOICES = {"all", "only", "without"} + @service_api_ns.doc("get_document") + @service_api_ns.doc(description="Get a specific document by ID") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document retrieved successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Document not found", + } + ) def get(self, tenant_id, dataset_id, document_id): dataset_id = str(dataset_id) document_id = str(document_id) @@ -567,28 +634,44 @@ class DocumentDetailApi(DatasetApiResource): return response + @service_api_ns.doc("delete_document") + @service_api_ns.doc(description="Delete a document") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 204: "Document deleted successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - document is archived", + 404: "Document not found", + } + ) + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def delete(self, tenant_id, dataset_id, document_id): + """Delete document.""" + document_id = str(document_id) + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) -api.add_resource( - DocumentAddByTextApi, - "/datasets//document/create_by_text", - "/datasets//document/create-by-text", -) -api.add_resource( - DocumentAddByFileApi, - "/datasets//document/create_by_file", - "/datasets//document/create-by-file", -) -api.add_resource( - DocumentUpdateByTextApi, - "/datasets//documents//update_by_text", - "/datasets//documents//update-by-text", -) -api.add_resource( - DocumentUpdateByFileApi, - "/datasets//documents//update_by_file", - "/datasets//documents//update-by-file", -) -api.add_resource(DocumentDeleteApi, "/datasets//documents/") -api.add_resource(DocumentListApi, "/datasets//documents") -api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") -api.add_resource(DocumentDetailApi, "/datasets//documents/") + # get dataset info + dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + + if not dataset: + raise ValueError("Dataset does not exist.") + + document = DocumentService.get_document(dataset.id, document_id) + + # 404 if document not found + if document is None: + raise NotFound("Document Not Exists.") + + # 403 if document is archived + if DocumentService.check_archived(document): + raise ArchivedDocumentImmutableError() + + try: + # delete document + DocumentService.delete_document(document) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError("Cannot delete document during indexing.") + + return 204 diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index ecc47b40a1..e4214a16ad 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -1,30 +1,6 @@ from libs.exception import BaseHTTPException -class NoFileUploadedError(BaseHTTPException): - error_code = "no_file_uploaded" - description = "Please upload your file." - code = 400 - - -class TooManyFilesError(BaseHTTPException): - error_code = "too_many_files" - description = "Only one file is allowed." - code = 400 - - -class FileTooLargeError(BaseHTTPException): - error_code = "file_too_large" - description = "File size exceeded. {message}" - code = 413 - - -class UnsupportedFileTypeError(BaseHTTPException): - error_code = "unsupported_file_type" - description = "File type not allowed." - code = 415 - - class DatasetNotInitializedError(BaseHTTPException): error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 52e9bca5da..d81287d56f 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,11 +1,26 @@ from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +@service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): + @service_api_ns.doc("dataset_hit_testing") + @service_api_ns.doc(description="Perform hit testing on a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Hit testing results", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): + """Perform hit testing on a dataset. + + Tests retrieval performance for the specified dataset. + """ dataset_id_str = str(dataset_id) dataset = self.get_and_validate_dataset(dataset_id_str) @@ -13,6 +28,3 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): self.hit_testing_args_check(args) return self.perform_hit_testing(dataset, args) - - -api.add_resource(HitTestingApi, "/datasets//hit-testing", "/datasets//retrieve") diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 1968696ee5..9defe6af03 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,8 +1,10 @@ +from typing import Literal + from flask_login import current_user # type: ignore -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from fields.dataset_fields import dataset_metadata_fields from services.dataset_service import DatasetService @@ -12,14 +14,43 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.metadata_service import MetadataService +# Define parsers for metadata APIs +metadata_create_parser = reqparse.RequestParser() +metadata_create_parser.add_argument( + "type", type=str, required=True, nullable=False, location="json", help="Metadata type" +) +metadata_create_parser.add_argument( + "name", type=str, required=True, nullable=False, location="json", help="Metadata name" +) +metadata_update_parser = reqparse.RequestParser() +metadata_update_parser.add_argument( + "name", type=str, required=True, nullable=False, location="json", help="New metadata name" +) + +document_metadata_parser = reqparse.RequestParser() +document_metadata_parser.add_argument( + "operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" +) + + +@service_api_ns.route("/datasets//metadata") class DatasetMetadataCreateServiceApi(DatasetApiResource): + @service_api_ns.expect(metadata_create_parser) + @service_api_ns.doc("create_dataset_metadata") + @service_api_ns.doc(description="Create metadata for a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 201: "Metadata created successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, required=True, nullable=True, location="json") - parser.add_argument("name", type=str, required=True, nullable=True, location="json") - args = parser.parse_args() + """Create metadata for a dataset.""" + args = metadata_create_parser.parse_args() metadata_args = MetadataArgs(**args) dataset_id_str = str(dataset_id) @@ -31,7 +62,18 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) return marshal(metadata, dataset_metadata_fields), 201 + @service_api_ns.doc("get_dataset_metadata") + @service_api_ns.doc(description="Get all metadata for a dataset") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Metadata retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) def get(self, tenant_id, dataset_id): + """Get all metadata for a dataset.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -39,12 +81,23 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): return MetadataService.get_dataset_metadatas(dataset), 200 +@service_api_ns.route("/datasets//metadata/") class DatasetMetadataServiceApi(DatasetApiResource): + @service_api_ns.expect(metadata_update_parser) + @service_api_ns.doc("update_dataset_metadata") + @service_api_ns.doc(description="Update metadata name") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) + @service_api_ns.doc( + responses={ + 200: "Metadata updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or metadata not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, metadata_id): - parser = reqparse.RequestParser() - parser.add_argument("name", type=str, required=True, nullable=True, location="json") - args = parser.parse_args() + """Update metadata name.""" + args = metadata_update_parser.parse_args() dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -56,8 +109,19 @@ class DatasetMetadataServiceApi(DatasetApiResource): metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) return marshal(metadata, dataset_metadata_fields), 200 + @service_api_ns.doc("delete_dataset_metadata") + @service_api_ns.doc(description="Delete metadata") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) + @service_api_ns.doc( + responses={ + 204: "Metadata deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or metadata not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, metadata_id): + """Delete metadata.""" dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -69,15 +133,37 @@ class DatasetMetadataServiceApi(DatasetApiResource): return 204 +@service_api_ns.route("/datasets/metadata/built-in") class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): + @service_api_ns.doc("get_built_in_fields") + @service_api_ns.doc(description="Get all built-in metadata fields") + @service_api_ns.doc( + responses={ + 200: "Built-in fields retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) def get(self, tenant_id): + """Get all built-in metadata fields.""" built_in_fields = MetadataService.get_built_in_fields() return {"fields": built_in_fields}, 200 +@service_api_ns.route("/datasets//metadata/built-in/") class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): + @service_api_ns.doc("toggle_built_in_field") + @service_api_ns.doc(description="Enable or disable built-in metadata field") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": "Action to perform: 'enable' or 'disable'"}) + @service_api_ns.doc( + responses={ + 200: "Action completed successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, action): + def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]): + """Enable or disable built-in metadata field.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -91,29 +177,31 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): return 200 +@service_api_ns.route("/datasets//documents/metadata") class DocumentMetadataEditServiceApi(DatasetApiResource): + @service_api_ns.expect(document_metadata_parser) + @service_api_ns.doc("update_documents_metadata") + @service_api_ns.doc(description="Update metadata for multiple documents") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Documents metadata updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): + """Update metadata for multiple documents.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - parser = reqparse.RequestParser() - parser.add_argument("operation_data", type=list, required=True, nullable=True, location="json") - args = parser.parse_args() + args = document_metadata_parser.parse_args() metadata_args = MetadataOperationData(**args) MetadataService.update_documents_metadata(dataset, metadata_args) return 200 - - -api.add_resource(DatasetMetadataCreateServiceApi, "/datasets//metadata") -api.add_resource(DatasetMetadataServiceApi, "/datasets//metadata/") -api.add_resource(DatasetMetadataBuiltInFieldServiceApi, "/datasets/metadata/built-in") -api.add_resource( - DatasetMetadataBuiltInFieldActionServiceApi, "/datasets//metadata/built-in/" -) -api.add_resource(DocumentMetadataEditServiceApi, "/datasets//documents/metadata") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 31f862dc8f..f5e2010ca4 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,9 +1,9 @@ from flask import request from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_restx import marshal, reqparse from werkzeug.exceptions import NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import ( DatasetApiResource, @@ -19,34 +19,59 @@ from fields.segment_fields import child_chunk_fields, segment_fields from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs -from services.errors.chunk import ( - ChildChunkDeleteIndexError, - ChildChunkIndexingError, -) -from services.errors.chunk import ( - ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError, -) -from services.errors.chunk import ( - ChildChunkIndexingError as ChildChunkIndexingServiceError, -) +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError +from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError +from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError + +# Define parsers for segment operations +segment_create_parser = reqparse.RequestParser() +segment_create_parser.add_argument("segments", type=list, required=False, nullable=True, location="json") + +segment_list_parser = reqparse.RequestParser() +segment_list_parser.add_argument("status", type=str, action="append", default=[], location="args") +segment_list_parser.add_argument("keyword", type=str, default=None, location="args") + +segment_update_parser = reqparse.RequestParser() +segment_update_parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") + +child_chunk_create_parser = reqparse.RequestParser() +child_chunk_create_parser.add_argument("content", type=str, required=True, nullable=False, location="json") + +child_chunk_list_parser = reqparse.RequestParser() +child_chunk_list_parser.add_argument("limit", type=int, default=20, location="args") +child_chunk_list_parser.add_argument("keyword", type=str, default=None, location="args") +child_chunk_list_parser.add_argument("page", type=int, default=1, location="args") + +child_chunk_update_parser = reqparse.RequestParser() +child_chunk_update_parser.add_argument("content", type=str, required=True, nullable=False, location="json") +@service_api_ns.route("/datasets//documents//segments") class SegmentApi(DatasetApiResource): """Resource for segments.""" + @service_api_ns.expect(segment_create_parser) + @service_api_ns.doc("create_segments") + @service_api_ns.doc(description="Create segments in a document") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Segments created successfully", + 400: "Bad request - segments data is missing", + 401: "Unauthorized - invalid API token", + 404: "Dataset or document not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id): + def post(self, tenant_id: str, dataset_id: str, document_id: str): """Create single segment.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") @@ -71,9 +96,7 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args - parser = reqparse.RequestParser() - parser.add_argument("segments", type=list, required=False, nullable=True, location="json") - args = parser.parse_args() + args = segment_create_parser.parse_args() if args["segments"] is not None: for args_item in args["segments"]: SegmentService.segment_create_args_validate(args_item, document) @@ -82,18 +105,26 @@ class SegmentApi(DatasetApiResource): else: return {"error": "Segments is required"}, 400 - def get(self, tenant_id, dataset_id, document_id): + @service_api_ns.expect(segment_list_parser) + @service_api_ns.doc("list_segments") + @service_api_ns.doc(description="List segments in a document") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Segments retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset or document not found", + } + ) + def get(self, tenant_id: str, dataset_id: str, document_id: str): """Get segments.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") @@ -114,10 +145,7 @@ class SegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) - parser = reqparse.RequestParser() - parser.add_argument("status", type=str, action="append", default=[], location="args") - parser.add_argument("keyword", type=str, default=None, location="args") - args = parser.parse_args() + args = segment_list_parser.parse_args() segments, total = SegmentService.get_segments( document_id=document_id, @@ -140,43 +168,62 @@ class SegmentApi(DatasetApiResource): return response, 200 +@service_api_ns.route("/datasets//documents//segments/") class DatasetSegmentApi(DatasetApiResource): + @service_api_ns.doc("delete_segment") + @service_api_ns.doc(description="Delete a specific segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to delete"} + ) + @service_api_ns.doc( + responses={ + 204: "Segment deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, document_id, segment_id): + def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) return 204 + @service_api_ns.expect(segment_update_parser) + @service_api_ns.doc("update_segment") + @service_api_ns.doc(description="Update a specific segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to update"} + ) + @service_api_ns.doc( + responses={ + 200: "Segment updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id, segment_id): + def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") @@ -197,37 +244,39 @@ class DatasetSegmentApi(DatasetApiResource): except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") # validate args - parser = reqparse.RequestParser() - parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") - args = parser.parse_args() + args = segment_update_parser.parse_args() updated_segment = SegmentService.update_segment( SegmentUpdateArgs(**args["segment"]), segment, document, dataset ) return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 - def get(self, tenant_id, dataset_id, document_id, segment_id): + @service_api_ns.doc("get_segment") + @service_api_ns.doc(description="Get a specific segment by ID") + @service_api_ns.doc( + responses={ + 200: "Segment retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) + def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document - document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -235,29 +284,41 @@ class DatasetSegmentApi(DatasetApiResource): return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 +@service_api_ns.route( + "/datasets//documents//segments//child_chunks" +) class ChildChunkApi(DatasetApiResource): """Resource for child chunks.""" + @service_api_ns.expect(child_chunk_create_parser) + @service_api_ns.doc("create_child_chunk") + @service_api_ns.doc(description="Create a new child chunk for a segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"} + ) + @service_api_ns.doc( + responses={ + 200: "Child chunk created successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id, segment_id): + def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): """Create child chunk.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") @@ -280,43 +341,46 @@ class ChildChunkApi(DatasetApiResource): raise ProviderNotInitializeError(ex.description) # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + args = child_chunk_create_parser.parse_args() try: - child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 - def get(self, tenant_id, dataset_id, document_id, segment_id): + @service_api_ns.expect(child_chunk_list_parser) + @service_api_ns.doc("list_child_chunks") + @service_api_ns.doc(description="List child chunks for a segment") + @service_api_ns.doc( + params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"} + ) + @service_api_ns.doc( + responses={ + 200: "Child chunks retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or segment not found", + } + ) + def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): """Get child chunks.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") - parser = reqparse.RequestParser() - parser.add_argument("limit", type=int, default=20, location="args") - parser.add_argument("keyword", type=str, default=None, location="args") - parser.add_argument("page", type=int, default=1, location="args") - args = parser.parse_args() + args = child_chunk_list_parser.parse_args() page = args["page"] limit = min(args["limit"], 100) @@ -333,40 +397,63 @@ class ChildChunkApi(DatasetApiResource): }, 200 +@service_api_ns.route( + "/datasets//documents//segments//child_chunks/" +) class DatasetChildChunkApi(DatasetApiResource): """Resource for updating child chunks.""" + @service_api_ns.doc("delete_child_chunk") + @service_api_ns.doc(description="Delete a specific child chunk") + @service_api_ns.doc( + params={ + "dataset_id": "Dataset ID", + "document_id": "Document ID", + "segment_id": "Parent segment ID", + "child_chunk_id": "Child chunk ID to delete", + } + ) + @service_api_ns.doc( + responses={ + 204: "Child chunk deleted successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, segment, or child chunk not found", + } + ) @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): + def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): """Delete child chunk.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") # check document - document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") # check segment - segment_id = str(segment_id) segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) if not segment: raise NotFound("Segment not found.") + # validate segment belongs to the specified document + if segment.document_id != document_id: + raise NotFound("Document not found.") + # check child chunk - child_chunk_id = str(child_chunk_id) child_chunk = SegmentService.get_child_chunk_by_id( child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id ) if not child_chunk: raise NotFound("Child chunk not found.") + # validate child chunk belongs to the specified segment + if child_chunk.segment_id != segment.id: + raise NotFound("Child chunk not found.") + try: SegmentService.delete_child_chunk(child_chunk, dataset) except ChildChunkDeleteIndexServiceError as e: @@ -374,14 +461,30 @@ class DatasetChildChunkApi(DatasetApiResource): return 204 + @service_api_ns.expect(child_chunk_update_parser) + @service_api_ns.doc("update_child_chunk") + @service_api_ns.doc(description="Update a specific child chunk") + @service_api_ns.doc( + params={ + "dataset_id": "Dataset ID", + "document_id": "Document ID", + "segment_id": "Parent segment ID", + "child_chunk_id": "Child chunk ID to update", + } + ) + @service_api_ns.doc( + responses={ + 200: "Child chunk updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, segment, or child chunk not found", + } + ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): + def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): """Update child chunk.""" # check dataset - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: raise NotFound("Dataset not found.") @@ -396,6 +499,10 @@ class DatasetChildChunkApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") + # validate segment belongs to the specified document + if segment.document_id != document_id: + raise NotFound("Segment not found.") + # get child chunk child_chunk = SegmentService.get_child_chunk_by_id( child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id @@ -403,29 +510,16 @@ class DatasetChildChunkApi(DatasetApiResource): if not child_chunk: raise NotFound("Child chunk not found.") + # validate child chunk belongs to the specified segment + if child_chunk.segment_id != segment.id: + raise NotFound("Child chunk not found.") + # validate args - parser = reqparse.RequestParser() - parser.add_argument("content", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() + args = child_chunk_update_parser.parse_args() try: - child_chunk = SegmentService.update_child_chunk( - args.get("content"), child_chunk, segment, document, dataset - ) + child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 - - -api.add_resource(SegmentApi, "/datasets//documents//segments") -api.add_resource( - DatasetSegmentApi, "/datasets//documents//segments/" -) -api.add_resource( - ChildChunkApi, "/datasets//documents//segments//child_chunks" -) -api.add_resource( - DatasetChildChunkApi, - "/datasets//documents//segments//child_chunks/", -) diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py index 3b4721b5b0..27b36a6402 100644 --- a/api/controllers/service_api/dataset/upload_file.py +++ b/api/controllers/service_api/dataset/upload_file.py @@ -1,6 +1,6 @@ from werkzeug.exceptions import NotFound -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import ( DatasetApiResource, ) @@ -11,9 +11,23 @@ from models.model import UploadFile from services.dataset_service import DocumentService +@service_api_ns.route("/datasets//documents//upload-file") class UploadFileApi(DatasetApiResource): + @service_api_ns.doc("get_upload_file") + @service_api_ns.doc(description="Get upload file information and download URL") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Upload file information retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Dataset, document, or upload file not found", + } + ) def get(self, tenant_id, dataset_id, document_id): - """Get upload file.""" + """Get upload file information and download URL. + + Returns information about an uploaded file including its download URL. + """ # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) @@ -49,6 +63,3 @@ class UploadFileApi(DatasetApiResource): "created_by": upload_file.created_by, "created_at": upload_file.created_at.timestamp(), }, 200 - - -api.add_resource(UploadFileApi, "/datasets//documents//upload-file") diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index 9bb5df4c4e..a9d2d6fadc 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,9 +1,10 @@ -from flask_restful import Resource +from flask_restx import Resource from configs import dify_config -from controllers.service_api import api +from controllers.service_api import service_api_ns +@service_api_ns.route("/") class IndexApi(Resource): def get(self): return { @@ -11,6 +12,3 @@ class IndexApi(Resource): "api_version": "v1", "server_version": dify_config.project.version, } - - -api.add_resource(IndexApi, "/") diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 3f18474674..536cf81a2f 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,21 +1,32 @@ from flask_login import current_user -from flask_restful import Resource +from flask_restx import Resource -from controllers.service_api import api +from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token from core.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService +@service_api_ns.route("/workspaces/current/models/model-types/") class ModelProviderAvailableModelApi(Resource): + @service_api_ns.doc("get_available_models") + @service_api_ns.doc(description="Get available models by model type") + @service_api_ns.doc(params={"model_type": "Type of model to retrieve"}) + @service_api_ns.doc( + responses={ + 200: "Models retrieved successfully", + 401: "Unauthorized - invalid API token", + } + ) @validate_dataset_token def get(self, _, model_type): + """Get available models by model type. + + Returns a list of available models for the specified model type. + """ tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) return jsonable_encoder({"data": models}) - - -api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index da81cc8bc3..8aac3de4c3 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -7,7 +7,7 @@ from typing import Optional from flask import current_app, request from flask_login import user_logged_in # type: ignore -from flask_restful import Resource +from flask_restx import Resource from pydantic import BaseModel from sqlalchemy import select, update from sqlalchemy.orm import Session diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 94a525a75d..0680903635 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,5 +1,8 @@ +import logging + from flask import request -from flask_restful import Resource, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse +from werkzeug.exceptions import Unauthorized from controllers.common import fields from controllers.web import api @@ -75,19 +78,22 @@ class AppWebAuthPermission(Resource): try: auth_header = request.headers.get("Authorization") if auth_header is None: - raise + raise Unauthorized("Authorization header is missing.") if " " not in auth_header: - raise + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": - raise + raise Unauthorized("Authorization scheme must be 'Bearer'") decoded = PassportService().verify(tk) user_id = decoded.get("user_id", "visitor") - except Exception as e: - pass + except Unauthorized: + raise + except Exception: + logging.exception("Unexpected error during auth verification") + raise features = FeatureService.get_system_features() if not features.webapp_auth.enabled: diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2919ca9af4..241d0874db 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -65,7 +65,7 @@ class AudioApi(WebApiResource): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - from flask_restful import reqparse + from flask_restx import reqparse try: parser = reqparse.RequestParser() diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index fd3b9aa804..c19afee9b7 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound import services diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 98cea3974f..cea8e442f3 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,5 +1,5 @@ -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import marshal_with, reqparse +from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 036e11d5c5..196a27e348 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -97,30 +97,6 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException): code = 400 -class NoFileUploadedError(BaseHTTPException): - error_code = "no_file_uploaded" - description = "Please upload your file." - code = 400 - - -class TooManyFilesError(BaseHTTPException): - error_code = "too_many_files" - description = "Only one file is allowed." - code = 400 - - -class FileTooLargeError(BaseHTTPException): - error_code = "file_too_large" - description = "File size exceeded. {message}" - code = 413 - - -class UnsupportedFileTypeError(BaseHTTPException): - error_code = "unsupported_file_type" - description = "File type not allowed." - code = 415 - - class WebAppAuthRequiredError(BaseHTTPException): error_code = "web_sso_auth_required" description = "Web app authentication required." diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 0563ed2238..478b3d2e31 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restx import Resource from controllers.web import api from services.feature_service import FeatureService diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index df06a73a85..b05e2a2e65 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,9 +1,14 @@ from flask import request -from flask_restful import marshal_with +from flask_restx import marshal_with import services -from controllers.common.errors import FilenameNotExistsError -from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError +from controllers.common.errors import ( + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.web.wraps import WebApiResource from fields.file_fields import file_fields from services.file_service import FileService @@ -12,18 +17,17 @@ from services.file_service import FileService class FileApi(WebApiResource): @marshal_with(file_fields) def post(self, app_model, end_user): - file = request.files["file"] - source = request.form.get("source") - if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() + file = request.files["file"] if not file.filename: raise FilenameNotExistsError + source = request.form.get("source") if source not in ("datasets", None): source = None diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 0da8d65efc..d436657f06 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -2,7 +2,7 @@ import base64 import secrets from flask import request -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 01c4f4a262..d4eafd532b 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restx import Resource, reqparse from jwt import InvalidTokenError # type: ignore import services diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index f2e1873601..f348221d80 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,10 +1,9 @@ import logging -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import fields, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound -import services from controllers.web import api from controllers.web.error import ( AppMoreLikeThisDisabledError, @@ -29,7 +28,11 @@ from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError -from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError +from services.errors.message import ( + FirstMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) from services.message_service import MessageService @@ -73,9 +76,9 @@ class MessageListApi(WebApiResource): return MessageService.pagination_by_first_id( app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] ) - except services.errors.conversation.ConversationNotExistsError: + except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - except services.errors.message.FirstMessageNotExistsError: + except FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") @@ -96,7 +99,7 @@ class MessageFeedbackApi(WebApiResource): rating=args.get("rating"), content=args.get("content"), ) - except services.errors.message.MessageNotExistsError: + except MessageNotExistsError: raise NotFound("Message Not Exists.") return {"result": "success"} diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index acd3a8b539..1ac20e6531 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -2,7 +2,7 @@ import uuid from datetime import UTC, datetime, timedelta from flask import request -from flask_restful import Resource +from flask_restx import Resource from sqlalchemy import func, select from werkzeug.exceptions import NotFound, Unauthorized diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index ae68df6bdc..930b9d96e9 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,19 +1,21 @@ import urllib.parse import httpx -from flask_restful import marshal_with, reqparse +from flask_restx import marshal_with, reqparse import services from controllers.common import helpers -from controllers.common.errors import RemoteFileUploadError +from controllers.common.errors import ( + FileTooLargeError, + RemoteFileUploadError, + UnsupportedFileTypeError, +) from controllers.web.wraps import WebApiResource from core.file import helpers as file_helpers from core.helper import ssrf_proxy from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields from services.file_service import FileService -from .error import FileTooLargeError, UnsupportedFileTypeError - class RemoteFileInfoApi(WebApiResource): @marshal_with(remote_file_info_fields) diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index d7188ef0b3..a0912499ff 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,5 +1,5 @@ -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restx import fields, marshal_with, reqparse +from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound from controllers.web import api diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 3c133499b7..b2a887a0de 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from flask_restful import fields, marshal_with +from flask_restx import fields, marshal_with from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 590fd3f2c7..331587cc28 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restx import reqparse from werkzeug.exceptions import InternalServerError from controllers.web import api diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index ae6f14a689..94fa5d5626 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime from functools import wraps from flask import request -from flask_restful import Resource +from flask_restx import Resource from sqlalchemy import select from werkzeug.exceptions import BadRequest, NotFound, Unauthorized diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1f3c218d59..f7c83f927f 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -280,7 +280,7 @@ class BaseAgentRunner(AppRunner): def create_agent_thought( self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] - ) -> MessageAgentThought: + ) -> str: """ Create agent thought """ @@ -313,16 +313,15 @@ class BaseAgentRunner(AppRunner): db.session.add(thought) db.session.commit() - db.session.refresh(thought) + agent_thought_id = str(thought.id) + self.agent_thought_count += 1 db.session.close() - self.agent_thought_count += 1 - - return thought + return agent_thought_id def save_agent_thought( self, - agent_thought: MessageAgentThought, + agent_thought_id: str, tool_name: str | None, tool_input: Union[str, dict, None], thought: str | None, @@ -335,12 +334,9 @@ class BaseAgentRunner(AppRunner): """ Save agent thought """ - updated_agent_thought = ( - db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first() - ) - if not updated_agent_thought: + agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first() + if not agent_thought: raise ValueError("agent thought not found") - agent_thought = updated_agent_thought if thought: agent_thought.thought += thought @@ -355,7 +351,7 @@ class BaseAgentRunner(AppRunner): except Exception: tool_input = json.dumps(tool_input) - updated_agent_thought.tool_input = tool_input + agent_thought.tool_input = tool_input if observation: if isinstance(observation, dict): @@ -364,27 +360,27 @@ class BaseAgentRunner(AppRunner): except Exception: observation = json.dumps(observation) - updated_agent_thought.observation = observation + agent_thought.observation = observation if answer: agent_thought.answer = answer if messages_ids is not None and len(messages_ids) > 0: - updated_agent_thought.message_files = json.dumps(messages_ids) + agent_thought.message_files = json.dumps(messages_ids) if llm_usage: - updated_agent_thought.message_token = llm_usage.prompt_tokens - updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit - updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price - updated_agent_thought.answer_token = llm_usage.completion_tokens - updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit - updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price - updated_agent_thought.tokens = llm_usage.total_tokens - updated_agent_thought.total_price = llm_usage.total_price + agent_thought.message_token = llm_usage.prompt_tokens + agent_thought.message_price_unit = llm_usage.prompt_price_unit + agent_thought.message_unit_price = llm_usage.prompt_unit_price + agent_thought.answer_token = llm_usage.completion_tokens + agent_thought.answer_price_unit = llm_usage.completion_price_unit + agent_thought.answer_unit_price = llm_usage.completion_unit_price + agent_thought.tokens = llm_usage.total_tokens + agent_thought.total_price = llm_usage.total_price # check if tool labels is not empty - labels = updated_agent_thought.tool_labels or {} - tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else [] + labels = agent_thought.tool_labels or {} + tools = agent_thought.tool.split(";") if agent_thought.tool else [] for tool in tools: if not tool: continue @@ -395,7 +391,7 @@ class BaseAgentRunner(AppRunner): else: labels[tool] = {"en_US": tool, "zh_Hans": tool} - updated_agent_thought.tool_labels_str = json.dumps(labels) + agent_thought.tool_labels_str = json.dumps(labels) if tool_invoke_meta is not None: if isinstance(tool_invoke_meta, dict): @@ -404,7 +400,7 @@ class BaseAgentRunner(AppRunner): except Exception: tool_invoke_meta = json.dumps(tool_invoke_meta) - updated_agent_thought.tool_meta_str = tool_invoke_meta + agent_thought.tool_meta_str = tool_invoke_meta db.session.commit() db.session.close() @@ -516,7 +512,6 @@ class BaseAgentRunner(AppRunner): if not file_objs: return UserPromptMessage(content=message.query) prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file in file_objs: prompt_message_contents.append( file_manager.to_prompt_message_content( @@ -524,4 +519,6 @@ class BaseAgentRunner(AppRunner): image_detail_config=image_detail_config, ) ) + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + return UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 4979f63432..6cb1077126 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -97,13 +97,13 @@ class CotAgentRunner(BaseAgentRunner, ABC): message_file_ids: list[str] = [] - agent_thought = self.create_agent_thought( + agent_thought_id = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) if iteration_step > 1: self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) # recalc llm max tokens @@ -133,7 +133,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): # publish agent thought if it's first iteration if iteration_step == 1: self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) for chunk in react_chunks: @@ -168,7 +168,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): usage_dict["usage"] = LLMUsage.empty_usage() self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""), tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, @@ -181,7 +181,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): if not scratchpad.is_final(): self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) if not scratchpad.action: @@ -197,7 +197,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): final_answer = scratchpad.action.action_input else: final_answer = f"{scratchpad.action.action_input}" - except json.JSONDecodeError: + except TypeError: final_answer = f"{scratchpad.action.action_input}" else: function_call_state = True @@ -212,7 +212,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): scratchpad.agent_response = tool_invoke_response self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name=scratchpad.action.action_name, tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, thought=scratchpad.thought or "", @@ -224,7 +224,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): ) self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) # update prompt tool message @@ -244,7 +244,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name="", tool_input={}, tool_invoke_meta={}, diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 5ff89bdacb..4d1d94eadc 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -39,9 +39,6 @@ class CotChatAgentRunner(CotAgentRunner): Organize user query """ if self.files: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=query)) - # get image detail config image_detail_config = ( self.application_generate_entity.file_upload_config.image_config.detail @@ -52,6 +49,8 @@ class CotChatAgentRunner(CotAgentRunner): else None ) image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in self.files: prompt_message_contents.append( file_manager.to_prompt_message_content( @@ -59,6 +58,7 @@ class CotChatAgentRunner(CotAgentRunner): image_detail_config=image_detail_config, ) ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5491689ece..9eb853aa74 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -80,7 +80,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): prompt_messages_tools = [] message_file_ids: list[str] = [] - agent_thought = self.create_agent_thought( + agent_thought_id = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) @@ -114,7 +114,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): for chunk in chunks: if is_first_chunk: self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) is_first_chunk = False # check if there is any tool call @@ -126,8 +126,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_call_inputs = json.dumps( {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False ) - except json.JSONDecodeError: - # ensure ascii to avoid encoding error + except TypeError: + # fallback: force ASCII to handle non-serializable objects tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if chunk.delta.message and chunk.delta.message.content: @@ -153,8 +153,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): tool_call_inputs = json.dumps( {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False ) - except json.JSONDecodeError: - # ensure ascii to avoid encoding error + except TypeError: + # fallback: force ASCII to handle non-serializable objects tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if result.usage: @@ -172,7 +172,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): result.message.content = "" self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) yield LLMResultChunk( @@ -205,7 +205,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): # save thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name=tool_call_names, tool_input=tool_call_inputs, thought=response, @@ -216,7 +216,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): llm_usage=current_llm_usage, ) self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) final_answer += response + "\n" @@ -276,7 +276,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): if len(tool_responses) > 0: # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought_id=agent_thought_id, tool_name="", tool_input="", thought="", @@ -291,7 +291,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): messages_ids=message_file_ids, ) self.queue_manager.publish( - QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) # update prompt tool @@ -395,9 +395,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): Organize user query """ if self.files: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=query)) - # get image detail config image_detail_config = ( self.application_generate_entity.file_upload_config.image_config.detail @@ -408,6 +405,8 @@ class FunctionCallAgentRunner(BaseAgentRunner): else None ) image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in self.files: prompt_message_contents.append( file_manager.to_prompt_message_content( @@ -415,6 +414,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): image_detail_config=image_detail_config, ) ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 75bd2f677a..0db1d52779 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -148,6 +148,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", @@ -165,7 +167,7 @@ class ModelConfig(BaseModel): provider: str name: str mode: LLMMode - completion_params: dict[str, Any] = {} + completion_params: dict[str, Any] = Field(default_factory=dict) class Condition(BaseModel): diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 610a5bb278..52ae20ee16 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -600,5 +600,5 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: - logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}") + logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) raise e diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index a75e17af64..3de2f5ca9e 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -118,26 +118,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ): return - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == self.conversation.app_id, - ConversationVariable.conversation_id == self.conversation.id, - ) - with Session(db.engine) as session: - db_conversation_variables = session.scalars(stmt).all() - if not db_conversation_variables: - # Create conversation variables if they don't exist. - db_conversation_variables = [ - ConversationVariable.from_variable( - app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable - ) - for variable in self._workflow.conversation_variables - ] - session.add_all(db_conversation_variables) - # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in db_conversation_variables] - - session.commit() + # Initialize conversation variables + conversation_variables = self._initialize_conversation_variables() # Create a variable pool. system_inputs = SystemVariable( @@ -292,3 +274,100 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): message_id=message_id, trace_manager=app_generate_entity.trace_manager, ) + + def _initialize_conversation_variables(self) -> list[VariableUnion]: + """ + Initialize conversation variables for the current conversation. + + This method: + 1. Loads existing variables from the database + 2. Creates new variables if none exist + 3. Syncs missing variables from the workflow definition + + :return: List of conversation variables ready for use + """ + with Session(db.engine) as session: + existing_variables = self._load_existing_conversation_variables(session) + + if not existing_variables: + # First time initialization - create all variables + existing_variables = self._create_all_conversation_variables(session) + else: + # Check and add any missing variables from the workflow + existing_variables = self._sync_missing_conversation_variables(session, existing_variables) + + # Convert to Variable objects for use in the workflow + conversation_variables = [var.to_variable() for var in existing_variables] + + session.commit() + return cast(list[VariableUnion], conversation_variables) + + def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: + """ + Load existing conversation variables from the database. + + :param session: Database session + :return: List of existing conversation variables + """ + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, + ) + return list(session.scalars(stmt).all()) + + def _create_all_conversation_variables(self, session: Session) -> list[ConversationVariable]: + """ + Create all conversation variables for a new conversation. + + :param session: Database session + :return: List of created conversation variables + """ + new_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in self._workflow.conversation_variables + ] + + if new_variables: + session.add_all(new_variables) + + return new_variables + + def _sync_missing_conversation_variables( + self, session: Session, existing_variables: list[ConversationVariable] + ) -> list[ConversationVariable]: + """ + Sync missing conversation variables from the workflow definition. + + This handles the case where new variables are added to a workflow + after conversations have already been created. + + :param session: Database session + :param existing_variables: List of existing conversation variables + :return: Updated list including any newly created variables + """ + # Get IDs of existing and workflow variables + existing_ids = {var.id for var in existing_variables} + workflow_variables = {var.id: var for var in self._workflow.conversation_variables} + + # Find missing variable IDs + missing_ids = set(workflow_variables.keys()) - existing_ids + + if not missing_ids: + return existing_variables + + # Create missing variables with their default values + new_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, + conversation_id=self.conversation.id, + variable=workflow_variables[var_id], + ) + for var_id in missing_ids + ] + + session.add_all(new_variables) + + # Return combined list + return existing_variables + new_variables diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index dc27076a4d..347fed4a17 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -74,6 +74,7 @@ from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from events.message_event import message_was_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models import Conversation, EndUser, Message, MessageFile from models.account import Account from models.enums import CreatorUserRole @@ -271,7 +272,7 @@ class AdvancedChatAppGenerateTaskPipeline: start_listener_time = time.time() yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) except Exception: - logger.exception(f"Failed to listen audio message, task_id: {task_id}") + logger.exception("Failed to listen audio message, task_id: %s", task_id) break if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) @@ -568,7 +569,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) yield workflow_finish_resp - self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) def _handle_workflow_partial_success_event( self, @@ -600,7 +601,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) yield workflow_finish_resp - self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) def _handle_workflow_failed_event( self, @@ -845,7 +846,7 @@ class AdvancedChatAppGenerateTaskPipeline: # Initialize graph runtime state graph_runtime_state: Optional[GraphRuntimeState] = None - for queue_message in self._base_task_pipeline._queue_manager.listen(): + for queue_message in self._base_task_pipeline.queue_manager.listen(): event = queue_message.event match event: @@ -896,6 +897,7 @@ class AdvancedChatAppGenerateTaskPipeline: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: message = self._get_message(session=session) message.answer = self._task_state.answer + message.updated_at = naive_utc_now() message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at message.message_metadata = self._task_state.metadata.model_dump_json() message_files = [ @@ -959,11 +961,11 @@ class AdvancedChatAppGenerateTaskPipeline: if self._base_task_pipeline._output_moderation_handler: if self._base_task_pipeline._output_moderation_handler.should_direct_output(): self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() - self._base_task_pipeline._queue_manager.publish( + self._base_task_pipeline.queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) - self._base_task_pipeline._queue_manager.publish( + self._base_task_pipeline.queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 0c76cc39ae..c273776eb1 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -140,7 +140,9 @@ class ChatAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - trace_manager = TraceQueueManager(app_id=app_model.id) + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) # init application generate entity application_generate_entity = ChatAppGenerateEntity( diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 34a1da2227..1a89237333 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -50,6 +50,7 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from libs.datetime_utils import naive_utc_now from models import ( Account, CreatorUserRole, @@ -399,7 +400,7 @@ class WorkflowResponseConverter: if event.error is None else WorkflowNodeExecutionStatus.FAILED, error=None, - elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), + elapsed_time=(naive_utc_now() - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), @@ -478,7 +479,7 @@ class WorkflowResponseConverter: if event.error is None else WorkflowNodeExecutionStatus.FAILED, error=None, - elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), + elapsed_time=(naive_utc_now() - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 9356bd1cea..64dade2968 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -124,7 +124,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - trace_manager = TraceQueueManager(app_model.id) + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) # init application generate entity application_generate_entity = CompletionAppGenerateEntity( diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 7dd9904eeb..11c979765b 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -78,7 +78,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: - logger.exception(f"Failed to handle response, conversation_id: {conversation.id}") + logger.exception("Failed to handle response, conversation_id: %s", conversation.id) raise e def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 8507f23f17..4100a0d5a9 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -6,7 +6,6 @@ from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAdvancedChatMessageEndEvent, QueueErrorEvent, - QueueMessage, QueueMessageEndEvent, QueueStopEvent, ) @@ -22,15 +21,6 @@ class MessageBasedAppQueueManager(AppQueueManager): self._app_mode = app_mode self._message_id = str(message_id) - def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: - return MessageQueueMessage( - task_id=self._task_id, - message_id=self._message_id, - conversation_id=self._conversation_id, - app_mode=self._app_mode, - event=event, - ) - def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ Publish event to queue diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 4c36f63c71..22b0234604 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -483,7 +483,7 @@ class WorkflowAppGenerator(BaseAppGenerator): try: runner.run() except GenerateTaskStoppedError as e: - logger.warning(f"Task stopped: {str(e)}") + logger.warning("Task stopped: %s", str(e)) pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -540,6 +540,6 @@ class WorkflowAppGenerator(BaseAppGenerator): raise GenerateTaskStoppedError() else: logger.exception( - f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}" + "Fails to process generate task pipeline, task_id: %s", application_generate_entity.task_id ) raise e diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index e31a316c56..537c070adf 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -246,7 +246,7 @@ class WorkflowAppGenerateTaskPipeline: else: yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) except Exception: - logger.exception(f"Fails to get audio trunk, task_id: {task_id}") + logger.exception("Fails to get audio trunk, task_id: %s", task_id) break if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) @@ -711,7 +711,7 @@ class WorkflowAppGenerateTaskPipeline: # Initialize graph runtime state graph_runtime_state = None - for queue_message in self._base_task_pipeline._queue_manager.listen(): + for queue_message in self._base_task_pipeline.queue_manager.listen(): event = queue_message.event match event: diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 65ed267959..11f37c4baa 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -9,7 +9,6 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAp from core.entities.provider_configuration import ProviderModelBundle from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity -from core.ops.ops_trace_manager import TraceQueueManager class InvokeFrom(Enum): @@ -114,7 +113,8 @@ class AppGenerateEntity(BaseModel): extras: dict[str, Any] = Field(default_factory=dict) # tracing instance - trace_manager: Optional[TraceQueueManager] = None + # Using Any to avoid circular import with TraceQueueManager + trace_manager: Optional[Any] = None class EasyUIBasedAppGenerateEntity(AppGenerateEntity): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 42e6a1519c..d663dbb175 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -610,7 +610,7 @@ class QueueErrorEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.ERROR - error: Any = None + error: Optional[Any] = None class QueuePingEvent(AppQueueEvent): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 25c889e922..a1c0368354 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -142,7 +142,7 @@ class MessageEndStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE_END id: str - metadata: dict = {} + metadata: dict = Field(default_factory=dict) files: Optional[Sequence[Mapping[str, Any]]] = None @@ -261,7 +261,7 @@ class NodeStartStreamResponse(StreamResponse): predecessor_node_id: Optional[str] = None inputs: Optional[Mapping[str, Any]] = None created_at: int - extras: dict = {} + extras: dict = Field(default_factory=dict) parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parent_parallel_id: Optional[str] = None @@ -503,7 +503,7 @@ class IterationNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = {} + extras: dict = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} parallel_id: Optional[str] = None @@ -531,7 +531,7 @@ class IterationNodeNextStreamResponse(StreamResponse): index: int created_at: int pre_iteration_output: Optional[Any] = None - extras: dict = {} + extras: dict = Field(default_factory=dict) parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parallel_mode_run_id: Optional[str] = None @@ -590,7 +590,7 @@ class LoopNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = {} + extras: dict = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} parallel_id: Optional[str] = None @@ -618,7 +618,7 @@ class LoopNodeNextStreamResponse(StreamResponse): index: int created_at: int pre_loop_output: Optional[Any] = None - extras: dict = {} + extras: dict = Field(default_factory=dict) parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parallel_mode_run_id: Optional[str] = None @@ -764,7 +764,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): conversation_id: str message_id: str answer: str - metadata: dict = {} + metadata: dict = Field(default_factory=dict) created_at: int data: Data @@ -784,7 +784,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse): mode: str message_id: str answer: str - metadata: dict = {} + metadata: dict = Field(default_factory=dict) created_at: int data: Data diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 54dc69302a..b829340401 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -83,7 +83,7 @@ class AnnotationReplyFeature: return annotation except Exception as e: - logger.warning(f"Query annotation failed, exception: {str(e)}.") + logger.warning("Query annotation failed, exception: %s.", str(e)) return None return None diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 3ed0c3352f..8c0a442158 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -37,7 +37,7 @@ class BasedGenerateTaskPipeline: stream: bool, ) -> None: self._application_generate_entity = application_generate_entity - self._queue_manager = queue_manager + self.queue_manager = queue_manager self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() self._stream = stream @@ -52,7 +52,8 @@ class BasedGenerateTaskPipeline: elif isinstance(e, InvokeError | ValueError): err = e else: - err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) + description = getattr(e, "description", None) + err = Exception(description if description is not None else str(e)) if not message_id or not session: return err @@ -113,7 +114,7 @@ class BasedGenerateTaskPipeline: tenant_id=app_config.tenant_id, app_id=app_config.app_id, rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), - queue_manager=self._queue_manager, + queue_manager=self.queue_manager, ) return None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 888434798a..471118c8cb 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -57,6 +57,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -257,7 +258,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): Process stream response. :return: """ - for message in self._queue_manager.listen(): + for message in self.queue_manager.listen(): if publisher: publisher.publish(message) event = message.event @@ -389,6 +390,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if llm_result.message.content else "" ) + message.updated_at = naive_utc_now() message.answer_tokens = usage.completion_tokens message.answer_unit_price = usage.completion_unit_price message.answer_price_unit = usage.completion_price_unit @@ -499,7 +501,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if self._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() - self._queue_manager.publish( + self.queue_manager.publish( QueueLLMChunkEvent( chunk=LLMResultChunk( model=self._task_state.llm_result.model, @@ -513,7 +515,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): PublishFrom.TASK_PIPELINE, ) - self._queue_manager.publish( + self.queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 824da0b934..0d786ba051 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -23,6 +23,7 @@ from core.app.entities.task_entities import ( MessageFileStreamResponse, MessageReplaceStreamResponse, MessageStreamResponse, + StreamEvent, WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator @@ -97,7 +98,7 @@ class MessageCycleManager: conversation.name = name except Exception as e: if dify_config.DEBUG: - logging.exception(f"generate conversation name failed, conversation_id: {conversation_id}") + logging.exception("generate conversation name failed, conversation_id: %s", conversation_id) pass db.session.merge(conversation) @@ -180,11 +181,15 @@ class MessageCycleManager: :param message_id: message id :return: """ + message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first() + event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE + return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, answer=answer, from_variable_selector=from_variable_selector, + event=event_type, ) def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index af5c18e267..646e0e21e9 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,4 +1,3 @@ -import datetime import json import logging from collections import defaultdict @@ -29,6 +28,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.plugin.entities.plugin import ModelProviderID from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.provider import ( LoadBalancingModelConfig, Provider, @@ -261,7 +261,7 @@ class ProviderConfiguration(BaseModel): if provider_record: provider_record.encrypted_config = json.dumps(credentials) provider_record.is_valid = True - provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + provider_record.updated_at = naive_utc_now() db.session.commit() else: provider_record = Provider() @@ -426,7 +426,7 @@ class ProviderConfiguration(BaseModel): if provider_model_record: provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.is_valid = True - provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + provider_model_record.updated_at = naive_utc_now() db.session.commit() else: provider_model_record = ProviderModel() @@ -501,7 +501,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + model_setting.updated_at = naive_utc_now() db.session.commit() else: model_setting = ProviderModelSetting() @@ -526,7 +526,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + model_setting.updated_at = naive_utc_now() db.session.commit() else: model_setting = ProviderModelSetting() @@ -599,7 +599,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + model_setting.updated_at = naive_utc_now() db.session.commit() else: model_setting = ProviderModelSetting() @@ -638,7 +638,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + model_setting.updated_at = naive_utc_now() db.session.commit() else: model_setting = ProviderModelSetting() @@ -843,7 +843,7 @@ class ProviderConfiguration(BaseModel): continue status = ModelStatus.ACTIVE - if m.model in model_setting_map: + if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: model_setting = model_setting_map[m.model_type][m.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -900,7 +900,7 @@ class ProviderConfiguration(BaseModel): credentials=copy_credentials, ) except Exception as ex: - logger.warning(f"get custom model schema failed, {ex}") + logger.warning("get custom model schema failed, %s", ex) continue if not custom_model_schema: @@ -1009,7 +1009,7 @@ class ProviderConfiguration(BaseModel): credentials=model_configuration.credentials, ) except Exception as ex: - logger.warning(f"get custom model schema failed, {ex}") + logger.warning("get custom model schema failed, %s", ex) continue if not custom_model_schema: diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 2a0751a5ee..a5a6e62bd7 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -176,7 +176,7 @@ class ProviderConfig(BasicProviderConfig): scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None required: bool = False - default: Optional[Union[int, str]] = None + default: Optional[Union[int, str, float, bool]] = None options: Optional[list[Option]] = None label: Optional[I18nObject] = None help: Optional[I18nObject] = None diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 3f4e20ec24..accccd8c40 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -22,7 +22,7 @@ class APIBasedExtensionRequestor: :param params: the request params :return: the response json """ - headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} url = self.api_endpoint @@ -49,8 +49,6 @@ class APIBasedExtensionRequestor: raise ValueError("request connection error") if response.status_code != 200: - raise ValueError( - "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) - ) + raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}") return cast(dict, response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 06fdb089d4..ae4671a381 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -17,7 +17,7 @@ class ExtensionModule(enum.Enum): class ModuleExtension(BaseModel): - extension_class: Any = None + extension_class: Optional[Any] = None name: str label: Optional[dict] = None form_schema: Optional[list] = None @@ -66,7 +66,7 @@ class Extensible: # Check for extension module file if (extension_name + ".py") not in file_names: - logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") + logging.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path) continue # Check for builtin flag and position @@ -95,7 +95,7 @@ class Extensible: break if not extension_class: - logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.") + logging.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name) continue # Load schema if not builtin @@ -103,7 +103,7 @@ class Extensible: if not builtin: json_path = os.path.join(subdir_path, "schema.json") if not os.path.exists(json_path): - logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") + logging.warning("Missing schema.json file in %s, Skip.", subdir_path) continue with open(json_path, encoding="utf-8") as f: diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 9eb9e0306b..50c3f9b5f4 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -38,6 +38,7 @@ class Extension: def extension_class(self, module: ExtensionModule, extension_name: str) -> type: module_extension = self.module_extension(module, extension_name) + assert module_extension.extension_class is not None t: type = module_extension.extension_class return t diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 2099a9e34c..d81f372d40 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -49,7 +49,7 @@ class ApiExternalDataTool(ExternalDataTool): """ # get params from config if not self.config: - raise ValueError("config is required, config: {}".format(self.config)) + raise ValueError(f"config is required, config: {self.config}") api_based_extension_id = self.config.get("api_based_extension_id") assert api_based_extension_id is not None, "api_based_extension_id is required" @@ -74,7 +74,7 @@ class ApiExternalDataTool(ExternalDataTool): # request api requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key) except Exception as e: - raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e)) + raise ValueError(f"[External data tool] API query failed, variable: {self.variable}, error: {e}") response_json = requestor.request( point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, @@ -90,7 +90,7 @@ class ApiExternalDataTool(ExternalDataTool): if not isinstance(response_json["result"], str): raise ValueError( - "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable) + f"[External data tool] API query failed, variable: {self.variable}, error: result is not string" ) return response_json["result"] diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index f8c050c2ac..770014aa72 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -32,7 +32,7 @@ def get_attr(*, file: File, attr: FileAttribute): case FileAttribute.TRANSFER_METHOD: return file.transfer_method.value case FileAttribute.URL: - return file.remote_url + return _to_url(file) case FileAttribute.EXTENSION: return file.extension case FileAttribute.RELATED_ID: diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b416e48ce4..3965f8cb31 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from core.variables.utils import SegmentJSONEncoder +from core.variables.utils import dumps_with_segments class TemplateTransformer(ABC): @@ -93,7 +93,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode() + inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index a324ac2767..86bac4119a 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -55,7 +55,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt if moderation_result is True: return True except Exception: - logger.exception(f"Fails to check moderation, provider_name: {provider_name}") + logger.exception("Fails to check moderation, provider_name: %s", provider_name) raise InvokeBadRequestError("Rate limit exceeded, please try again later.") return False diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 9a041667e4..251309fa2c 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -30,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz spec.loader.exec_module(module) return module except Exception as e: - logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'") + logging.exception("Failed to load module %s from script file '%s'", module_name, repr(py_file_path)) raise e diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 11f245812e..329527633c 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -73,10 +73,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if response.status_code not in STATUS_FORCELIST: return response else: - logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") + logging.warning( + "Received status code %s for URL %s which is in the force list", response.status_code, url + ) except httpx.RequestError as e: - logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") + logging.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e) if max_retries == 0: raise diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index e90c3194f2..5cd0ea5c66 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -1,3 +1,4 @@ +import contextlib import re from collections.abc import Mapping from typing import Any, Optional @@ -16,15 +17,33 @@ def get_external_trace_id(request: Any) -> Optional[str]: """ Retrieve the trace_id from the request. - Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid. + Priority: + 1. header ('X-Trace-Id') + 2. parameters + 3. JSON body + 4. Current OpenTelemetry context (if enabled) + 5. OpenTelemetry traceparent header (if present and valid) + + Returns None if no valid trace_id is provided. """ trace_id = request.headers.get("X-Trace-Id") + if not trace_id: trace_id = request.args.get("trace_id") + if not trace_id and getattr(request, "is_json", False): json_data = getattr(request, "json", None) if json_data: trace_id = json_data.get("trace_id") + + if not trace_id: + trace_id = get_trace_id_from_otel_context() + + if not trace_id: + traceparent = request.headers.get("traceparent") + if traceparent: + trace_id = parse_traceparent_header(traceparent) + if isinstance(trace_id, str) and is_valid_trace_id(trace_id): return trace_id return None @@ -40,3 +59,47 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: if trace_id: return {"external_trace_id": trace_id} return {} + + +def get_trace_id_from_otel_context() -> Optional[str]: + """ + Retrieve the current trace ID from the active OpenTelemetry trace context. + Returns None if: + 1. OpenTelemetry SDK is not installed or enabled. + 2. There is no active span or trace context. + """ + try: + from opentelemetry.trace import SpanContext, get_current_span + from opentelemetry.trace.span import INVALID_TRACE_ID + + span = get_current_span() + if not span: + return None + + span_context: SpanContext = span.get_span_context() + + if not span_context or span_context.trace_id == INVALID_TRACE_ID: + return None + + trace_id_hex = f"{span_context.trace_id:032x}" + return trace_id_hex + + except Exception: + return None + + +def parse_traceparent_header(traceparent: str) -> Optional[str]: + """ + Parse the `traceparent` header to extract the trace_id. + + Expected format: + 'version-trace_id-span_id-flags' + + Reference: + W3C Trace Context Specification: https://www.w3.org/TR/trace-context/ + """ + with contextlib.suppress(Exception): + parts = traceparent.split("-") + if len(parts) == 4 and len(parts[1]) == 32: + return parts[1] + return None diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index fc5d0547fc..9876194608 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -1,5 +1,4 @@ import concurrent.futures -import datetime import json import logging import re @@ -9,7 +8,6 @@ import uuid from typing import Any, Optional, cast from flask import current_app -from flask_login import current_user from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config @@ -30,11 +28,12 @@ from core.rag.splitter.fixed_text_splitter import ( FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter -from core.tools.utils.rag_web_reader import get_image_upload_file_ids +from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper +from libs.datetime_utils import naive_utc_now from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile @@ -84,19 +83,19 @@ class IndexingRunner: documents=documents, ) except DocumentIsPausedError: - raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) + raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() except ObjectDeletedError: - logging.warning("Document deleted, document id: {}".format(dataset_document.id)) + logging.warning("Document deleted, document id: %s", dataset_document.id) except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() def run_in_splitting_status(self, dataset_document: DatasetDocument): @@ -147,17 +146,17 @@ class IndexingRunner: index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) except DocumentIsPausedError: - raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) + raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() def run_in_indexing_status(self, dataset_document: DatasetDocument): @@ -222,17 +221,17 @@ class IndexingRunner: index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) except DocumentIsPausedError: - raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) + raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.stopped_at = naive_utc_now() db.session.commit() def indexing_estimate( @@ -295,7 +294,7 @@ class IndexingRunner: text_docs, embedding_model_instance=embedding_model_instance, process_rule=processing_rule.to_dict(), - tenant_id=current_user.current_tenant_id, + tenant_id=tenant_id, doc_language=doc_language, preview=True, ) @@ -324,7 +323,8 @@ class IndexingRunner: except Exception: logging.exception( "Delete image_files failed while indexing_estimate, \ - image_upload_file_is: {}".format(upload_file_id) + image_upload_file_is: %s", + upload_file_id, ) db.session.delete(image_file) @@ -400,7 +400,7 @@ class IndexingRunner: after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DatasetDocument.parsing_completed_at: naive_utc_now(), }, ) @@ -583,7 +583,7 @@ class IndexingRunner: after_indexing_status="completed", extra_update_params={ DatasetDocument.tokens: tokens, - DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DatasetDocument.completed_at: naive_utc_now(), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.error: None, }, @@ -608,7 +608,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.completed_at: naive_utc_now(), } ) @@ -639,7 +639,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.completed_at: naive_utc_now(), } ) @@ -649,7 +649,7 @@ class IndexingRunner: @staticmethod def _check_document_paused_status(document_id: str): - indexing_cache_key = "document_{}_is_paused".format(document_id) + indexing_cache_key = f"document_{document_id}_is_paused" result = redis_client.get(indexing_cache_key) if result: raise DocumentIsPausedError() @@ -727,7 +727,7 @@ class IndexingRunner: doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) # update document status to indexing - cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + cur_time = naive_utc_now() self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -742,7 +742,7 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.indexing_at: naive_utc_now(), }, ) pass diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 331ac933c8..8c1d171688 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,6 +1,7 @@ import json import logging import re +from collections.abc import Sequence from typing import Optional, cast import json_repair @@ -11,6 +12,8 @@ from core.llm_generator.prompts import ( CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, + LLM_MODIFY_CODE_SYSTEM, + LLM_MODIFY_PROMPT_SYSTEM, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, @@ -24,6 +27,9 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from core.workflow.graph_engine.entities.event import AgentLogEvent +from models import App, Message, WorkflowNodeExecutionModel, db class LLMGenerator: @@ -125,16 +131,13 @@ class LLMGenerator: return questions @classmethod - def generate_rule_config( - cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 - ) -> dict: + def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} - model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} - + model_parameters = model_config.get("completion_params", {}) if no_variable: prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) @@ -170,7 +173,7 @@ class LLMGenerator: error = str(e) error_step = "generate rule config" except Exception as e: - logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}") + logging.exception("Failed to generate rule config, model: %s", model_config.get("name")) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -267,7 +270,7 @@ class LLMGenerator: error_step = "generate conversation opener" except Exception as e: - logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}") + logging.exception("Failed to generate rule config, model: %s", model_config.get("name")) rule_config["error"] = str(e) rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -276,12 +279,7 @@ class LLMGenerator: @classmethod def generate_code( - cls, - tenant_id: str, - instruction: str, - model_config: dict, - code_language: str = "javascript", - max_tokens: int = 1000, + cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript" ) -> dict: if code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) @@ -305,8 +303,7 @@ class LLMGenerator: ) prompt_messages = [UserPromptMessage(content=prompt)] - model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} - + model_parameters = model_config.get("completion_params", {}) try: response = cast( LLMResult, @@ -323,7 +320,7 @@ class LLMGenerator: return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} except Exception as e: logging.exception( - f"Failed to invoke LLM model, model: {model_config.get('name')}, language: {code_language}" + "Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language ) return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} @@ -395,5 +392,183 @@ class LLMGenerator: error = str(e) return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} except Exception as e: - logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}") + logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} + + @staticmethod + def instruction_modify_legacy( + tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None + ) -> dict: + app: App | None = db.session.query(App).where(App.id == flow_id).first() + last_run: Message | None = ( + db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() + ) + if not last_run: + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=None, + current=current, + error_message="", + instruction=instruction, + node_type="llm", + ideal_output=ideal_output, + ) + last_run_dict = { + "query": last_run.query, + "answer": last_run.answer, + "error": last_run.error, + } + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=last_run_dict, + current=current, + error_message=str(last_run.error), + instruction=instruction, + node_type="llm", + ideal_output=ideal_output, + ) + + @staticmethod + def instruction_modify_workflow( + tenant_id: str, + flow_id: str, + node_id: str, + current: str, + instruction: str, + model_config: dict, + ideal_output: str | None, + ) -> dict: + from services.workflow_service import WorkflowService + + app: App | None = db.session.query(App).where(App.id == flow_id).first() + if not app: + raise ValueError("App not found.") + workflow = WorkflowService().get_draft_workflow(app_model=app) + if not workflow: + raise ValueError("Workflow not found for the given app model.") + last_run = WorkflowService().get_node_last_run(app_model=app, workflow=workflow, node_id=node_id) + try: + node_type = cast(WorkflowNodeExecutionModel, last_run).node_type + except Exception: + try: + node_type = [it for it in workflow.graph_dict["graph"]["nodes"] if it["id"] == node_id][0]["data"][ + "type" + ] + except Exception: + node_type = "llm" + + if not last_run: # Node is not executed yet + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=None, + current=current, + error_message="", + instruction=instruction, + node_type=node_type, + ideal_output=ideal_output, + ) + + def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence: + raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG) + if not raw_agent_log: + return [] + parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) + + def dict_of_event(event: AgentLogEvent) -> dict: + return { + "status": event.status, + "error": event.error, + "data": event.data, + } + + return [dict_of_event(event) for event in parsed] + + last_run_dict = { + "inputs": last_run.inputs_dict, + "status": last_run.status, + "error": last_run.error, + "agent_log": agent_log_of(last_run), + } + + return LLMGenerator.__instruction_modify_common( + tenant_id=tenant_id, + model_config=model_config, + last_run=last_run_dict, + current=current, + error_message=last_run.error, + instruction=instruction, + node_type=last_run.node_type, + ideal_output=ideal_output, + ) + + @staticmethod + def __instruction_modify_common( + tenant_id: str, + model_config: dict, + last_run: dict | None, + current: str | None, + error_message: str | None, + instruction: str, + node_type: str, + ideal_output: str | None, + ) -> dict: + LAST_RUN = "{{#last_run#}}" + CURRENT = "{{#current#}}" + ERROR_MESSAGE = "{{#error_message#}}" + injected_instruction = instruction + if LAST_RUN in injected_instruction: + injected_instruction = injected_instruction.replace(LAST_RUN, json.dumps(last_run)) + if CURRENT in injected_instruction: + injected_instruction = injected_instruction.replace(CURRENT, current or "null") + if ERROR_MESSAGE in injected_instruction: + injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null") + model_instance = ModelManager().get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + match node_type: + case "llm" | "agent": + system_prompt = LLM_MODIFY_PROMPT_SYSTEM + case "code": + system_prompt = LLM_MODIFY_CODE_SYSTEM + case _: + system_prompt = LLM_MODIFY_PROMPT_SYSTEM + prompt_messages = [ + SystemPromptMessage(content=system_prompt), + UserPromptMessage( + content=json.dumps( + { + "current": current, + "last_run": last_run, + "instruction": injected_instruction, + "ideal_output": ideal_output, + } + ) + ), + ] + model_parameters = {"temperature": 0.4} + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ), + ) + + generated_raw = cast(str, response.message.content) + first_brace = generated_raw.find("{") + last_brace = generated_raw.rfind("}") + return {**json.loads(generated_raw[first_brace : last_brace + 1])} + + except InvokeError as e: + error = str(e) + return {"error": f"Failed to generate code. Error: {error}"} + except Exception as e: + logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e) + return {"error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ef81e38dc5..9268347526 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -309,3 +309,116 @@ eg: Here is the JSON schema: {{schema}} """ # noqa: E501 + +LLM_MODIFY_PROMPT_SYSTEM = """ +Both your input and output should be in JSON format. + +! Below is the schema for input content ! +{ + "type": "object", + "description": "The user is trying to process some content with a prompt, but the output is not as expected. They hope to achieve their goal by modifying the prompt.", + "properties": { + "current": { + "type": "string", + "description": "The prompt before modification, where placeholders {{}} will be replaced with actual values for the large language model. The content in the placeholders should not be changed." + }, + "last_run": { + "type": "object", + "description": "The output result from the large language model after receiving the prompt.", + }, + "instruction": { + "type": "string", + "description": "User's instruction to edit the current prompt" + }, + "ideal_output": { + "type": "string", + "description": "The ideal output that the user expects from the large language model after modifying the prompt. You should compare the last output with the ideal output and make changes to the prompt to achieve the goal." + } + } +} +! Above is the schema for input content ! + +! Below is the schema for output content ! +{ + "type": "object", + "description": "Your feedback to the user after they provide modification suggestions.", + "properties": { + "modified": { + "type": "string", + "description": "Your modified prompt. You should change the original prompt as little as possible to achieve the goal. Keep the language of prompt if not asked to change" + }, + "message": { + "type": "string", + "description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user." + } + }, + "required": [ + "modified", + "message" + ] +} +! Above is the schema for output content ! + +Your output must strictly follow the schema format, do not output any content outside of the JSON body. +""" # noqa: E501 + +LLM_MODIFY_CODE_SYSTEM = """ +Both your input and output should be in JSON format. + +! Below is the schema for input content ! +{ + "type": "object", + "description": "The user is trying to process some data with a code snippet, but the result is not as expected. They hope to achieve their goal by modifying the code.", + "properties": { + "current": { + "type": "string", + "description": "The code before modification." + }, + "last_run": { + "type": "object", + "description": "The result of the code.", + }, + "message": { + "type": "string", + "description": "User's instruction to edit the current code" + } + } +} +! Above is the schema for input content ! + +! Below is the schema for output content ! +{ + "type": "object", + "description": "Your feedback to the user after they provide modification suggestions.", + "properties": { + "modified": { + "type": "string", + "description": "Your modified code. You should change the original code as little as possible to achieve the goal. Keep the programming language of code if not asked to change" + }, + "message": { + "type": "string", + "description": "Your feedback to the user, in the user's language, explaining what you did and your thought process in text, providing sufficient emotional value to the user." + } + }, + "required": [ + "modified", + "message" + ] +} +! Above is the schema for output content ! + +When you are modifying the code, you should remember: +- Do not use print, this not work in dify sandbox. +- Do not try dangerous call like deleting files. It's PROHIBITED. +- Do not use any library that is not built-in in with Python. +- Get inputs from the parameters of the function and have explicit type annotations. +- Write proper imports at the top of the code. +- Use return statement to return the result. +- You should return a `dict`. If you need to return a `result: str`, you should `return {"result": result}`. +Your output must strictly follow the schema format, do not output any content outside of the JSON body. +""" # noqa: E501 + +INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as expected: {{#last_run#}}. +You should edit the prompt according to the IDEAL OUTPUT.""" + +INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}.""" diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index bcb31a816f..eb783297c3 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -5,9 +5,9 @@ import os import secrets import urllib.parse from typing import Optional -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse -import requests +import httpx from pydantic import BaseModel, ValidationError from core.mcp.auth.auth_provider import OAuthClientProvider @@ -99,24 +99,52 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta return full_state_data +def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: + """Check if the server supports OAuth 2.0 Resource Discovery.""" + b_scheme, b_netloc, b_path, b_params, b_query, b_fragment = urlparse(server_url, "", True) + url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}" + if b_query: + url_for_resource_discovery += f"?{b_query}" + if b_fragment: + url_for_resource_discovery += f"#{b_fragment}" + try: + headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"} + response = httpx.get(url_for_resource_discovery, headers=headers) + if 200 <= response.status_code < 300: + body = response.json() + if "authorization_server_url" in body: + return True, body["authorization_server_url"][0] + else: + return False, "" + return False, "" + except httpx.RequestError as e: + # Not support resource discovery, fall back to well-known OAuth metadata + return False, "" + + def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]: """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.""" - url = urljoin(server_url, "/.well-known/oauth-authorization-server") + # First check if the server supports OAuth 2.0 Resource Discovery + support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url) + if support_resource_discovery: + url = oauth_discovery_url + else: + url = urljoin(server_url, "/.well-known/oauth-authorization-server") try: headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION} - response = requests.get(url, headers=headers) + response = httpx.get(url, headers=headers) if response.status_code == 404: return None - if not response.ok: + if not response.is_success: raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") return OAuthMetadata.model_validate(response.json()) - except requests.RequestException as e: - if isinstance(e, requests.ConnectionError): - response = requests.get(url) + except httpx.RequestError as e: + if isinstance(e, httpx.ConnectError): + response = httpx.get(url) if response.status_code == 404: return None - if not response.ok: + if not response.is_success: raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata") return OAuthMetadata.model_validate(response.json()) raise @@ -206,8 +234,8 @@ def exchange_authorization( if client_information.client_secret: params["client_secret"] = client_information.client_secret - response = requests.post(token_url, data=params) - if not response.ok: + response = httpx.post(token_url, data=params) + if not response.is_success: raise ValueError(f"Token exchange failed: HTTP {response.status_code}") return OAuthTokens.model_validate(response.json()) @@ -237,8 +265,8 @@ def refresh_authorization( if client_information.client_secret: params["client_secret"] = client_information.client_secret - response = requests.post(token_url, data=params) - if not response.ok: + response = httpx.post(token_url, data=params) + if not response.is_success: raise ValueError(f"Token refresh failed: HTTP {response.status_code}") return OAuthTokens.model_validate(response.json()) @@ -256,12 +284,12 @@ def register_client( else: registration_url = urljoin(server_url, "/register") - response = requests.post( + response = httpx.post( registration_url, json=client_metadata.model_dump(), headers={"Content-Type": "application/json"}, ) - if not response.ok: + if not response.is_success: response.raise_for_status() return OAuthClientInformationFull.model_validate(response.json()) @@ -283,7 +311,7 @@ def auth( raise ValueError("Existing OAuth client information is required when exchanging an authorization code") try: full_information = register_client(server_url, metadata, provider.client_metadata) - except requests.RequestException as e: + except httpx.RequestError as e: raise ValueError(f"Could not register OAuth client: {e}") provider.save_client_information(full_information) client_information = full_information diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index 00d5a25956..bad99fc092 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -10,8 +10,6 @@ from core.mcp.types import ( from models.tools import MCPToolProvider from services.tools.mcp_tools_manage_service import MCPToolManageService -LATEST_PROTOCOL_VERSION = "1.0" - class OAuthClientProvider: mcp_provider: MCPToolProvider diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 91debcc8f9..cc38954eca 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -7,6 +7,7 @@ from typing import Any, TypeAlias, final from urllib.parse import urljoin, urlparse import httpx +from httpx_sse import EventSource, ServerSentEvent from sseclient import SSEClient from core.mcp import types @@ -37,11 +38,6 @@ WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None] StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError] -def remove_request_params(url: str) -> str: - """Remove request parameters from URL, keeping only the path.""" - return urljoin(url, urlparse(url).path) - - class SSETransport: """SSE client transport implementation.""" @@ -88,7 +84,7 @@ class SSETransport: status_queue: Queue to put status updates. """ endpoint_url = urljoin(self.url, sse_data) - logger.info(f"Received endpoint URL: {endpoint_url}") + logger.info("Received endpoint URL: %s", endpoint_url) if not self._validate_endpoint_url(endpoint_url): error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}" @@ -107,14 +103,14 @@ class SSETransport: """ try: message = types.JSONRPCMessage.model_validate_json(sse_data) - logger.debug(f"Received server message: {message}") + logger.debug("Received server message: %s", message) session_message = SessionMessage(message) read_queue.put(session_message) except Exception as exc: logger.exception("Error parsing server message") read_queue.put(exc) - def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None: """Handle a single SSE event. Args: @@ -128,9 +124,9 @@ class SSETransport: case "message": self._handle_message_event(sse.data, read_queue) case _: - logger.warning(f"Unknown SSE event: {sse.event}") + logger.warning("Unknown SSE event: %s", sse.event) - def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None: + def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None: """Read and process SSE events. Args: @@ -142,7 +138,7 @@ class SSETransport: for sse in event_source.iter_sse(): self._handle_sse_event(sse, read_queue, status_queue) except httpx.ReadError as exc: - logger.debug(f"SSE reader shutting down normally: {exc}") + logger.debug("SSE reader shutting down normally: %s", exc) except Exception as exc: read_queue.put(exc) finally: @@ -165,7 +161,7 @@ class SSETransport: ), ) response.raise_for_status() - logger.debug(f"Client message sent successfully: {response.status_code}") + logger.debug("Client message sent successfully: %s", response.status_code) def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: """Handle writing messages to the server. @@ -190,7 +186,7 @@ class SSETransport: except queue.Empty: continue except httpx.ReadError as exc: - logger.debug(f"Post writer shutting down normally: {exc}") + logger.debug("Post writer shutting down normally: %s", exc) except Exception as exc: logger.exception("Error writing messages") write_queue.put(exc) @@ -225,7 +221,7 @@ class SSETransport: self, executor: ThreadPoolExecutor, client: httpx.Client, - event_source, + event_source: EventSource, ) -> tuple[ReadQueue, WriteQueue]: """Establish connection and start worker threads. @@ -326,8 +322,8 @@ def send_message(http_client: httpx.Client, endpoint_url: str, session_message: ), ) response.raise_for_status() - logger.debug(f"Client message sent successfully: {response.status_code}") - except Exception as exc: + logger.debug("Client message sent successfully: %s", response.status_code) + except Exception: logger.exception("Error sending message") raise @@ -349,13 +345,13 @@ def read_messages( if sse.event == "message": try: message = types.JSONRPCMessage.model_validate_json(sse.data) - logger.debug(f"Received server message: {message}") + logger.debug("Received server message: %s", message) yield SessionMessage(message) except Exception as exc: logger.exception("Error parsing server message") yield exc else: - logger.warning(f"Unknown SSE event: {sse.event}") + logger.warning("Unknown SSE event: %s", sse.event) except Exception as exc: logger.exception("Error reading SSE messages") yield exc diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index fbd8d05f9e..14e346c2f3 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -55,14 +55,10 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3 class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" - pass - class ResumptionError(StreamableHTTPError): """Raised when resumption request is invalid.""" - pass - @dataclass class RequestContext: @@ -74,7 +70,7 @@ class RequestContext: session_message: SessionMessage metadata: ClientMessageMetadata | None server_to_client_queue: ServerToClientQueue # Renamed for clarity - sse_read_timeout: timedelta + sse_read_timeout: float class StreamableHTTPTransport: @@ -84,8 +80,8 @@ class StreamableHTTPTransport: self, url: str, headers: dict[str, Any] | None = None, - timeout: timedelta = timedelta(seconds=30), - sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, ) -> None: """Initialize the StreamableHTTP transport. @@ -97,8 +93,10 @@ class StreamableHTTPTransport: """ self.url = url self.headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + self.sse_read_timeout = ( + sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout + ) self.session_id: str | None = None self.request_headers = { ACCEPT: f"{JSON}, {SSE}", @@ -129,7 +127,7 @@ class StreamableHTTPTransport: new_session_id = response.headers.get(MCP_SESSION_ID) if new_session_id: self.session_id = new_session_id - logger.info(f"Received session ID: {self.session_id}") + logger.info("Received session ID: %s", self.session_id) def _handle_sse_event( self, @@ -142,7 +140,7 @@ class StreamableHTTPTransport: if sse.event == "message": try: message = JSONRPCMessage.model_validate_json(sse.data) - logger.debug(f"SSE message: {message}") + logger.debug("SSE message: %s", message) # If this is a response and we have original_request_id, replace it if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): @@ -168,7 +166,7 @@ class StreamableHTTPTransport: logger.debug("Received ping event") return False else: - logger.warning(f"Unknown SSE event: {sse.event}") + logger.warning("Unknown SSE event: %s", sse.event) return False def handle_get_stream( @@ -186,7 +184,7 @@ class StreamableHTTPTransport: with ssrf_proxy_sse_connect( self.url, headers=headers, - timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), client=client, method="GET", ) as event_source: @@ -197,7 +195,7 @@ class StreamableHTTPTransport: self._handle_sse_event(sse, server_to_client_queue) except Exception as exc: - logger.debug(f"GET stream error (non-fatal): {exc}") + logger.debug("GET stream error (non-fatal): %s", exc) def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" @@ -215,7 +213,7 @@ class StreamableHTTPTransport: with ssrf_proxy_sse_connect( self.url, headers=headers, - timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), client=ctx.client, method="GET", ) as event_source: @@ -352,7 +350,7 @@ class StreamableHTTPTransport: # Check if this is a resumption request is_resumption = bool(metadata and metadata.resumption_token) - logger.debug(f"Sending client message: {message}") + logger.debug("Sending client message: %s", message) # Handle initialized notification if self._is_initialized_notification(message): @@ -389,9 +387,9 @@ class StreamableHTTPTransport: if response.status_code == 405: logger.debug("Server does not allow session termination") elif response.status_code != 200: - logger.warning(f"Session termination failed: {response.status_code}") + logger.warning("Session termination failed: %s", response.status_code) except Exception as exc: - logger.warning(f"Session termination failed: {exc}") + logger.warning("Session termination failed: %s", exc) def get_session_id(self) -> str | None: """Get the current session ID.""" @@ -402,8 +400,8 @@ class StreamableHTTPTransport: def streamablehttp_client( url: str, headers: dict[str, Any] | None = None, - timeout: timedelta = timedelta(seconds=30), - sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, terminate_on_close: bool = True, ) -> Generator[ tuple[ @@ -436,7 +434,7 @@ def streamablehttp_client( try: with create_ssrf_proxy_mcp_http_client( headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), + timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), ) as client: # Define callbacks that need access to thread pool def start_get_stream() -> None: diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 5fe52c008a..7d90d51956 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -23,12 +23,18 @@ class MCPClient: authed: bool = True, authorization_code: Optional[str] = None, for_list: bool = False, + headers: Optional[dict[str, str]] = None, + timeout: Optional[float] = None, + sse_read_timeout: Optional[float] = None, ): # Initialize info self.provider_id = provider_id self.tenant_id = tenant_id self.client_type = "streamable" self.server_url = server_url + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout # Authentication info self.authed = authed @@ -43,7 +49,7 @@ class MCPClient: self._session: Optional[ClientSession] = None self._streams_context: Optional[AbstractContextManager[Any]] = None self._session_context: Optional[ClientSession] = None - self.exit_stack = ExitStack() + self._exit_stack = ExitStack() # Whether the client has been initialized self._initialized = False @@ -75,7 +81,7 @@ class MCPClient: self.connect_server(client_factory, method_name) else: try: - logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.") + logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name) self.connect_server(sse_client, "sse") except MCPConnectionError: logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") @@ -90,21 +96,26 @@ class MCPClient: headers = ( {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} if self.authed and self.token - else {} + else self.headers + ) + self._streams_context = client_factory( + url=self.server_url, + headers=headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, ) - self._streams_context = client_factory(url=self.server_url, headers=headers) if not self._streams_context: raise MCPConnectionError("Failed to create connection context") # Use exit_stack to manage context managers properly if method_name == "mcp": - read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context) + read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context) streams = (read_stream, write_stream) else: # sse_client - streams = self.exit_stack.enter_context(self._streams_context) + streams = self._exit_stack.enter_context(self._streams_context) self._session_context = ClientSession(*streams) - self._session = self.exit_stack.enter_context(self._session_context) + self._session = self._exit_stack.enter_context(self._session_context) session = cast(ClientSession, self._session) session.initialize() return @@ -120,9 +131,6 @@ class MCPClient: if first_try: return self.connect_server(client_factory, method_name, first_try=False) - except MCPConnectionError: - raise - def list_tools(self) -> list[Tool]: """Connect to an MCP server running with SSE transport""" # List available tools to verify connection @@ -142,7 +150,7 @@ class MCPClient: """Clean up resources""" try: # ExitStack will handle proper cleanup of all managed context managers - self.exit_stack.close() + self._exit_stack.close() except Exception as e: logging.exception("Error during cleanup") raise ValueError(f"Error during cleanup: {e}") diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 496b5432a0..efe91bbff4 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -16,13 +16,14 @@ from extensions.ext_database import db from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService -""" -Apply to MCP HTTP streamable server with stateless http -""" logger = logging.getLogger(__name__) class MCPServerStreamableHTTPRequestHandler: + """ + Apply to MCP HTTP streamable server with stateless http + """ + def __init__( self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity] ): diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 7734b8fdd9..031f01f411 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -2,10 +2,9 @@ import logging import queue from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError -from contextlib import ExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Self, TypeVar +from typing import Any, Generic, Optional, Self, TypeVar from httpx import HTTPStatusError from pydantic import BaseModel @@ -170,7 +169,6 @@ class BaseSession( self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} - self._exit_stack = ExitStack() # Initialize executor and future to None for proper cleanup checks self._executor: ThreadPoolExecutor | None = None self._receiver_future: Future | None = None @@ -211,7 +209,7 @@ class BaseSession( request: SendRequestT, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, - metadata: MessageMetadata = None, + metadata: Optional[MessageMetadata] = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -368,7 +366,7 @@ class BaseSession( self._handle_incoming(notification) except Exception as e: # For other validation errors, log and continue - logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}") + logging.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root) else: # Response or error response_queue = self._response_streams.get(message.message.root.id) if response_queue is not None: @@ -377,7 +375,7 @@ class BaseSession( self._handle_incoming(RuntimeError(f"Server Error: {message}")) except queue.Empty: continue - except Exception as e: + except Exception: logging.exception("Error in message processing loop") raise @@ -389,14 +387,12 @@ class BaseSession( If the request is responded to within this method, it will not be forwarded on to the message stream. """ - pass def _received_notification(self, notification: ReceiveNotificationT) -> None: """ Can be overridden by subclasses to handle a notification without needing to listen on the message stream. """ - pass def send_progress_notification( self, progress_token: str | int, progress: float, total: float | None = None @@ -405,11 +401,9 @@ class BaseSession( Sends a progress notification for a request that is currently being processed. """ - pass def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, ) -> None: """A generic handler for incoming messages. Overwritten by subclasses.""" - pass diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index ed2ad508ab..1bccf1d031 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -1,3 +1,4 @@ +import queue from datetime import timedelta from typing import Any, Protocol @@ -85,8 +86,8 @@ class ClientSession( ): def __init__( self, - read_stream, - write_stream, + read_stream: queue.Queue, + write_stream: queue.Queue, read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 99d985a781..49aa8e4498 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -1173,7 +1173,7 @@ class SessionMessage: """A message with specific metadata for transport-specific features.""" message: JSONRPCMessage - metadata: MessageMetadata = None + metadata: Optional[MessageMetadata] = None class OAuthClientMetadata(BaseModel): diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index a54badcd4c..80912bc4c1 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -1,6 +1,10 @@ import json +from collections.abc import Generator +from contextlib import AbstractContextManager import httpx +import httpx_sse +from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError @@ -55,20 +59,42 @@ def create_ssrf_proxy_mcp_http_client( ) -def ssrf_proxy_sse_connect(url, **kwargs): +def ssrf_proxy_sse_connect(url: str, **kwargs) -> AbstractContextManager[httpx_sse.EventSource]: """Connect to SSE endpoint with SSRF proxy protection. This function creates an SSE connection using the configured proxy settings - to prevent SSRF attacks when connecting to external endpoints. + to prevent SSRF attacks when connecting to external endpoints. It returns + a context manager that yields an EventSource object for SSE streaming. + + The function handles HTTP client creation and cleanup automatically, but + also accepts a pre-configured client via kwargs. Args: - url: The SSE endpoint URL - **kwargs: Additional arguments passed to the SSE connection + url (str): The SSE endpoint URL to connect to + **kwargs: Additional arguments passed to the SSE connection, including: + - client (httpx.Client, optional): Pre-configured HTTP client. + If not provided, one will be created with SSRF protection. + - method (str, optional): HTTP method to use, defaults to "GET" + - headers (dict, optional): HTTP headers to include in the request + - timeout (httpx.Timeout, optional): Timeout configuration for the connection Returns: - EventSource object for SSE streaming + AbstractContextManager[httpx_sse.EventSource]: A context manager that yields an EventSource + object for SSE streaming. The EventSource provides access to server-sent events. + + Example: + ```python + with ssrf_proxy_sse_connect(url, headers=headers) as event_source: + for sse in event_source.iter_sse(): + print(sse.event, sse.data) + ``` + + Note: + If a client is not provided in kwargs, one will be automatically created + with SSRF protection based on the application's configuration. If an + exception occurs during connection, any automatically created client + will be cleaned up automatically. """ - from httpx_sse import connect_sse # Extract client if provided, otherwise create one client = kwargs.pop("client", None) @@ -101,7 +127,9 @@ def ssrf_proxy_sse_connect(url, **kwargs): raise -def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None): +def create_mcp_error_response( + request_id: int | str | None, code: int, message: str, data=None +) -> Generator[bytes, None, None]: """Create MCP error response""" error_data = ErrorData(code=code, message=message, data=data) json_response = JSONRPCError( diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 7ce124594a..2a76b1f41a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -99,13 +99,13 @@ class TokenBufferMemory: prompt_messages.append(UserPromptMessage(content=message.query)) else: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file in file_objs: prompt_message = file_manager.to_prompt_message_content( file, image_detail_config=detail, ) prompt_message_contents.append(prompt_message) + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) @@ -121,9 +121,8 @@ class TokenBufferMemory: curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) if curr_message_tokens > max_token_limit: - pruned_memory = [] while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: - pruned_memory.append(prompt_messages.pop(0)) + prompt_messages.pop(0) curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 4886ffe244..51af3d1877 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -535,9 +535,19 @@ class LBModelManager: if dify_config.DEBUG: logger.info( - f"Model LB\nid: {config.id}\nname:{config.name}\n" - f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" - f"model_type: {self._model_type.value}\nmodel: {self._model}" + """Model LB +id: %s +name:%s +tenant_id: %s +provider: %s +model_type: %s +model: %s""", + config.id, + config.name, + self._tenant_id, + self._provider, + self._model_type.value, + self._model, ) return config diff --git a/api/core/model_runtime/README.md b/api/core/model_runtime/README.md index b5de7ad412..3abb3f63ac 100644 --- a/api/core/model_runtime/README.md +++ b/api/core/model_runtime/README.md @@ -30,7 +30,7 @@ This module provides the interface for invoking and authenticating various model In addition, this list also returns configurable parameter information and rules for LLM, as shown below: - ![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png) + ![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png) These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule). @@ -60,8 +60,6 @@ Model Runtime is divided into three layers: It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). - - ## Next Steps - Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md) diff --git a/api/core/model_runtime/README_CN.md b/api/core/model_runtime/README_CN.md index 2fc2a60461..19846481e0 100644 --- a/api/core/model_runtime/README_CN.md +++ b/api/core/model_runtime/README_CN.md @@ -20,19 +20,19 @@ ![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png) -​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。 +​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。 - 可选择的模型列表展示 ![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png) -​ 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 +​ 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 -​ 除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图: +​ 除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图: -​ ![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png) +​ ![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png) -​ 这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。 +​ 这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。 - 供应商/模型凭据鉴权 @@ -40,7 +40,7 @@ ![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png) -​ 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。 +​ 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。 ## 结构 @@ -57,9 +57,10 @@ Model Runtime 分三层: 提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。 对于供应商/模型凭据,有两种情况 + - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - ![Alt text](docs/zh_Hans/images/index/image.png) + ![Alt text](docs/zh_Hans/images/index/image.png) 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 @@ -76,14 +77,17 @@ Model Runtime 分三层: ## 下一步 ### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md) + 当添加后,这里将会出现一个新的供应商 ![Alt text](docs/zh_Hans/images/index/image-1.png) -### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型) +### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#%E5%A2%9E%E5%8A%A0%E6%A8%A1%E5%9E%8B) + 当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。 ![Alt text](docs/zh_Hans/images/index/image-2.png) ### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md) + 你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。 diff --git a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md index d845c4bd09..245aa4699c 100644 --- a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md +++ b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md @@ -56,7 +56,6 @@ provider_credential_schema: credential_form_schemas: ``` - Then, we need to determine what credentials are required to define a model in Xinference. - Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it: @@ -191,7 +190,6 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr """ ``` - Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens and ensure environment variable `PLUGIN_BASED_TOKEN_COUNTING_ENABLED` is set to `true`, This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate. - Model Credentials Validation diff --git a/api/core/model_runtime/docs/en_US/interfaces.md b/api/core/model_runtime/docs/en_US/interfaces.md index 158d4b306b..9a8c2ec942 100644 --- a/api/core/model_runtime/docs/en_US/interfaces.md +++ b/api/core/model_runtime/docs/en_US/interfaces.md @@ -35,12 +35,11 @@ All models need to uniformly implement the following 2 methods: Similar to provider credential verification, this step involves verification for an individual model. - ```python def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials - + :param model: model name :param credentials: model credentials :return: @@ -77,12 +76,12 @@ All models need to uniformly implement the following 2 methods: The key is the error type thrown to the caller The value is the error type thrown by the model, which needs to be converted into a unified error type for the caller. - + :return: Invoke error mapping """ ``` -​ You can refer to OpenAI's `_invoke_error_mapping` for an example. +​ You can refer to OpenAI's `_invoke_error_mapping` for an example. ### LLM @@ -92,7 +91,6 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl Implement the core method for LLM invocation, which can support both streaming and synchronous returns. - ```python def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, @@ -101,7 +99,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl -> Union[LLMResult, Generator]: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages @@ -122,7 +120,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included. - - `prompt_messages` (array[[PromptMessage](#PromptMessage)]) List of prompts + - `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) List of prompts If the model is of the `Completion` type, the list only needs to include one [UserPromptMessage](#UserPromptMessage) element; @@ -132,7 +130,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl The model parameters are defined by the `parameter_rules` in the model's YAML configuration. - - `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] List of tools, equivalent to the `function` in `function calling`. + - `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] List of tools, equivalent to the `function` in `function calling`. That is, the tool list for tool calling. @@ -142,7 +140,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl - `stream` (bool) Whether to output in a streaming manner, default is True - Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult). + Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult). - `user` (string) [optional] Unique identifier of the user @@ -150,7 +148,7 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl - Returns - Streaming output returns Generator[[LLMResultChunk](#LLMResultChunk)], non-streaming output returns [LLMResult](#LLMResult). + Streaming output returns Generator\[[LLMResultChunk](#LLMResultChunk)\], non-streaming output returns [LLMResult](#LLMResult). - Pre-calculating Input Tokens @@ -187,7 +185,6 @@ Inherit the `__base.large_language_model.LargeLanguageModel` base class and impl When the provider supports adding custom LLMs, this method can be implemented to allow custom models to fetch model schema. The default return null. - ### TextEmbedding Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and implement the following interfaces: @@ -200,7 +197,7 @@ Inherit the `__base.text_embedding_model.TextEmbeddingModel` base class and impl -> TextEmbeddingResult: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param texts: texts to embed @@ -256,7 +253,7 @@ Inherit the `__base.rerank_model.RerankModel` base class and implement the follo -> RerankResult: """ Invoke rerank model - + :param model: model name :param credentials: model credentials :param query: search query @@ -302,7 +299,7 @@ Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param file: audio file @@ -339,7 +336,7 @@ Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None): """ Invoke large language model - + :param model: model name :param credentials: model credentials :param content_text: text content to be translated @@ -381,7 +378,7 @@ Inherit the `__base.moderation_model.ModerationModel` base class and implement t -> bool: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param text: text to moderate @@ -408,11 +405,9 @@ Inherit the `__base.moderation_model.ModerationModel` base class and implement t False indicates that the input text is safe, True indicates otherwise. - - ## Entities -### PromptMessageRole +### PromptMessageRole Message role @@ -583,7 +578,7 @@ class PromptMessageTool(BaseModel): parameters: dict ``` ---- +______________________________________________________________________ ### LLMResult @@ -650,7 +645,7 @@ class LLMUsage(ModelUsage): latency: float # Request latency (s) ``` ---- +______________________________________________________________________ ### TextEmbeddingResult @@ -680,7 +675,7 @@ class EmbeddingUsage(ModelUsage): latency: float # Request latency (s) ``` ---- +______________________________________________________________________ ### RerankResult diff --git a/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md index a770ed157b..97968e9988 100644 --- a/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md +++ b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md @@ -153,8 +153,11 @@ Runtime Errors: - `InvokeConnectionError` Connection error - `InvokeServerUnavailableError` Service provider unavailable + - `InvokeRateLimitError` Rate limit reached + - `InvokeAuthorizationError` Authorization failed + - `InvokeBadRequestError` Parameter error ```python diff --git a/api/core/model_runtime/docs/en_US/provider_scale_out.md b/api/core/model_runtime/docs/en_US/provider_scale_out.md index 07be5811d3..c38c7c0f0c 100644 --- a/api/core/model_runtime/docs/en_US/provider_scale_out.md +++ b/api/core/model_runtime/docs/en_US/provider_scale_out.md @@ -63,6 +63,7 @@ You can also refer to the YAML configuration information under other provider di ### Implementing Provider Code Providers need to inherit the `__base.model_provider.ModelProvider` base class and implement the `validate_provider_credentials` method for unified provider credential verification. For reference, see [AnthropicProvider](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/anthropic.py). + > If the provider is the type of `customizable-model`, there is no need to implement the `validate_provider_credentials` method. ```python @@ -80,7 +81,7 @@ def validate_provider_credentials(self, credentials: dict) -> None: Of course, you can also preliminarily reserve the implementation of `validate_provider_credentials` and directly reuse it after the model credential verification method is implemented. ---- +______________________________________________________________________ ### Adding Models @@ -166,7 +167,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag -> Union[LLMResult, Generator]: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages @@ -205,7 +206,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials - + :param model: model name :param credentials: model credentials :return: @@ -232,7 +233,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag The key is the error type thrown to the caller The value is the error type thrown by the model, which needs to be converted into a unified error type for the caller. - + :return: Invoke error mapping """ ``` diff --git a/api/core/model_runtime/docs/en_US/schema.md b/api/core/model_runtime/docs/en_US/schema.md index f819a4dbdc..1cea4127f4 100644 --- a/api/core/model_runtime/docs/en_US/schema.md +++ b/api/core/model_runtime/docs/en_US/schema.md @@ -28,8 +28,8 @@ - `url` (object) help link, i18n - `zh_Hans` (string) [optional] Chinese link - `en_US` (string) English link -- `supported_model_types` (array[[ModelType](#ModelType)]) Supported model types -- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) Configuration methods +- `supported_model_types` (array\[[ModelType](#ModelType)\]) Supported model types +- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) Configuration methods - `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) Provider credential specification - `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) Model credential specification @@ -40,23 +40,23 @@ - `zh_Hans` (string) [optional] Chinese label name - `en_US` (string) English label name - `model_type` ([ModelType](#ModelType)) Model type -- `features` (array[[ModelFeature](#ModelFeature)]) [optional] Supported feature list +- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] Supported feature list - `model_properties` (object) Model properties - `mode` ([LLMMode](#LLMMode)) Mode (available for model type `llm`) - `context_size` (int) Context size (available for model types `llm`, `text-embedding`) - `max_chunks` (int) Maximum number of chunks (available for model types `text-embedding`, `moderation`) - `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`) - `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`) - - `default_voice` (string) default voice, e.g.:alloy,echo,fable,onyx,nova,shimmer(available for model type `tts`) - - `voices` (list) List of available voice.(available for model type `tts`) - - `mode` (string) voice model.(available for model type `tts`) - - `name` (string) voice model display name.(available for model type `tts`) - - `language` (string) the voice model supports languages.(available for model type `tts`) - - `word_limit` (int) Single conversion word limit, paragraph-wise by default(available for model type `tts`) - - `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`) - - `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`) + - `default_voice` (string) default voice, e.g.:alloy,echo,fable,onyx,nova,shimmer(available for model type `tts`) + - `voices` (list) List of available voice.(available for model type `tts`) + - `mode` (string) voice model.(available for model type `tts`) + - `name` (string) voice model display name.(available for model type `tts`) + - `language` (string) the voice model supports languages.(available for model type `tts`) + - `word_limit` (int) Single conversion word limit, paragraph-wise by default(available for model type `tts`) + - `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`) + - `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`) - `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`) -- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] Model invocation parameter rules +- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] Model invocation parameter rules - `pricing` ([PriceConfig](#PriceConfig)) [optional] Pricing information - `deprecated` (bool) Whether deprecated. If deprecated, the model will no longer be displayed in the list, but those already configured can continue to be used. Default False. @@ -74,6 +74,7 @@ - `predefined-model` Predefined model Indicates that users can use the predefined models under the provider by configuring the unified provider credentials. + - `customizable-model` Customizable model Users need to add credential configuration for each model. @@ -103,6 +104,7 @@ ### ParameterRule - `name` (string) Actual model invocation parameter name + - `use_template` (string) [optional] Using template By default, 5 variable content configuration templates are preset: @@ -112,7 +114,7 @@ - `frequency_penalty` - `presence_penalty` - `max_tokens` - + In use_template, you can directly set the template variable name, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE No need to set any parameters other than `name` and `use_template`. If additional configuration parameters are set, they will override the default configuration. Refer to `openai/llm/gpt-3.5-turbo.yaml`. @@ -155,7 +157,7 @@ ### ProviderCredentialSchema -- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard +- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard ### ModelCredentialSchema @@ -166,7 +168,7 @@ - `placeholder` (object) Model prompt content - `en_US`(string) English - `zh_Hans`(string) [optional] Chinese -- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) Credential form standard +- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) Credential form standard ### CredentialFormSchema @@ -177,12 +179,12 @@ - `type` ([FormType](#FormType)) Form item type - `required` (bool) Whether required - `default`(string) Default value -- `options` (array[[FormOption](#FormOption)]) Specific property of form items of type `select` or `radio`, defining dropdown content +- `options` (array\[[FormOption](#FormOption)\]) Specific property of form items of type `select` or `radio`, defining dropdown content - `placeholder`(object) Specific property of form items of type `text-input`, placeholder content - `en_US`(string) English - `zh_Hans` (string) [optional] Chinese - `max_length` (int) Specific property of form items of type `text-input`, defining maximum input length, 0 for no limit. -- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty. +- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty. ### FormType @@ -198,7 +200,7 @@ - `en_US`(string) English - `zh_Hans`(string) [optional] Chinese - `value` (string) Dropdown option value -- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) Displayed when other form item values meet certain conditions, displayed always if empty. +- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) Displayed when other form item values meet certain conditions, displayed always if empty. ### FormShowOnObject diff --git a/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md b/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md index 7d30655469..825f9349d7 100644 --- a/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md +++ b/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md @@ -10,7 +10,6 @@ ![Alt text](images/index/image-3.png) - 在前文中,我们已经知道了供应商无需实现`validate_provider_credential`,Runtime 会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。 ### 编写供应商 yaml @@ -55,6 +54,7 @@ provider_credential_schema: 随后,我们需要思考在 Xinference 中定义一个模型需要哪些凭据 - 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写 + ```yaml provider_credential_schema: credential_form_schemas: @@ -76,7 +76,9 @@ provider_credential_schema: label: en_US: Rerank ``` + - 每一个模型都有自己的名称`model_name`,因此需要在这里定义 + ```yaml - variable: model_name type: text-input @@ -88,7 +90,9 @@ provider_credential_schema: zh_Hans: 填写模型名称 en_US: Input model name ``` + - 填写 Xinference 本地部署的地址 + ```yaml - variable: server_url label: @@ -100,7 +104,9 @@ provider_credential_schema: zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx en_US: Enter the url of your Xinference, for example https://example.com/xxx ``` + - 每个模型都有唯一的 model_uid,因此需要在这里定义 + ```yaml - variable: model_uid label: @@ -112,6 +118,7 @@ provider_credential_schema: zh_Hans: 在此输入您的 Model UID en_US: Enter the model uid ``` + 现在,我们就完成了供应商的基础定义。 ### 编写模型代码 @@ -132,7 +139,7 @@ provider_credential_schema: -> Union[LLMResult, Generator]: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages @@ -189,7 +196,7 @@ provider_credential_schema: def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials - + :param model: model name :param credentials: model credentials :return: @@ -197,78 +204,78 @@ provider_credential_schema: ``` - 模型参数 Schema - + 与自定义类型不同,由于没有在 yaml 文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的 Schema。 - + 如 Xinference 支持`max_tokens` `temperature` `top_p` 这三个模型参数。 - + 但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例 A 模型支持`top_k`,B 模型不支持`top_k`,那么我们需要在这里动态生成模型参数的 Schema,如下所示: - - ```python - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - """ - used to define customizable model schema - """ - rules = [ - ParameterRule( - name='temperature', type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', en_US='Temperature' - ) - ), - ParameterRule( - name='top_p', type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', en_US='Top P' - ) - ), - ParameterRule( - name='max_tokens', type=ParameterType.INT, - use_template='max_tokens', - min=1, - default=512, - label=I18nObject( - zh_Hans='最大生成长度', en_US='Max Tokens' - ) - ) - ] - # if model is A, add top_k to rules - if model == 'A': - rules.append( - ParameterRule( - name='top_k', type=ParameterType.INT, - use_template='top_k', - min=1, - default=50, - label=I18nObject( - zh_Hans='Top K', en_US='Top K' - ) - ) - ) + ```python + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', en_US='Temperature' + ) + ), + ParameterRule( + name='top_p', type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', type=ParameterType.INT, + use_template='max_tokens', + min=1, + default=512, + label=I18nObject( + zh_Hans='最大生成长度', en_US='Max Tokens' + ) + ) + ] - """ - some NOT IMPORTANT code here - """ + # if model is A, add top_k to rules + if model == 'A': + rules.append( + ParameterRule( + name='top_k', type=ParameterType.INT, + use_template='top_k', + min=1, + default=50, + label=I18nObject( + zh_Hans='Top K', en_US='Top K' + ) + ) + ) - entity = AIModelEntity( - model=model, - label=I18nObject( - en_US=model - ), - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_type=model_type, - model_properties={ - ModelPropertyKey.MODE: ModelType.LLM, - }, - parameter_rules=rules - ) + """ + some NOT IMPORTANT code here + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=model_type, + model_properties={ + ModelPropertyKey.MODE: ModelType.LLM, + }, + parameter_rules=rules + ) + + return entity + ``` - return entity - ``` - - 调用异常错误映射表 当模型调用异常时需要映射到 Runtime 指定的 `InvokeError` 类型,方便 Dify 针对不同错误做不同后续处理。 @@ -278,7 +285,7 @@ provider_credential_schema: - `InvokeConnectionError` 调用连接错误 - `InvokeServerUnavailableError ` 调用服务方不可用 - `InvokeRateLimitError ` 调用达到限额 - - `InvokeAuthorizationError` 调用鉴权失败 + - `InvokeAuthorizationError` 调用鉴权失败 - `InvokeBadRequestError ` 调用传参有误 ```python @@ -289,7 +296,7 @@ provider_credential_schema: The key is the error type thrown to the caller The value is the error type thrown by the model, which needs to be converted into a unified error type for the caller. - + :return: Invoke error mapping """ ``` diff --git a/api/core/model_runtime/docs/zh_Hans/interfaces.md b/api/core/model_runtime/docs/zh_Hans/interfaces.md index 93a48cafb8..8eeeee9ff9 100644 --- a/api/core/model_runtime/docs/zh_Hans/interfaces.md +++ b/api/core/model_runtime/docs/zh_Hans/interfaces.md @@ -49,7 +49,7 @@ class XinferenceProvider(Provider): def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials - + :param model: model name :param credentials: model credentials :return: @@ -75,7 +75,7 @@ class XinferenceProvider(Provider): - `InvokeConnectionError` 调用连接错误 - `InvokeServerUnavailableError ` 调用服务方不可用 - `InvokeRateLimitError ` 调用达到限额 - - `InvokeAuthorizationError` 调用鉴权失败 + - `InvokeAuthorizationError` 调用鉴权失败 - `InvokeBadRequestError ` 调用传参有误 ```python @@ -86,36 +86,36 @@ class XinferenceProvider(Provider): The key is the error type thrown to the caller The value is the error type thrown by the model, which needs to be converted into a unified error type for the caller. - + :return: Invoke error mapping """ ``` 也可以直接抛出对应 Errors,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。 - - ```python - @property - def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError - ], - } - ``` -​ 可参考 OpenAI `_invoke_error_mapping`。 + ```python + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError + ], + } + ``` + +​ 可参考 OpenAI `_invoke_error_mapping`。 ### LLM @@ -133,7 +133,7 @@ class XinferenceProvider(Provider): -> Union[LLMResult, Generator]: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages @@ -151,38 +151,38 @@ class XinferenceProvider(Provider): - `model` (string) 模型名称 - `credentials` (object) 凭据信息 - + 凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。 - - `prompt_messages` (array[[PromptMessage](#PromptMessage)]) Prompt 列表 - + - `prompt_messages` (array\[[PromptMessage](#PromptMessage)\]) Prompt 列表 + 若模型为 `Completion` 类型,则列表只需要传入一个 [UserPromptMessage](#UserPromptMessage) 元素即可; - + 若模型为 `Chat` 类型,需要根据消息不同传入 [SystemPromptMessage](#SystemPromptMessage), [UserPromptMessage](#UserPromptMessage), [AssistantPromptMessage](#AssistantPromptMessage), [ToolPromptMessage](#ToolPromptMessage) 元素列表 - `model_parameters` (object) 模型参数 - + 模型参数由模型 YAML 配置的 `parameter_rules` 定义。 - - `tools` (array[[PromptMessageTool](#PromptMessageTool)]) [optional] 工具列表,等同于 `function calling` 中的 `function`。 - + - `tools` (array\[[PromptMessageTool](#PromptMessageTool)\]) [optional] 工具列表,等同于 `function calling` 中的 `function`。 + 即传入 tool calling 的工具列表。 - `stop` (array[string]) [optional] 停止序列 - + 模型返回将在停止序列定义的字符串之前停止输出。 - `stream` (bool) 是否流式输出,默认 True - - 流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。 + + 流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。 - `user` (string) [optional] 用户的唯一标识符 - + 可以帮助供应商监控和检测滥用行为。 - 返回 - 流式输出返回 Generator[[LLMResultChunk](#LLMResultChunk)],非流式输出返回 [LLMResult](#LLMResult)。 + 流式输出返回 Generator\[[LLMResultChunk](#LLMResultChunk)\],非流式输出返回 [LLMResult](#LLMResult)。 - 预计算输入 tokens @@ -236,7 +236,7 @@ class XinferenceProvider(Provider): -> TextEmbeddingResult: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param texts: texts to embed @@ -294,7 +294,7 @@ class XinferenceProvider(Provider): -> RerankResult: """ Invoke rerank model - + :param model: model name :param credentials: model credentials :param query: search query @@ -342,7 +342,7 @@ class XinferenceProvider(Provider): -> str: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param file: audio file @@ -379,7 +379,7 @@ class XinferenceProvider(Provider): def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None): """ Invoke large language model - + :param model: model name :param credentials: model credentials :param content_text: text content to be translated @@ -421,7 +421,7 @@ class XinferenceProvider(Provider): -> bool: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param text: text to moderate @@ -448,11 +448,9 @@ class XinferenceProvider(Provider): False 代表传入的文本安全,True 则反之。 - - ## 实体 -### PromptMessageRole +### PromptMessageRole 消息角色 @@ -623,7 +621,7 @@ class PromptMessageTool(BaseModel): parameters: dict # 工具参数 dict ``` ---- +______________________________________________________________________ ### LLMResult @@ -690,7 +688,7 @@ class LLMUsage(ModelUsage): latency: float # 请求耗时 (s) ``` ---- +______________________________________________________________________ ### TextEmbeddingResult @@ -720,7 +718,7 @@ class EmbeddingUsage(ModelUsage): latency: float # 请求耗时 (s) ``` ---- +______________________________________________________________________ ### RerankResult diff --git a/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md b/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md index 80e7982e9f..cd4de51ef7 100644 --- a/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md +++ b/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md @@ -62,7 +62,7 @@ pricing: # 价格信息 建议将所有模型配置都准备完毕后再开始模型代码的实现。 -同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。 +同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。 ### 实现模型调用代码 @@ -82,7 +82,7 @@ pricing: # 价格信息 -> Union[LLMResult, Generator]: """ Invoke large language model - + :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages @@ -137,7 +137,7 @@ pricing: # 价格信息 def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials - + :param model: model name :param credentials: model credentials :return: @@ -153,7 +153,7 @@ pricing: # 价格信息 - `InvokeConnectionError` 调用连接错误 - `InvokeServerUnavailableError ` 调用服务方不可用 - `InvokeRateLimitError ` 调用达到限额 - - `InvokeAuthorizationError` 调用鉴权失败 + - `InvokeAuthorizationError` 调用鉴权失败 - `InvokeBadRequestError ` 调用传参有误 ```python @@ -164,7 +164,7 @@ pricing: # 价格信息 The key is the error type thrown to the caller The value is the error type thrown by the model, which needs to be converted into a unified error type for the caller. - + :return: Invoke error mapping """ ``` diff --git a/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md b/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md index 2048b506ac..de48b0d11a 100644 --- a/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md +++ b/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md @@ -5,7 +5,7 @@ - `predefined-model ` 预定义模型 表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。 - + - `customizable-model` 自定义模型 用户需要新增每个模型的凭据配置,如 Xinference,它同时支持 LLM 和 Text Embedding,但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。 @@ -23,9 +23,11 @@ ### 介绍 #### 名词解释 - - `module`: 一个`module`即为一个 Python Package,或者通俗一点,称为一个文件夹,里面包含了一个`__init__.py`文件,以及其他的`.py`文件。 + +- `module`: 一个`module`即为一个 Python Package,或者通俗一点,称为一个文件夹,里面包含了一个`__init__.py`文件,以及其他的`.py`文件。 #### 步骤 + 新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。 - 创建供应商 yaml 文件,根据[ProviderSchema](./schema.md#provider)编写 @@ -117,7 +119,7 @@ model_credential_schema: en_US: Enter your API Base ``` -也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。 +也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。 #### 实现供应商代码 @@ -155,12 +157,14 @@ def validate_provider_credentials(self, credentials: dict) -> None: #### 增加模型 #### [增加预定义模型 👈🏻](./predefined_model_scale_out.md) + 对于预定义模型,我们可以通过简单定义一个 yaml,并通过实现调用代码来接入。 #### [增加自定义模型 👈🏻](./customizable_model_scale_out.md) + 对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。 ---- +______________________________________________________________________ ### 测试 diff --git a/api/core/model_runtime/docs/zh_Hans/schema.md b/api/core/model_runtime/docs/zh_Hans/schema.md index 681f49c435..e68cb500e1 100644 --- a/api/core/model_runtime/docs/zh_Hans/schema.md +++ b/api/core/model_runtime/docs/zh_Hans/schema.md @@ -16,9 +16,9 @@ - `zh_Hans` (string) [optional] 中文描述 - `en_US` (string) 英文描述 - `icon_small` (string) [optional] 供应商小 ICON,存储在对应供应商实现目录下的 `_assets` 目录,中英文策略同 `label` - - `zh_Hans` (string) [optional] 中文 ICON + - `zh_Hans` (string) [optional] 中文 ICON - `en_US` (string) 英文 ICON -- `icon_large` (string) [optional] 供应商大 ICON,存储在对应供应商实现目录下的 _assets 目录,中英文策略同 label +- `icon_large` (string) [optional] 供应商大 ICON,存储在对应供应商实现目录下的 \_assets 目录,中英文策略同 label - `zh_Hans `(string) [optional] 中文 ICON - `en_US` (string) 英文 ICON - `background` (string) [optional] 背景颜色色值,例:#FFFFFF,为空则展示前端默认色值。 @@ -29,8 +29,8 @@ - `url` (object) 帮助链接,i18n - `zh_Hans` (string) [optional] 中文链接 - `en_US` (string) 英文链接 -- `supported_model_types` (array[[ModelType](#ModelType)]) 支持的模型类型 -- `configurate_methods` (array[[ConfigurateMethod](#ConfigurateMethod)]) 配置方式 +- `supported_model_types` (array\[[ModelType](#ModelType)\]) 支持的模型类型 +- `configurate_methods` (array\[[ConfigurateMethod](#ConfigurateMethod)\]) 配置方式 - `provider_credential_schema` ([ProviderCredentialSchema](#ProviderCredentialSchema)) 供应商凭据规格 - `model_credential_schema` ([ModelCredentialSchema](#ModelCredentialSchema)) 模型凭据规格 @@ -41,23 +41,23 @@ - `zh_Hans `(string) [optional] 中文标签名 - `en_US` (string) 英文标签名 - `model_type` ([ModelType](#ModelType)) 模型类型 -- `features` (array[[ModelFeature](#ModelFeature)]) [optional] 支持功能列表 +- `features` (array\[[ModelFeature](#ModelFeature)\]) [optional] 支持功能列表 - `model_properties` (object) 模型属性 - `mode` ([LLMMode](#LLMMode)) 模式 (模型类型 `llm` 可用) - `context_size` (int) 上下文大小 (模型类型 `llm` `text-embedding` 可用) - `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用) - `file_upload_limit` (int) 文件最大上传限制,单位:MB。(模型类型 `speech2text` 可用) - - `supported_file_extensions` (string) 支持文件扩展格式,如:mp3,mp4(模型类型 `speech2text` 可用) - - `default_voice` (string) 缺省音色,必选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用) - - `voices` (list) 可选音色列表。 - - `mode` (string) 音色模型。(模型类型 `tts` 可用) - - `name` (string) 音色模型显示名称。(模型类型 `tts` 可用) - - `language` (string) 音色模型支持语言。(模型类型 `tts` 可用) - - `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用) - - `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用) - - `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用) - - `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用) -- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] 模型调用参数规则 + - `supported_file_extensions` (string) 支持文件扩展格式,如:mp3,mp4(模型类型 `speech2text` 可用) + - `default_voice` (string) 缺省音色,必选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用) + - `voices` (list) 可选音色列表。 + - `mode` (string) 音色模型。(模型类型 `tts` 可用) + - `name` (string) 音色模型显示名称。(模型类型 `tts` 可用) + - `language` (string) 音色模型支持语言。(模型类型 `tts` 可用) + - `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用) + - `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用) + - `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用) + - `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用) +- `parameter_rules` (array\[[ParameterRule](#ParameterRule)\]) [optional] 模型调用参数规则 - `pricing` ([PriceConfig](#PriceConfig)) [optional] 价格信息 - `deprecated` (bool) 是否废弃。若废弃,模型列表将不再展示,但已经配置的可以继续使用,默认 False。 @@ -75,6 +75,7 @@ - `predefined-model ` 预定义模型 表示用户只需要配置统一的供应商凭据即可使用供应商下的预定义模型。 + - `customizable-model` 自定义模型 用户需要新增每个模型的凭据配置。 @@ -106,7 +107,7 @@ - `name` (string) 调用模型实际参数名 - `use_template` (string) [optional] 使用模板 - + 默认预置了 5 种变量内容配置模板: - `temperature` @@ -114,7 +115,7 @@ - `frequency_penalty` - `presence_penalty` - `max_tokens` - + 可在 use_template 中直接设置模板变量名,将会使用 entities.defaults.PARAMETER_RULE_TEMPLATE 中的默认配置 不用设置除 `name` 和 `use_template` 之外的所有参数,若设置了额外的配置参数,将覆盖默认配置。 可参考 `openai/llm/gpt-3.5-turbo.yaml`。 @@ -157,7 +158,7 @@ ### ProviderCredentialSchema -- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范 +- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范 ### ModelCredentialSchema @@ -168,7 +169,7 @@ - `placeholder` (object) 模型提示内容 - `en_US`(string) 英文 - `zh_Hans`(string) [optional] 中文 -- `credential_form_schemas` (array[[CredentialFormSchema](#CredentialFormSchema)]) 凭据表单规范 +- `credential_form_schemas` (array\[[CredentialFormSchema](#CredentialFormSchema)\]) 凭据表单规范 ### CredentialFormSchema @@ -179,12 +180,12 @@ - `type` ([FormType](#FormType)) 表单项类型 - `required` (bool) 是否必填 - `default`(string) 默认值 -- `options` (array[[FormOption](#FormOption)]) 表单项为 `select` 或 `radio` 专有属性,定义下拉内容 +- `options` (array\[[FormOption](#FormOption)\]) 表单项为 `select` 或 `radio` 专有属性,定义下拉内容 - `placeholder`(object) 表单项为 `text-input `专有属性,表单项 PlaceHolder - `en_US`(string) 英文 - `zh_Hans` (string) [optional] 中文 - `max_length` (int) 表单项为`text-input`专有属性,定义输入最大长度,0 为不限制。 -- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。 +- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。 ### FormType @@ -200,7 +201,7 @@ - `en_US`(string) 英文 - `zh_Hans`(string) [optional] 中文 - `value` (string) 下拉选项值 -- `show_on` (array[[FormShowOnObject](#FormShowOnObject)]) 当其他表单项值符合条件时显示,为空则始终显示。 +- `show_on` (array\[[FormShowOnObject](#FormShowOnObject)\]) 当其他表单项值符合条件时显示,为空则始终显示。 ### FormShowOnObject diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index ace2c1f770..dc6032e405 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from collections.abc import Mapping, Sequence from decimal import Decimal from enum import StrEnum -from typing import Any, Optional +from typing import Any, Optional, TypedDict, Union from pydantic import BaseModel, Field @@ -18,6 +20,26 @@ class LLMMode(StrEnum): CHAT = "chat" +class LLMUsageMetadata(TypedDict, total=False): + """ + TypedDict for LLM usage metadata. + All fields are optional. + """ + + prompt_tokens: int + completion_tokens: int + total_tokens: int + prompt_unit_price: Union[float, str] + completion_unit_price: Union[float, str] + total_price: Union[float, str] + currency: str + prompt_price_unit: Union[float, str] + completion_price_unit: Union[float, str] + prompt_price: Union[float, str] + completion_price: Union[float, str] + latency: float + + class LLMUsage(ModelUsage): """ Model class for llm usage. @@ -54,23 +76,27 @@ class LLMUsage(ModelUsage): ) @classmethod - def from_metadata(cls, metadata: dict) -> "LLMUsage": + def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage: """ Create LLMUsage instance from metadata dictionary with default values. Args: - metadata: Dictionary containing usage metadata + metadata: TypedDict containing usage metadata Returns: LLMUsage instance with values from metadata or defaults """ - total_tokens = metadata.get("total_tokens", 0) + prompt_tokens = metadata.get("prompt_tokens", 0) completion_tokens = metadata.get("completion_tokens", 0) - if total_tokens > 0 and completion_tokens == 0: - completion_tokens = total_tokens + total_tokens = metadata.get("total_tokens", 0) + + # If total_tokens is not provided but prompt and completion tokens are, + # calculate total_tokens + if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0): + total_tokens = prompt_tokens + completion_tokens return cls( - prompt_tokens=metadata.get("prompt_tokens", 0), + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), @@ -84,7 +110,7 @@ class LLMUsage(ModelUsage): latency=metadata.get("latency", 0.0), ) - def plus(self, other: "LLMUsage") -> "LLMUsage": + def plus(self, other: LLMUsage) -> LLMUsage: """ Add two LLMUsage instances together. @@ -109,7 +135,7 @@ class LLMUsage(ModelUsage): latency=self.latency + other.latency, ) - def __add__(self, other: "LLMUsage") -> "LLMUsage": + def __add__(self, other: LLMUsage) -> LLMUsage: """ Overload the + operator to add two LLMUsage instances. diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index e2cc576f83..ce378b443d 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -440,7 +440,9 @@ class LargeLanguageModel(AIModel): if callback.raise_error: raise e else: - logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}") + logger.warning( + "Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e + ) def _trigger_new_chunk_callbacks( self, @@ -487,7 +489,7 @@ class LargeLanguageModel(AIModel): if callback.raise_error: raise e else: - logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}") + logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e) def _trigger_after_invoke_callbacks( self, @@ -535,7 +537,9 @@ class LargeLanguageModel(AIModel): if callback.raise_error: raise e else: - logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}") + logger.warning( + "Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e + ) def _trigger_invoke_error_callbacks( self, @@ -583,4 +587,6 @@ class LargeLanguageModel(AIModel): if callback.raise_error: raise e else: - logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}") + logger.warning( + "Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e + ) diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py similarity index 96% rename from api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py rename to api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py index b7db0b78bc..68d30112d9 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py @@ -1,10 +1,10 @@ import logging from threading import Lock -from typing import Any +from typing import Any, Optional logger = logging.getLogger(__name__) -_tokenizer: Any = None +_tokenizer: Optional[Any] = None _lock = Lock() diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index ad46f64ec3..f8590b38f8 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -257,11 +257,6 @@ class ModelProviderFactory: # scan all providers plugin_model_provider_entities = self.get_plugin_model_providers() - # convert provider_configs to dict - provider_credentials_dict = {} - for provider_config in provider_configs: - provider_credentials_dict[provider_config.provider] = provider_config.credentials - # traverse all model_provider_extensions providers = [] for plugin_model_provider_entity in plugin_model_provider_entities: diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index 810a7c4c44..b689007401 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -68,7 +68,7 @@ class CommonValidator: if credential_form_schema.max_length: if len(value) > credential_form_schema.max_length: raise ValueError( - f"Variable {credential_form_schema.variable} length should not" + f"Variable {credential_form_schema.variable} length should not be" f" greater than {credential_form_schema.max_length}" ) diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index a5c11aeeba..f65339fbfc 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -151,12 +151,9 @@ def jsonable_encoder( return format(obj, "f") if isinstance(obj, dict): encoded_dict = {} - allowed_keys = set(obj.keys()) for key, value in obj.items(): - if ( - (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) - and (value is not None or not exclude_none) - and key in allowed_keys + if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and ( + value is not None or not exclude_none ): encoded_key = jsonable_encoder( key, diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py deleted file mode 100644 index 5e8a723ec7..0000000000 --- a/api/core/model_runtime/utils/helper.py +++ /dev/null @@ -1,10 +0,0 @@ -import pydantic -from pydantic import BaseModel - - -def dump_model(model: BaseModel) -> dict: - if hasattr(pydantic, "model_dump"): - # FIXME mypy error, try to fix it instead of using type: ignore - return pydantic.model_dump(model) # type: ignore - else: - return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 332381555b..af51b72cd5 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token @@ -11,7 +11,7 @@ from models.api_based_extension import APIBasedExtension class ModerationInputParams(BaseModel): app_id: str = "" - inputs: dict = {} + inputs: dict = Field(default_factory=dict) query: str = "" diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index d8c392d097..99bd0049c0 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.extension.extensible import Extensible, ExtensionModule @@ -16,7 +16,7 @@ class ModerationInputsResult(BaseModel): flagged: bool = False action: ModerationAction preset_response: str = "" - inputs: dict = {} + inputs: dict = Field(default_factory=dict) query: str = "" diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 2ec315417f..b39db4b7ff 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -136,6 +136,6 @@ class OutputModeration(BaseModel): result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) return result except Exception as e: - logger.exception(f"Moderation Output error, app_id: {app_id}") + logger.exception("Moderation Output error, app_id: %s", app_id) return None diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index cf367efdf0..82f54582ed 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from typing import Optional from urllib.parse import urljoin -from opentelemetry.trace import Status, StatusCode +from opentelemetry.trace import Link, Status, StatusCode from sqlalchemy.orm import Session, sessionmaker from core.ops.aliyun_trace.data_exporter.traceclient import ( @@ -12,6 +12,7 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( convert_datetime_to_nanoseconds, convert_to_span_id, convert_to_trace_id, + create_link, generate_span_id, ) from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData @@ -97,14 +98,16 @@ class AliyunDataTrace(BaseTraceInstance): try: return self.trace_client.get_project_url() except Exception as e: - logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True) + logger.info("Aliyun get run url failed: %s", str(e), exc_info=True) raise ValueError(f"Aliyun get run url failed: {str(e)}") def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or convert_to_trace_id(trace_info.workflow_run_id) + trace_id = convert_to_trace_id(trace_info.workflow_run_id) + links = [] + if trace_info.trace_id: + links.append(create_link(trace_id_str=trace_info.trace_id)) workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow") - self.add_workflow_span(trace_id, workflow_span_id, trace_info) + self.add_workflow_span(trace_id, workflow_span_id, trace_info, links) workflow_node_executions = self.get_workflow_node_executions(trace_info) for node_execution in workflow_node_executions: @@ -130,6 +133,10 @@ class AliyunDataTrace(BaseTraceInstance): status = Status(StatusCode.ERROR, trace_info.error) trace_id = convert_to_trace_id(message_id) + links = [] + if trace_info.trace_id: + links.append(create_link(trace_id_str=trace_info.trace_id)) + message_span_id = convert_to_span_id(message_id, "message") message_span = SpanData( trace_id=trace_id, @@ -139,7 +146,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_USER_ID: str(user_id), GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, GEN_AI_FRAMEWORK: "dify", @@ -147,6 +154,7 @@ class AliyunDataTrace(BaseTraceInstance): OUTPUT_VALUE: str(trace_info.outputs), }, status=status, + links=links, ) self.trace_client.add_span(message_span) @@ -161,12 +169,12 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_USER_ID: str(user_id), GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, GEN_AI_FRAMEWORK: "dify", - GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""), - GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""), + GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "", + GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "", GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens), GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens), GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens), @@ -186,9 +194,14 @@ class AliyunDataTrace(BaseTraceInstance): return message_id = trace_info.message_id + trace_id = convert_to_trace_id(message_id) + links = [] + if trace_info.trace_id: + links.append(create_link(trace_id_str=trace_info.trace_id)) + documents_data = extract_retrieval_documents(trace_info.documents) dataset_retrieval_span = SpanData( - trace_id=convert_to_trace_id(message_id), + trace_id=trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=generate_span_id(), name="dataset_retrieval", @@ -202,6 +215,7 @@ class AliyunDataTrace(BaseTraceInstance): INPUT_VALUE: str(trace_info.inputs), OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False), }, + links=links, ) self.trace_client.add_span(dataset_retrieval_span) @@ -214,8 +228,13 @@ class AliyunDataTrace(BaseTraceInstance): if trace_info.error: status = Status(StatusCode.ERROR, trace_info.error) + trace_id = convert_to_trace_id(message_id) + links = [] + if trace_info.trace_id: + links.append(create_link(trace_id_str=trace_info.trace_id)) + tool_span = SpanData( - trace_id=convert_to_trace_id(message_id), + trace_id=trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=generate_span_id(), name=trace_info.tool_name, @@ -231,6 +250,7 @@ class AliyunDataTrace(BaseTraceInstance): OUTPUT_VALUE: str(trace_info.tool_outputs), }, status=status, + links=links, ) self.trace_client.add_span(tool_span) @@ -286,7 +306,7 @@ class AliyunDataTrace(BaseTraceInstance): node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) return node_span except Exception as e: - logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True) + logging.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True) return None def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: @@ -386,21 +406,23 @@ class AliyunDataTrace(BaseTraceInstance): GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, GEN_AI_FRAMEWORK: "dify", - GEN_AI_MODEL_NAME: process_data.get("model_name", ""), - GEN_AI_SYSTEM: process_data.get("model_provider", ""), + GEN_AI_MODEL_NAME: process_data.get("model_name") or "", + GEN_AI_SYSTEM: process_data.get("model_provider") or "", GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)), GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)), GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)), GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False), GEN_AI_COMPLETION: str(outputs.get("text", "")), - GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""), + GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason") or "", INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), OUTPUT_VALUE: str(outputs.get("text", "")), }, status=self.get_workflow_node_status(node_execution), ) - def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo): + def add_workflow_span( + self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, links: Sequence[Link] + ): message_span_id = None if trace_info.message_id: message_span_id = convert_to_span_id(trace_info.message_id, "message") @@ -421,10 +443,11 @@ class AliyunDataTrace(BaseTraceInstance): GEN_AI_USER_ID: str(user_id), GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, GEN_AI_FRAMEWORK: "dify", - INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""), + INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query") or "", OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), }, status=status, + links=links, ) self.trace_client.add_span(message_span) @@ -443,6 +466,7 @@ class AliyunDataTrace(BaseTraceInstance): OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False), }, status=status, + links=links, ) self.trace_client.add_span(workflow_span) @@ -451,8 +475,14 @@ class AliyunDataTrace(BaseTraceInstance): status: Status = Status(StatusCode.OK) if trace_info.error: status = Status(StatusCode.ERROR, trace_info.error) + + trace_id = convert_to_trace_id(message_id) + links = [] + if trace_info.trace_id: + links.append(create_link(trace_id_str=trace_info.trace_id)) + suggested_question_span = SpanData( - trace_id=convert_to_trace_id(message_id), + trace_id=trace_id, parent_span_id=convert_to_span_id(message_id, "message"), span_id=convert_to_span_id(message_id, "suggested_question"), name="suggested_question", @@ -461,14 +491,15 @@ class AliyunDataTrace(BaseTraceInstance): attributes={ GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, GEN_AI_FRAMEWORK: "dify", - GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""), - GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""), + GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name") or "", + GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider") or "", GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False), GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False), INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False), OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False), }, status=status, + links=links, ) self.trace_client.add_span(suggested_question_span) diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index ba5ac3f420..3eb7c30d55 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.util.instrumentation import InstrumentationScope from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData @@ -69,10 +70,10 @@ class TraceClient: if response.status_code == 405: return True else: - logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}") + logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) return False except requests.exceptions.RequestException as e: - logger.debug(f"AliyunTrace API check failed: {str(e)}") + logger.debug("AliyunTrace API check failed: %s", str(e)) raise ValueError(f"AliyunTrace API check failed: {str(e)}") def get_project_url(self): @@ -109,7 +110,7 @@ class TraceClient: try: self.exporter.export(spans_to_export) except Exception as e: - logger.debug(f"Error exporting spans: {e}") + logger.debug("Error exporting spans: %s", e) def shutdown(self): with self.condition: @@ -166,6 +167,16 @@ class SpanBuilder: return span +def create_link(trace_id_str: str) -> Link: + placeholder_span_id = 0x0000000000000000 + trace_id = int(trace_id_str, 16) + span_context = SpanContext( + trace_id=trace_id, span_id=placeholder_span_id, is_remote=False, trace_flags=TraceFlags(TraceFlags.SAMPLED) + ) + + return Link(span_context) + + def generate_span_id() -> int: span_id = random.getrandbits(64) while span_id == INVALID_SPAN_ID: @@ -181,15 +192,21 @@ def convert_to_trace_id(uuid_v4: Optional[str]) -> int: raise ValueError(f"Invalid UUID input: {e}") +def convert_string_to_id(string: Optional[str]) -> int: + if not string: + return generate_span_id() + hash_bytes = hashlib.sha256(string.encode("utf-8")).digest() + id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) + return id + + def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int: try: uuid_obj = uuid.UUID(uuid_v4) except Exception as e: raise ValueError(f"Invalid UUID input: {e}") combined_key = f"{uuid_obj.hex}-{span_type}" - hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest() - span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) - return span_id + return convert_string_to_id(combined_key) def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]: diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 1b72a4775a..e7c90c1229 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -4,6 +4,7 @@ import logging import os from datetime import datetime, timedelta from typing import Any, Optional, Union, cast +from urllib.parse import urlparse from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -40,8 +41,14 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra try: # Choose the appropriate exporter based on config type exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter] + + # Inspect the provided endpoint to determine its structure + parsed = urlparse(arize_phoenix_config.endpoint) + base_endpoint = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") + if isinstance(arize_phoenix_config, ArizeConfig): - arize_endpoint = f"{arize_phoenix_config.endpoint}/v1" + arize_endpoint = f"{base_endpoint}/v1" arize_headers = { "api_key": arize_phoenix_config.api_key or "", "space_id": arize_phoenix_config.space_id or "", @@ -53,7 +60,7 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra timeout=30, ) else: - phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces" + phoenix_endpoint = f"{base_endpoint}{path}/v1/traces" phoenix_headers = { "api_key": arize_phoenix_config.api_key or "", "authorization": f"Bearer {arize_phoenix_config.api_key or ''}", @@ -77,10 +84,10 @@ def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[tra # Create a named tracer instead of setting the global provider tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}" - logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}") + logger.info("[Arize/Phoenix] Created tracer with name: %s", tracer_name) return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor except Exception as e: - logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True) + logger.error("[Arize/Phoenix] Failed to setup the tracer: %s", str(e), exc_info=True) raise @@ -91,16 +98,21 @@ def datetime_to_nanos(dt: Optional[datetime]) -> int: return int(dt.timestamp() * 1_000_000_000) -def uuid_to_trace_id(string: Optional[str]) -> int: - """Convert UUID string to a valid trace ID (16-byte integer).""" +def string_to_trace_id128(string: Optional[str]) -> int: + """ + Convert any input string into a stable 128-bit integer trace ID. + + This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest. + It's suitable for generating consistent, unique identifiers from strings. + """ if string is None: string = "" hash_object = hashlib.sha256(string.encode()) - # Take the first 16 bytes (128 bits) of the hash + # Take the first 16 bytes (128 bits) of the hash digest digest = hash_object.digest()[:16] - # Convert to integer (128 bits) + # Convert to a 128-bit integer return int.from_bytes(digest, byteorder="big") @@ -120,7 +132,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") def trace(self, trace_info: BaseTraceInfo): - logger.info(f"[Arize/Phoenix] Trace: {trace_info}") + logger.info("[Arize/Phoenix] Trace: %s", trace_info) try: if isinstance(trace_info, WorkflowTraceInfo): self.workflow_trace(trace_info) @@ -138,7 +150,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) except Exception as e: - logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True) + logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True) raise def workflow_trace(self, trace_info: WorkflowTraceInfo): @@ -153,8 +165,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } workflow_metadata.update(trace_info.metadata) - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or uuid_to_trace_id(trace_info.workflow_run_id) + trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -310,7 +321,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, } - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id) message_span_id = RandomIdGenerator().generate_span_id() span_context = SpanContext( trace_id=trace_id, @@ -406,7 +417,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -468,7 +479,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -521,7 +532,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -568,9 +579,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False), } - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) tool_span_id = RandomIdGenerator().generate_span_id() - logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}") + logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id) # Create span context with the same trace_id as the parent # todo: Create with the appropriate parent span context, so that the tool span is @@ -629,7 +640,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = string_to_trace_id128(trace_info.message_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -673,7 +684,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): span.set_attribute("test", "true") return True except Exception as e: - logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True) + logger.info("[Arize/Phoenix] API check failed: %s", str(e), exc_info=True) raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}") def get_project_url(self): @@ -683,7 +694,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): else: return f"{self.arize_phoenix_config.endpoint}/projects/" except Exception as e: - logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True) + logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True) raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}") def _get_workflow_nodes(self, workflow_run_id: str): diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 89ff0cfded..851a77fbc1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -87,7 +87,7 @@ class PhoenixConfig(BaseTracingConfig): @field_validator("endpoint") @classmethod def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com") + return validate_url_with_path(v, "https://app.phoenix.arize.com") class LangfuseConfig(BaseTracingConfig): @@ -102,7 +102,7 @@ class LangfuseConfig(BaseTracingConfig): @field_validator("host") @classmethod def host_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://api.langfuse.com") + return validate_url_with_path(v, "https://api.langfuse.com") class LangSmithConfig(BaseTracingConfig): diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 151fa2aaf4..3bad5c92fb 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -14,6 +14,7 @@ class BaseTraceInfo(BaseModel): start_time: Optional[datetime] = None end_time: Optional[datetime] = None metadata: dict[str, Any] + trace_id: Optional[str] = None @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index f4a59ef3a7..3a03d9f4fe 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -67,14 +67,13 @@ class LangFuseDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or trace_info.workflow_run_id + trace_id = trace_info.trace_id or trace_info.workflow_run_id user_id = trace_info.metadata.get("user_id") metadata = trace_info.metadata metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id if trace_info.message_id: - trace_id = external_trace_id or trace_info.message_id + trace_id = trace_info.trace_id or trace_info.message_id name = TraceTaskName.MESSAGE_TRACE.value trace_data = LangfuseTrace( id=trace_id, @@ -250,8 +249,10 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = end_user_data.session_id metadata["user_id"] = user_id + trace_id = trace_info.trace_id or message_id + trace_data = LangfuseTrace( - id=message_id, + id=trace_id, user_id=user_id, name=TraceTaskName.MESSAGE_TRACE.value, input={ @@ -285,7 +286,7 @@ class LangFuseDataTrace(BaseTraceInstance): langfuse_generation_data = LangfuseGeneration( name="llm", - trace_id=message_id, + trace_id=trace_id, start_time=trace_info.start_time, end_time=trace_info.end_time, model=message_data.model_id, @@ -311,7 +312,7 @@ class LangFuseDataTrace(BaseTraceInstance): "preset_response": trace_info.preset_response, "inputs": trace_info.inputs, }, - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.created_at, metadata=trace_info.metadata, @@ -334,7 +335,7 @@ class LangFuseDataTrace(BaseTraceInstance): name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, input=trace_info.inputs, output=str(trace_info.suggested_question), - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, @@ -352,7 +353,7 @@ class LangFuseDataTrace(BaseTraceInstance): name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, input=trace_info.inputs, output={"documents": trace_info.documents}, - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, metadata=trace_info.metadata, @@ -365,7 +366,7 @@ class LangFuseDataTrace(BaseTraceInstance): name=trace_info.tool_name, input=trace_info.tool_inputs, output=trace_info.tool_outputs, - trace_id=trace_info.message_id, + trace_id=trace_info.trace_id or trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, @@ -440,7 +441,7 @@ class LangFuseDataTrace(BaseTraceInstance): try: return self.langfuse_client.auth_check() except Exception as e: - logger.debug(f"LangFuse API check failed: {str(e)}") + logger.debug("LangFuse API check failed: %s", str(e)) raise ValueError(f"LangFuse API check failed: {str(e)}") def get_project_key(self): @@ -448,5 +449,5 @@ class LangFuseDataTrace(BaseTraceInstance): projects = self.langfuse_client.client.projects.get() return projects.data[0].id except Exception as e: - logger.debug(f"LangFuse get project key failed: {str(e)}") + logger.debug("LangFuse get project key failed: %s", str(e)) raise ValueError(f"LangFuse get project key failed: {str(e)}") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index c97846dc9b..f9e5128e89 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -65,8 +65,7 @@ class LangSmithDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id + trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id if trace_info.start_time is None: trace_info.start_time = datetime.now() message_dotted_order = ( @@ -290,7 +289,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, parent_run_id=None, ) @@ -319,7 +318,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, id=str(uuid.uuid4()), ) @@ -351,7 +350,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -381,7 +380,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -410,7 +409,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -440,7 +439,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error=trace_info.error or "", ) @@ -465,7 +464,7 @@ class LangSmithDataTrace(BaseTraceInstance): reference_example_id=None, input_attachments={}, output_attachments={}, - trace_id=None, + trace_id=trace_info.trace_id, dotted_order=None, error="", file_list=[], @@ -504,7 +503,7 @@ class LangSmithDataTrace(BaseTraceInstance): self.langsmith_client.delete_project(project_name=random_project_name) return True except Exception as e: - logger.debug(f"LangSmith API check failed: {str(e)}") + logger.debug("LangSmith API check failed: %s", str(e)) raise ValueError(f"LangSmith API check failed: {str(e)}") def get_project_url(self): @@ -523,5 +522,5 @@ class LangSmithDataTrace(BaseTraceInstance): ) return project_url.split("/r/")[0] except Exception as e: - logger.debug(f"LangSmith get run url failed: {str(e)}") + logger.debug("LangSmith get run url failed: %s", str(e)) raise ValueError(f"LangSmith get run url failed: {str(e)}") diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 6079b2faef..dd6a424ddb 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -96,8 +96,7 @@ class OpikDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - dify_trace_id = external_trace_id or trace_info.workflow_run_id + dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) workflow_metadata = wrap_metadata( trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id @@ -105,7 +104,7 @@ class OpikDataTrace(BaseTraceInstance): root_span_id = None if trace_info.message_id: - dify_trace_id = external_trace_id or trace_info.message_id + dify_trace_id = trace_info.trace_id or trace_info.message_id opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) trace_data = { @@ -276,7 +275,7 @@ class OpikDataTrace(BaseTraceInstance): return metadata = trace_info.metadata - message_id = trace_info.message_id + dify_trace_id = trace_info.trace_id or trace_info.message_id user_id = message_data.from_account_id metadata["user_id"] = user_id @@ -291,7 +290,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["end_user_id"] = end_user_id trace_data = { - "id": prepare_opik_uuid(trace_info.start_time, message_id), + "id": prepare_opik_uuid(trace_info.start_time, dify_trace_id), "name": TraceTaskName.MESSAGE_TRACE.value, "start_time": trace_info.start_time, "end_time": trace_info.end_time, @@ -330,7 +329,7 @@ class OpikDataTrace(BaseTraceInstance): start_time = trace_info.start_time or trace_info.message_data.created_at span_data = { - "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.MODERATION_TRACE.value, "type": "tool", "start_time": start_time, @@ -356,7 +355,7 @@ class OpikDataTrace(BaseTraceInstance): start_time = trace_info.start_time or message_data.created_at span_data = { - "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value, "type": "tool", "start_time": start_time, @@ -376,7 +375,7 @@ class OpikDataTrace(BaseTraceInstance): start_time = trace_info.start_time or trace_info.message_data.created_at span_data = { - "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value, "type": "tool", "start_time": start_time, @@ -391,7 +390,7 @@ class OpikDataTrace(BaseTraceInstance): def tool_trace(self, trace_info: ToolTraceInfo): span_data = { - "trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id), + "trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id), "name": trace_info.tool_name, "type": "tool", "start_time": trace_info.start_time, @@ -406,7 +405,7 @@ class OpikDataTrace(BaseTraceInstance): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): trace_data = { - "id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id), + "id": prepare_opik_uuid(trace_info.start_time, trace_info.trace_id or trace_info.message_id), "name": TraceTaskName.GENERATE_NAME_TRACE.value, "start_time": trace_info.start_time, "end_time": trace_info.end_time, @@ -453,12 +452,12 @@ class OpikDataTrace(BaseTraceInstance): self.opik_client.auth_check() return True except Exception as e: - logger.info(f"Opik API check failed: {str(e)}", exc_info=True) + logger.info("Opik API check failed: %s", str(e), exc_info=True) raise ValueError(f"Opik API check failed: {str(e)}") def get_project_url(self): try: return self.opik_client.get_project_url(project_name=self.project) except Exception as e: - logger.info(f"Opik get run url failed: {str(e)}", exc_info=True) + logger.info("Opik get run url failed: %s", str(e), exc_info=True) raise ValueError(f"Opik get run url failed: {str(e)}") diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 2b546b47cc..7eb5da7e3a 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -287,7 +287,7 @@ class OpsTraceManager: # create new tracing_instance and update the cache if it absent tracing_instance = trace_instance(config_class(**decrypt_trace_config)) cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance - logging.info(f"new tracing_instance for app_id: {app_id}") + logging.info("new tracing_instance for app_id: %s", app_id) return tracing_instance @classmethod @@ -322,7 +322,7 @@ class OpsTraceManager: :return: """ # auth check - if enabled == True: + if enabled: try: provider_config_map[tracing_provider] except KeyError: @@ -422,8 +422,11 @@ class TraceTask: self.timer = timer self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.app_id = None - + self.trace_id = None self.kwargs = kwargs + external_trace_id = kwargs.get("external_trace_id") + if external_trace_id: + self.trace_id = external_trace_id def execute(self): return self.preprocess() @@ -520,11 +523,8 @@ class TraceTask: "app_id": workflow_run.app_id, } - external_trace_id = self.kwargs.get("external_trace_id") - if external_trace_id: - metadata["external_trace_id"] = external_trace_id - workflow_trace_info = WorkflowTraceInfo( + trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), conversation_id=conversation_id, workflow_id=workflow_id, @@ -584,6 +584,7 @@ class TraceTask: message_tokens = message_data.message_tokens message_trace_info = MessageTraceInfo( + trace_id=self.trace_id, message_id=message_id, message_data=message_data.to_dict(), conversation_model=conversation_mode, @@ -627,6 +628,7 @@ class TraceTask: workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( + trace_id=self.trace_id, message_id=workflow_app_log_id or message_id, inputs=inputs, message_data=message_data.to_dict(), @@ -667,6 +669,7 @@ class TraceTask: workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( + trace_id=self.trace_id, message_id=workflow_app_log_id or message_id, message_data=message_data.to_dict(), inputs=message_data.message, @@ -708,6 +711,7 @@ class TraceTask: } dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( + trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, documents=[doc.model_dump() for doc in documents] if documents else [], @@ -772,6 +776,7 @@ class TraceTask: ) tool_trace_info = ToolTraceInfo( + trace_id=self.trace_id, message_id=message_id, message_data=message_data.to_dict(), tool_name=tool_name, @@ -807,6 +812,7 @@ class TraceTask: } generate_name_trace_info = GenerateNameTraceInfo( + trace_id=self.trace_id, conversation_id=conversation_id, inputs=inputs, outputs=generate_conversation_name, @@ -843,7 +849,7 @@ class TraceQueueManager: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception as e: - logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}") + logging.exception("Error adding trace task, trace_type %s", trace_task.trace_type) finally: self.start_timer() diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 573e8cac88..2c0afb1600 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -67,7 +67,13 @@ def generate_dotted_order( def validate_url(url: str, default_url: str, allowed_schemes: tuple = ("https", "http")) -> str: """ - Validate and normalize URL with proper error handling + Validate and normalize URL with proper error handling. + + NOTE: This function does not retain the `path` component of the provided URL. + In most cases, it is recommended to use `validate_url_with_path` instead. + + This function is deprecated and retained only for compatibility purposes. + New implementations should use `validate_url_with_path`. Args: url: The URL to validate diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index a34b3b780c..8089860481 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -66,11 +66,11 @@ class WeaveDataTrace(BaseTraceInstance): project_url = f"https://wandb.ai/{self.weave_client._project_id()}" return project_url except Exception as e: - logger.debug(f"Weave get run url failed: {str(e)}") + logger.debug("Weave get run url failed: %s", str(e)) raise ValueError(f"Weave get run url failed: {str(e)}") def trace(self, trace_info: BaseTraceInfo): - logger.debug(f"Trace info: {trace_info}") + logger.debug("Trace info: %s", trace_info) if isinstance(trace_info, WorkflowTraceInfo): self.workflow_trace(trace_info) if isinstance(trace_info, MessageTraceInfo): @@ -87,8 +87,7 @@ class WeaveDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - external_trace_id = trace_info.metadata.get("external_trace_id") - trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id + trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id if trace_info.start_time is None: trace_info.start_time = datetime.now() @@ -245,8 +244,12 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = trace_info.start_time attributes["end_time"] = trace_info.end_time attributes["tags"] = ["message", str(trace_info.conversation_mode)] + + trace_id = trace_info.trace_id or message_id + attributes["trace_id"] = trace_id + message_run = WeaveTraceModel( - id=message_id, + id=trace_id, op=str(TraceTaskName.MESSAGE_TRACE.value), input_tokens=trace_info.message_tokens, output_tokens=trace_info.answer_tokens, @@ -274,7 +277,7 @@ class WeaveDataTrace(BaseTraceInstance): ) self.start_call( llm_run, - parent_run_id=message_id, + parent_run_id=trace_id, ) self.finish_call(llm_run) self.finish_call(message_run) @@ -289,6 +292,9 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at + trace_id = trace_info.trace_id or trace_info.message_id + attributes["trace_id"] = trace_id + moderation_run = WeaveTraceModel( id=str(uuid.uuid4()), op=str(TraceTaskName.MODERATION_TRACE.value), @@ -303,7 +309,7 @@ class WeaveDataTrace(BaseTraceInstance): exception=getattr(trace_info, "error", None), file_list=[], ) - self.start_call(moderation_run, parent_run_id=trace_info.message_id) + self.start_call(moderation_run, parent_run_id=trace_id) self.finish_call(moderation_run) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): @@ -316,6 +322,9 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = (trace_info.start_time or message_data.created_at,) attributes["end_time"] = (trace_info.end_time or message_data.updated_at,) + trace_id = trace_info.trace_id or trace_info.message_id + attributes["trace_id"] = trace_id + suggested_question_run = WeaveTraceModel( id=str(uuid.uuid4()), op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value), @@ -326,7 +335,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(suggested_question_run, parent_run_id=trace_info.message_id) + self.start_call(suggested_question_run, parent_run_id=trace_id) self.finish_call(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): @@ -338,6 +347,9 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,) attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,) + trace_id = trace_info.trace_id or trace_info.message_id + attributes["trace_id"] = trace_id + dataset_retrieval_run = WeaveTraceModel( id=str(uuid.uuid4()), op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value), @@ -348,7 +360,7 @@ class WeaveDataTrace(BaseTraceInstance): file_list=[], ) - self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id) + self.start_call(dataset_retrieval_run, parent_run_id=trace_id) self.finish_call(dataset_retrieval_run) def tool_trace(self, trace_info: ToolTraceInfo): @@ -357,6 +369,11 @@ class WeaveDataTrace(BaseTraceInstance): attributes["start_time"] = trace_info.start_time attributes["end_time"] = trace_info.end_time + message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None) + message_id = message_id or None + trace_id = trace_info.trace_id or message_id + attributes["trace_id"] = trace_id + tool_run = WeaveTraceModel( id=str(uuid.uuid4()), op=trace_info.tool_name, @@ -366,9 +383,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes=attributes, exception=trace_info.error, ) - message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None) - message_id = message_id or None - self.start_call(tool_run, parent_run_id=message_id) + self.start_call(tool_run, parent_run_id=trace_id) self.finish_call(tool_run) def generate_name_trace(self, trace_info: GenerateNameTraceInfo): @@ -403,7 +418,7 @@ class WeaveDataTrace(BaseTraceInstance): print("Weave login successful") return True except Exception as e: - logger.debug(f"Weave API check failed: {str(e)}") + logger.debug("Weave API check failed: %s", str(e)) raise ValueError(f"Weave API check failed: {str(e)}") def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7375726fa9..6f32498b42 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -208,6 +208,7 @@ class BasePluginClient: except Exception: raise PluginDaemonInnerError(code=rep.code, message=rep.message) + logger.error("Error in stream reponse for plugin %s", rep.__dict__) self._handle_plugin_daemon_error(error.error_type, error.message) raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}") if rep.data is None: diff --git a/api/core/plugin/impl/exc.py b/api/core/plugin/impl/exc.py index 8b660c807d..8ecc2e2147 100644 --- a/api/core/plugin/impl/exc.py +++ b/api/core/plugin/impl/exc.py @@ -2,6 +2,8 @@ from collections.abc import Mapping from pydantic import TypeAdapter +from extensions.ext_logging import get_request_id + class PluginDaemonError(Exception): """Base class for all plugin daemon errors.""" @@ -11,7 +13,7 @@ class PluginDaemonError(Exception): def __str__(self) -> str: # returns the class name and description - return f"{self.__class__.__name__}: {self.description}" + return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}" class PluginDaemonInternalError(PluginDaemonError): diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0f0fe65f27..16c145f936 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -125,11 +125,11 @@ class AdvancedPromptTransform(PromptTransform): if files: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: @@ -196,16 +196,17 @@ class AdvancedPromptTransform(PromptTransform): query = parser.format(prompt_inputs) + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) if files and query is not None: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=query)) for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=query)) @@ -215,27 +216,27 @@ class AdvancedPromptTransform(PromptTransform): last_message = prompt_messages[-1] if prompt_messages else None if last_message and last_message.role == PromptMessageRole.USER: # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))] for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=cast(str, last_message.content))) last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data="")) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: - prompt_message_contents = [TextPromptMessageContent(data=query)] for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) elif query: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index e19c6419ca..13f4163d80 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -265,11 +265,11 @@ class SimplePromptTransform(PromptTransform): ) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContentUnionTypes] = [] - prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_message = UserPromptMessage(content=prompt_message_contents) else: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6de4f3a303..39fec951bb 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,3 +1,4 @@ +import contextlib import json from collections import defaultdict from json import JSONDecodeError @@ -523,7 +524,7 @@ class ProviderManager: # Init trial provider records if not exists if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: - # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic + # FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic new_provider_record = Provider( tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. @@ -624,14 +625,12 @@ class ProviderManager: for variable in provider_credential_secret_variables: if variable in provider_credentials: - try: + with contextlib.suppress(ValueError): provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_credentials.get(variable) or "", # type: ignore self.decoding_rsa_key, self.decoding_cipher_rsa, ) - except ValueError: - pass # cache provider credentials provider_credentials_cache.set(credentials=provider_credentials) @@ -672,14 +671,12 @@ class ProviderManager: for variable in model_credential_secret_variables: if variable in provider_model_credentials: - try: + with contextlib.suppress(ValueError): provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa, ) - except ValueError: - pass # cache provider model credentials provider_model_credentials_cache.set(credentials=provider_model_credentials) diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index ec3a23bd96..c98306ea4b 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -1,7 +1,7 @@ -import json from collections import defaultdict from typing import Any, Optional +import orjson from pydantic import BaseModel from configs import dify_config @@ -24,7 +24,7 @@ class Jieba(BaseKeyword): self._config = KeywordTableConfig() def create(self, texts: list[Document], **kwargs) -> BaseKeyword: - lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) + lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() @@ -43,7 +43,7 @@ class Jieba(BaseKeyword): return self def add_texts(self, texts: list[Document], **kwargs): - lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) + lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() @@ -76,7 +76,7 @@ class Jieba(BaseKeyword): return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: - lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) + lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() if keyword_table is not None: @@ -116,7 +116,7 @@ class Jieba(BaseKeyword): return documents def delete(self) -> None: - lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) + lock_name = f"keyword_indexing_lock_{self.dataset.id}" with redis_client.lock(lock_name, timeout=600): dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: @@ -134,13 +134,13 @@ class Jieba(BaseKeyword): dataset_keyword_table = self.dataset.dataset_keyword_table keyword_data_source_type = dataset_keyword_table.data_source_type if keyword_data_source_type == "database": - dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) + dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict) db.session.commit() else: file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8")) + storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8")) def _get_dataset_keyword_table(self) -> Optional[dict]: dataset_keyword_table = self.dataset.dataset_keyword_table @@ -156,12 +156,11 @@ class Jieba(BaseKeyword): data_source_type=keyword_data_source_type, ) if keyword_data_source_type == "database": - dataset_keyword_table.keyword_table = json.dumps( + dataset_keyword_table.keyword_table = dumps_with_sets( { "__type__": "keyword_table", "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, - }, - cls=SetEncoder, + } ) db.session.add(dataset_keyword_table) db.session.commit() @@ -252,8 +251,13 @@ class Jieba(BaseKeyword): self._save_dataset_keyword_table(keyword_table) -class SetEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, set): - return list(obj) - return super().default(obj) +def set_orjson_default(obj: Any) -> Any: + """Default function for orjson serialization of set types""" + if isinstance(obj, set): + return list(obj) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def dumps_with_sets(obj: Any) -> str: + """JSON dumps with set support using orjson""" + return orjson.dumps(obj, default=set_orjson_default).decode("utf-8") diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 14481b1f10..bb61b71bb1 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -98,18 +98,26 @@ class AnalyticdbVectorBySql: try: cur.execute(f"CREATE DATABASE {self.databaseName}") except Exception as e: - if "already exists" in str(e): - return - raise e + if "already exists" not in str(e): + raise e finally: cur.close() conn.close() self.pool = self._create_connection_pool() with self._get_cursor() as cur: + conn = cur.connection + try: + cur.execute("CREATE EXTENSION IF NOT EXISTS zhparser;") + except Exception as e: + conn.rollback() + raise RuntimeError( + "Failed to create zhparser extension. Please ensure it is available in your AnalyticDB." + ) from e try: cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)") cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple") except Exception as e: + conn.rollback() if "already exists" not in str(e): raise e cur.execute( diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index db7ffc9c4f..d63ca9f695 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -203,9 +203,9 @@ class BaiduVector(BaseVector): def _create_table(self, dimension: int) -> None: # Try to grab distributed lock and create table - lock_name = "vector_indexing_lock_{}".format(self._collection_name) + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=60): - table_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + table_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(table_exist_cache_key): return diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index b8b265d5e6..699a602365 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -57,9 +57,9 @@ class ChromaVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str): - lock_name = "vector_indexing_lock_{}".format(collection_name) + lock_name = f"vector_indexing_lock_{collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return self._client.get_or_create_collection(collection_name) diff --git a/api/core/rag/datasource/vdb/clickzetta/README.md b/api/core/rag/datasource/vdb/clickzetta/README.md new file mode 100644 index 0000000000..969d4e40a0 --- /dev/null +++ b/api/core/rag/datasource/vdb/clickzetta/README.md @@ -0,0 +1,201 @@ +# Clickzetta Vector Database Integration + +This module provides integration with Clickzetta Lakehouse as a vector database for Dify. + +## Features + +- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type +- **Vector Search**: Efficient similarity search using HNSW algorithm +- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities +- **Hybrid Search**: Combine vector similarity and full-text search for better results +- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing +- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments + +## Configuration + +### Required Environment Variables + +All seven configuration parameters are required: + +```bash +# Authentication +CLICKZETTA_USERNAME=your_username +CLICKZETTA_PASSWORD=your_password + +# Instance configuration +CLICKZETTA_INSTANCE=your_instance_id +CLICKZETTA_SERVICE=api.clickzetta.com +CLICKZETTA_WORKSPACE=your_workspace +CLICKZETTA_VCLUSTER=your_vcluster +CLICKZETTA_SCHEMA=your_schema +``` + +### Optional Configuration + +```bash +# Batch processing +CLICKZETTA_BATCH_SIZE=100 + +# Full-text search configuration +CLICKZETTA_ENABLE_INVERTED_INDEX=true +CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode +CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart + +# Vector search configuration +CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance +``` + +## Usage + +### 1. Set Clickzetta as the Vector Store + +In your Dify configuration, set: + +```bash +VECTOR_STORE=clickzetta +``` + +### 2. Table Structure + +Clickzetta will automatically create tables with the following structure: + +```sql +CREATE TABLE ( + id STRING NOT NULL, + content STRING NOT NULL, + metadata JSON, + vector VECTOR(FLOAT, ) NOT NULL, + PRIMARY KEY (id) +); + +-- Vector index for similarity search +CREATE VECTOR INDEX idx__vec +ON TABLE .(vector) +PROPERTIES ( + "distance.function" = "cosine_distance", + "scalar.type" = "f32" +); + +-- Inverted index for full-text search (if enabled) +CREATE INVERTED INDEX idx__text +ON .(content) +PROPERTIES ( + "analyzer" = "chinese", + "mode" = "smart" +); +``` + +## Full-Text Search Capabilities + +Clickzetta supports advanced full-text search with multiple analyzers: + +### Analyzer Types + +1. **keyword**: No tokenization, treats the entire string as a single token + + - Best for: Exact matching, IDs, codes + +1. **english**: Designed for English text + + - Features: Recognizes ASCII letters and numbers, converts to lowercase + - Best for: English content + +1. **chinese**: Chinese text tokenizer + + - Features: Recognizes Chinese and English characters, removes punctuation + - Best for: Chinese or mixed Chinese-English content + +1. **unicode**: Multi-language tokenizer based on Unicode + + - Features: Recognizes text boundaries in multiple languages + - Best for: Multi-language content + +### Analyzer Modes + +- **max_word**: Fine-grained tokenization (more tokens) +- **smart**: Intelligent tokenization (balanced) + +### Full-Text Search Functions + +- `MATCH_ALL(column, query)`: All terms must be present +- `MATCH_ANY(column, query)`: At least one term must be present +- `MATCH_PHRASE(column, query)`: Exact phrase matching +- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching +- `MATCH_REGEXP(column, pattern)`: Regular expression matching + +## Performance Optimization + +### Vector Search + +1. **Adjust exploration factor** for accuracy vs speed trade-off: + + ```sql + SET cz.vector.index.search.ef=64; + ``` + +1. **Use appropriate distance functions**: + + - `cosine_distance`: Best for normalized embeddings (e.g., from language models) + - `l2_distance`: Best for raw feature vectors + +### Full-Text Search + +1. **Choose the right analyzer**: + + - Use `keyword` for exact matching + - Use language-specific analyzers for better tokenization + +1. **Combine with vector search**: + + - Pre-filter with full-text search for better performance + - Use hybrid search for improved relevance + +## Troubleshooting + +### Connection Issues + +1. Verify all 7 required configuration parameters are set +1. Check network connectivity to Clickzetta service +1. Ensure the user has proper permissions on the schema + +### Search Performance + +1. Verify vector index exists: + + ```sql + SHOW INDEX FROM .; + ``` + +1. Check if vector index is being used: + + ```sql + EXPLAIN SELECT ... WHERE l2_distance(...) < threshold; + ``` + + Look for `vector_index_search_type` in the execution plan. + +### Full-Text Search Not Working + +1. Verify inverted index is created +1. Check analyzer configuration matches your content language +1. Use `TOKENIZE()` function to test tokenization: + ```sql + SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart')); + ``` + +## Limitations + +1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns +1. Full-text search relevance scores are not provided by Clickzetta +1. Inverted index creation may fail for very large existing tables (continue without error) +1. Index naming constraints: + - Index names must be unique within a schema + - Only one vector index can be created per column + - The implementation uses timestamps to ensure unique index names +1. A column can only have one vector index at a time + +## References + +- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search) +- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index) +- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference) diff --git a/api/core/rag/datasource/vdb/clickzetta/__init__.py b/api/core/rag/datasource/vdb/clickzetta/__init__.py new file mode 100644 index 0000000000..9d41c5a57d --- /dev/null +++ b/api/core/rag/datasource/vdb/clickzetta/__init__.py @@ -0,0 +1 @@ +# Clickzetta Vector Database Integration for Dify diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py new file mode 100644 index 0000000000..6e8077ffd9 --- /dev/null +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -0,0 +1,1077 @@ +import contextlib +import json +import logging +import queue +import re +import threading +import time +import uuid +from typing import TYPE_CHECKING, Any, Optional + +import clickzetta # type: ignore +from pydantic import BaseModel, model_validator + +if TYPE_CHECKING: + from clickzetta import Connection + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +# ClickZetta Lakehouse Vector Database Configuration + + +class ClickzettaConfig(BaseModel): + """ + Configuration class for Clickzetta connection. + """ + + username: str + password: str + instance: str + service: str = "api.clickzetta.com" + workspace: str = "quick_start" + vcluster: str = "default_ap" + schema_name: str = "dify" # Renamed to avoid shadowing BaseModel.schema + # Advanced settings + batch_size: int = 20 # Reduced batch size to avoid large SQL statements + enable_inverted_index: bool = True # Enable inverted index for full-text search + analyzer_type: str = "chinese" # Analyzer type for full-text search: keyword, english, chinese, unicode + analyzer_mode: str = "smart" # Analyzer mode: max_word, smart + vector_distance_function: str = "cosine_distance" # l2_distance or cosine_distance + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """ + Validate the configuration values. + """ + if not values.get("username"): + raise ValueError("config CLICKZETTA_USERNAME is required") + if not values.get("password"): + raise ValueError("config CLICKZETTA_PASSWORD is required") + if not values.get("instance"): + raise ValueError("config CLICKZETTA_INSTANCE is required") + if not values.get("service"): + raise ValueError("config CLICKZETTA_SERVICE is required") + if not values.get("workspace"): + raise ValueError("config CLICKZETTA_WORKSPACE is required") + if not values.get("vcluster"): + raise ValueError("config CLICKZETTA_VCLUSTER is required") + if not values.get("schema_name"): + raise ValueError("config CLICKZETTA_SCHEMA is required") + return values + + +class ClickzettaConnectionPool: + """ + Global connection pool for ClickZetta connections. + Manages connection reuse across ClickzettaVector instances. + """ + + _instance: Optional["ClickzettaConnectionPool"] = None + _lock = threading.Lock() + + def __init__(self): + self._pools: dict[str, list[tuple[Connection, float]]] = {} # config_key -> [(connection, last_used_time)] + self._pool_locks: dict[str, threading.Lock] = {} + self._max_pool_size = 5 # Maximum connections per configuration + self._connection_timeout = 300 # 5 minutes timeout + self._cleanup_thread: Optional[threading.Thread] = None + self._shutdown = False + self._start_cleanup_thread() + + @classmethod + def get_instance(cls) -> "ClickzettaConnectionPool": + """Get singleton instance of connection pool.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _get_config_key(self, config: ClickzettaConfig) -> str: + """Generate unique key for connection configuration.""" + return ( + f"{config.username}:{config.instance}:{config.service}:" + f"{config.workspace}:{config.vcluster}:{config.schema_name}" + ) + + def _create_connection(self, config: ClickzettaConfig) -> "Connection": + """Create a new ClickZetta connection.""" + max_retries = 3 + retry_delay = 1.0 + + for attempt in range(max_retries): + try: + connection = clickzetta.connect( + username=config.username, + password=config.password, + instance=config.instance, + service=config.service, + workspace=config.workspace, + vcluster=config.vcluster, + schema=config.schema_name, + ) + + # Configure connection session settings + self._configure_connection(connection) + logger.debug("Created new ClickZetta connection (attempt %d/%d)", attempt + 1, max_retries) + return connection + except Exception: + logger.exception("ClickZetta connection attempt %d/%d failed", attempt + 1, max_retries) + if attempt < max_retries - 1: + time.sleep(retry_delay * (2**attempt)) + else: + raise + + raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") + + def _configure_connection(self, connection: "Connection") -> None: + """Configure connection session settings.""" + try: + with connection.cursor() as cursor: + # Temporarily suppress ClickZetta client logging to reduce noise + clickzetta_logger = logging.getLogger("clickzetta") + original_level = clickzetta_logger.level + clickzetta_logger.setLevel(logging.WARNING) + + try: + # Use quote mode for string literal escaping + cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") + + # Apply performance optimization hints + performance_hints = [ + # Vector index optimization + "SET cz.storage.parquet.vector.index.read.memory.cache = true", + "SET cz.storage.parquet.vector.index.read.local.cache = false", + # Query optimization + "SET cz.sql.table.scan.push.down.filter = true", + "SET cz.sql.table.scan.enable.ensure.filter = true", + "SET cz.storage.always.prefetch.internal = true", + "SET cz.optimizer.generate.columns.always.valid = true", + "SET cz.sql.index.prewhere.enabled = true", + # Storage optimization + "SET cz.storage.parquet.enable.io.prefetch = false", + "SET cz.optimizer.enable.mv.rewrite = false", + "SET cz.sql.dump.as.lz4 = true", + "SET cz.optimizer.limited.optimization.naive.query = true", + "SET cz.sql.table.scan.enable.push.down.log = false", + "SET cz.storage.use.file.format.local.stats = false", + "SET cz.storage.local.file.object.cache.level = all", + # Job execution optimization + "SET cz.sql.job.fast.mode = true", + "SET cz.storage.parquet.non.contiguous.read = true", + "SET cz.sql.compaction.after.commit = true", + ] + + for hint in performance_hints: + cursor.execute(hint) + finally: + # Restore original logging level + clickzetta_logger.setLevel(original_level) + + except Exception: + logger.exception("Failed to configure connection, continuing with defaults") + + def _is_connection_valid(self, connection: "Connection") -> bool: + """Check if connection is still valid.""" + try: + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + return True + except Exception: + return False + + def get_connection(self, config: ClickzettaConfig) -> "Connection": + """Get a connection from the pool or create a new one.""" + config_key = self._get_config_key(config) + + # Ensure pool lock exists + if config_key not in self._pool_locks: + with self._lock: + if config_key not in self._pool_locks: + self._pool_locks[config_key] = threading.Lock() + self._pools[config_key] = [] + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + current_time = time.time() + + # Try to reuse existing connection + while pool: + connection, last_used = pool.pop(0) + + # Check if connection is not expired and still valid + if current_time - last_used < self._connection_timeout and self._is_connection_valid(connection): + logger.debug("Reusing ClickZetta connection from pool") + return connection + else: + # Connection expired or invalid, close it + with contextlib.suppress(Exception): + connection.close() + + # No valid connection found, create new one + return self._create_connection(config) + + def return_connection(self, config: ClickzettaConfig, connection: "Connection") -> None: + """Return a connection to the pool.""" + config_key = self._get_config_key(config) + + if config_key not in self._pool_locks: + # Pool was cleaned up, just close the connection + with contextlib.suppress(Exception): + connection.close() + return + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + + # Only return to pool if not at capacity and connection is valid + if len(pool) < self._max_pool_size and self._is_connection_valid(connection): + pool.append((connection, time.time())) + logger.debug("Returned ClickZetta connection to pool") + else: + # Pool full or connection invalid, close it + with contextlib.suppress(Exception): + connection.close() + + def _cleanup_expired_connections(self) -> None: + """Clean up expired connections from all pools.""" + current_time = time.time() + + with self._lock: + for config_key in list(self._pools.keys()): + if config_key not in self._pool_locks: + continue + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + valid_connections = [] + + for connection, last_used in pool: + if current_time - last_used < self._connection_timeout: + valid_connections.append((connection, last_used)) + else: + with contextlib.suppress(Exception): + connection.close() + + self._pools[config_key] = valid_connections + + def _start_cleanup_thread(self) -> None: + """Start background thread for connection cleanup.""" + + def cleanup_worker(): + while not self._shutdown: + try: + time.sleep(60) # Cleanup every minute + if not self._shutdown: + self._cleanup_expired_connections() + except Exception: + logger.exception("Error in connection pool cleanup") + + self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) + self._cleanup_thread.start() + + def shutdown(self) -> None: + """Shutdown connection pool and close all connections.""" + self._shutdown = True + + with self._lock: + for config_key in list(self._pools.keys()): + if config_key not in self._pool_locks: + continue + + with self._pool_locks[config_key]: + pool = self._pools[config_key] + for connection, _ in pool: + with contextlib.suppress(Exception): + connection.close() + pool.clear() + + +class ClickzettaVector(BaseVector): + """ + Clickzetta vector storage implementation. + """ + + # Class-level write queue and lock for serializing writes + _write_queue: Optional[queue.Queue] = None + _write_thread: Optional[threading.Thread] = None + _write_lock = threading.Lock() + _shutdown = False + + def __init__(self, collection_name: str, config: ClickzettaConfig): + super().__init__(collection_name) + self._config = config + self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name + self._connection_pool = ClickzettaConnectionPool.get_instance() + self._init_write_queue() + + def _get_connection(self) -> "Connection": + """Get a connection from the pool.""" + return self._connection_pool.get_connection(self._config) + + def _return_connection(self, connection: "Connection") -> None: + """Return a connection to the pool.""" + self._connection_pool.return_connection(self._config, connection) + + class ConnectionContext: + """Context manager for borrowing and returning connections.""" + + def __init__(self, vector_instance: "ClickzettaVector"): + self.vector = vector_instance + self.connection: Optional[Connection] = None + + def __enter__(self) -> "Connection": + self.connection = self.vector._get_connection() + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.connection: + self.vector._return_connection(self.connection) + + def get_connection_context(self) -> "ClickzettaVector.ConnectionContext": + """Get a connection context manager.""" + return self.ConnectionContext(self) + + def _parse_metadata(self, raw_metadata: str, row_id: str) -> dict: + """ + Parse metadata from JSON string with proper error handling and fallback. + + Args: + raw_metadata: Raw JSON string from database + row_id: Row ID for fallback document_id + + Returns: + Parsed metadata dict with guaranteed required fields + """ + try: + if raw_metadata: + metadata = json.loads(raw_metadata) + + # Handle double-encoded JSON + if isinstance(metadata, str): + metadata = json.loads(metadata) + + # Ensure we have a dict + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = {} + except (json.JSONDecodeError, TypeError): + logger.exception("JSON parsing failed for metadata") + # Fallback: extract document_id with regex + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "") + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + + # Ensure required fields are set + metadata["doc_id"] = row_id # segment id + + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row_id # fallback to segment id + + return metadata + + @classmethod + def _init_write_queue(cls): + """Initialize the write queue and worker thread.""" + with cls._write_lock: + if cls._write_queue is None: + cls._write_queue = queue.Queue() + cls._write_thread = threading.Thread(target=cls._write_worker, daemon=True) + cls._write_thread.start() + logger.info("Started Clickzetta write worker thread") + + @classmethod + def _write_worker(cls): + """Worker thread that processes write tasks sequentially.""" + while not cls._shutdown: + try: + # Get task from queue with timeout + if cls._write_queue is not None: + task = cls._write_queue.get(timeout=1) + if task is None: # Shutdown signal + break + + # Execute the write task + func, args, kwargs, result_queue = task + try: + result = func(*args, **kwargs) + result_queue.put((True, result)) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Write task failed") + result_queue.put((False, e)) + finally: + cls._write_queue.task_done() + else: + break + except queue.Empty: + continue + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Write worker error") + + def _execute_write(self, func, *args, **kwargs): + """Execute a write operation through the queue.""" + if ClickzettaVector._write_queue is None: + raise RuntimeError("Write queue not initialized") + + result_queue: queue.Queue[tuple[bool, Any]] = queue.Queue() + ClickzettaVector._write_queue.put((func, args, kwargs, result_queue)) + + # Wait for result + success, result = result_queue.get() + if not success: + raise result + return result + + def get_type(self) -> str: + """Return the vector database type.""" + return "clickzetta" + + def _ensure_connection(self) -> "Connection": + """Get a connection from the pool.""" + return self._get_connection() + + def _table_exists(self) -> bool: + """Check if the table exists.""" + try: + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") + return True + except Exception as e: + error_message = str(e).lower() + # Handle ClickZetta specific "table or view not found" errors + if any( + phrase in error_message + for phrase in ["table or view not found", "czlh-42000", "semantic analysis exception"] + ): + logger.debug("Table %s.%s does not exist", self._config.schema_name, self._table_name) + return False + else: + # For other connection/permission errors, log warning but return False to avoid blocking cleanup + logger.exception( + "Table existence check failed for %s.%s, assuming it doesn't exist", + self._config.schema_name, + self._table_name, + ) + return False + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + """Create the collection and add initial documents.""" + # Execute table creation through write queue to avoid concurrent conflicts + self._execute_write(self._create_table_and_indexes, embeddings) + + # Add initial texts + if texts: + self.add_texts(texts, embeddings, **kwargs) + + def _create_table_and_indexes(self, embeddings: list[list[float]]): + """Create table and indexes (executed in write worker thread).""" + # Check if table already exists to avoid unnecessary index creation + if self._table_exists(): + logger.info("Table %s.%s already exists, skipping creation", self._config.schema_name, self._table_name) + return + + # Create table with vector and metadata columns + dimension = len(embeddings[0]) if embeddings else 768 + + create_table_sql = f""" + CREATE TABLE IF NOT EXISTS {self._config.schema_name}.{self._table_name} ( + id STRING NOT NULL COMMENT 'Unique document identifier', + {Field.CONTENT_KEY.value} STRING NOT NULL COMMENT 'Document text content for search and retrieval', + {Field.METADATA_KEY.value} JSON COMMENT 'Document metadata including source, type, and other attributes', + {Field.VECTOR.value} VECTOR(FLOAT, {dimension}) NOT NULL COMMENT + 'High-dimensional embedding vector for semantic similarity search', + PRIMARY KEY (id) + ) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' + """ + + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(create_table_sql) + logger.info("Created table %s.%s", self._config.schema_name, self._table_name) + + # Create vector index + self._create_vector_index(cursor) + + # Create inverted index for full-text search if enabled + if self._config.enable_inverted_index: + self._create_inverted_index(cursor) + + def _create_vector_index(self, cursor): + """Create HNSW vector index for similarity search.""" + # Use a fixed index name based on table and column name + index_name = f"idx_{self._table_name}_vector" + + # First check if an index already exists on this column + try: + cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") + existing_indexes = cursor.fetchall() + for idx in existing_indexes: + # Check if vector index already exists on the embedding column + if Field.VECTOR.value in str(idx).lower(): + logger.info("Vector index already exists on column %s", Field.VECTOR.value) + return + except (RuntimeError, ValueError) as e: + logger.warning("Failed to check existing indexes: %s", e) + + index_sql = f""" + CREATE VECTOR INDEX IF NOT EXISTS {index_name} + ON TABLE {self._config.schema_name}.{self._table_name}({Field.VECTOR.value}) + PROPERTIES ( + "distance.function" = "{self._config.vector_distance_function}", + "scalar.type" = "f32", + "m" = "16", + "ef.construction" = "128" + ) + """ + try: + cursor.execute(index_sql) + logger.info("Created vector index: %s", index_name) + except (RuntimeError, ValueError) as e: + error_msg = str(e).lower() + if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg: + logger.info("Vector index already exists: %s", e) + else: + logger.exception("Failed to create vector index") + raise + + def _create_inverted_index(self, cursor): + """Create inverted index for full-text search.""" + # Use a fixed index name based on table name to avoid duplicates + index_name = f"idx_{self._table_name}_text" + + # Check if an inverted index already exists on this column + try: + cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") + existing_indexes = cursor.fetchall() + for idx in existing_indexes: + idx_str = str(idx).lower() + # More precise check: look for inverted index specifically on the content column + if ( + "inverted" in idx_str + and Field.CONTENT_KEY.value.lower() in idx_str + and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str) + ): + logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx) + return + except (RuntimeError, ValueError) as e: + logger.warning("Failed to check existing indexes: %s", e) + + index_sql = f""" + CREATE INVERTED INDEX IF NOT EXISTS {index_name} + ON TABLE {self._config.schema_name}.{self._table_name} ({Field.CONTENT_KEY.value}) + PROPERTIES ( + "analyzer" = "{self._config.analyzer_type}", + "mode" = "{self._config.analyzer_mode}" + ) + """ + try: + cursor.execute(index_sql) + logger.info("Created inverted index: %s", index_name) + except (RuntimeError, ValueError) as e: + error_msg = str(e).lower() + # Handle ClickZetta specific error messages + if ( + "already exists" in error_msg + or "already has index" in error_msg + or "with the same type" in error_msg + or "cannot create inverted index" in error_msg + ) and "already has index" in error_msg: + logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value) + # Try to get the existing index name for logging + try: + cursor.execute(f"SHOW INDEX FROM {self._config.schema_name}.{self._table_name}") + existing_indexes = cursor.fetchall() + for idx in existing_indexes: + if "inverted" in str(idx).lower() and Field.CONTENT_KEY.value.lower() in str(idx).lower(): + logger.info("Found existing inverted index: %s", idx) + break + except (RuntimeError, ValueError): + pass + else: + logger.warning("Failed to create inverted index: %s", e) + # Continue without inverted index - full-text search will fall back to LIKE + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + """Add documents with embeddings to the collection.""" + if not documents: + return + + batch_size = self._config.batch_size + total_batches = (len(documents) + batch_size - 1) // batch_size + + for i in range(0, len(documents), batch_size): + batch_docs = documents[i : i + batch_size] + batch_embeddings = embeddings[i : i + batch_size] + + # Execute batch insert through write queue + self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) + + def _insert_batch( + self, + batch_docs: list[Document], + batch_embeddings: list[list[float]], + batch_index: int, + batch_size: int, + total_batches: int, + ): + """Insert a batch of documents using parameterized queries (executed in write worker thread).""" + if not batch_docs or not batch_embeddings: + logger.warning("Empty batch provided, skipping insertion") + return + + if len(batch_docs) != len(batch_embeddings): + logger.error("Mismatch between docs (%d) and embeddings (%d)", len(batch_docs), len(batch_embeddings)) + return + + # Prepare data for parameterized insertion + data_rows = [] + vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768 + + for doc, embedding in zip(batch_docs, batch_embeddings): + # Optimized: minimal checks for common case, fallback for edge cases + metadata = doc.metadata if doc.metadata else {} + + if not isinstance(metadata, dict): + metadata = {} + + doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))) + + # Fast path for JSON serialization + try: + metadata_json = json.dumps(metadata, ensure_ascii=True) + except (TypeError, ValueError): + logger.warning("JSON serialization failed, using empty dict") + metadata_json = "{}" + + content = doc.page_content or "" + + # According to ClickZetta docs, vector should be formatted as array string + # for external systems: '[1.0, 2.0, 3.0]' + vector_str = "[" + ",".join(map(str, embedding)) + "]" + data_rows.append([doc_id, content, metadata_json, vector_str]) + + # Check if we have any valid data to insert + if not data_rows: + logger.warning("No valid documents to insert in batch %d/%d", batch_index // batch_size + 1, total_batches) + return + + # Use parameterized INSERT with executemany for better performance and security + # Cast JSON and VECTOR in SQL, pass raw data as parameters + columns = f"id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, {Field.VECTOR.value}" + insert_sql = ( + f"INSERT INTO {self._config.schema_name}.{self._table_name} ({columns}) " + f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" + ) + + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + try: + # Set session-level hints for batch insert operations + # Note: executemany doesn't support hints parameter, so we set them as session variables + # Temporarily suppress ClickZetta client logging to reduce noise + clickzetta_logger = logging.getLogger("clickzetta") + original_level = clickzetta_logger.level + clickzetta_logger.setLevel(logging.WARNING) + + try: + cursor.execute("SET cz.sql.job.fast.mode = true") + cursor.execute("SET cz.sql.compaction.after.commit = true") + cursor.execute("SET cz.storage.always.prefetch.internal = true") + finally: + # Restore original logging level + clickzetta_logger.setLevel(original_level) + + cursor.executemany(insert_sql, data_rows) + logger.info( + "Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)", + batch_index // batch_size + 1, + total_batches, + len(data_rows), + vector_dimension, + ) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) + logger.exception("SQL template: %s", insert_sql) + logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") + raise + + def text_exists(self, id: str) -> bool: + """Check if a document exists by ID.""" + # Check if table exists first + if not self._table_exists(): + return False + + safe_id = self._safe_doc_id(id) + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute( + f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", + binding_params=[safe_id], + ) + result = cursor.fetchone() + return result[0] > 0 if result else False + + def delete_by_ids(self, ids: list[str]) -> None: + """Delete documents by IDs.""" + if not ids: + return + + # Check if table exists before attempting delete + if not self._table_exists(): + logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name) + return + + # Execute delete through write queue + self._execute_write(self._delete_by_ids_impl, ids) + + def _delete_by_ids_impl(self, ids: list[str]) -> None: + """Implementation of delete by IDs (executed in write worker thread).""" + safe_ids = [self._safe_doc_id(id) for id in ids] + + # Use parameterized query to prevent SQL injection + placeholders = ",".join("?" for _ in safe_ids) + sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({placeholders})" + + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(sql, binding_params=safe_ids) + + def delete_by_metadata_field(self, key: str, value: str) -> None: + """Delete documents by metadata field.""" + # Check if table exists before attempting delete + if not self._table_exists(): + logger.warning("Table %s.%s does not exist, skipping delete", self._config.schema_name, self._table_name) + return + + # Execute delete through write queue + self._execute_write(self._delete_by_metadata_field_impl, key, value) + + def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: + """Implementation of delete by metadata field (executed in write worker thread).""" + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Using JSON path to filter with parameterized query + # Note: JSON path requires literal key name, cannot be parameterized + # Use json_extract_string function for ClickZetta compatibility + sql = ( + f"DELETE FROM {self._config.schema_name}.{self._table_name} " + f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" + ) + cursor.execute(sql, binding_params=[value]) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """Search for documents by vector similarity.""" + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + + top_k = kwargs.get("top_k", 10) + score_threshold = kwargs.get("score_threshold", 0.0) + document_ids_filter = kwargs.get("document_ids_filter") + + # Handle filter parameter from canvas (workflow) + filter_param = kwargs.get("filter", {}) + + # Build filter clause + filter_clauses = [] + if document_ids_filter: + safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) + # Use json_extract_string function for ClickZetta compatibility + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + + # No need for dataset_id filter since each dataset has its own table + + # Add distance threshold based on distance function + vector_dimension = len(query_vector) + if self._config.vector_distance_function == "cosine_distance": + # For cosine distance, smaller is better (0 = identical, 2 = opposite) + distance_func = "COSINE_DISTANCE" + if score_threshold > 0: + query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" + filter_clauses.append( + f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}" + ) + else: + # For L2 distance, smaller is better + distance_func = "L2_DISTANCE" + if score_threshold > 0: + query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" + filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}") + + where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1" + + # Execute vector search query + query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" + search_sql = f""" + SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value}, + {distance_func}({Field.VECTOR.value}, {query_vector_str}) AS distance + FROM {self._config.schema_name}.{self._table_name} + WHERE {where_clause} + ORDER BY distance + LIMIT {top_k} + """ + + documents = [] + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Use hints parameter for vector search optimization + search_hints = { + "hints": { + "sdk.job.timeout": 60, # Increase timeout for vector search + "cz.sql.job.fast.mode": True, + "cz.storage.parquet.vector.index.read.memory.cache": True, + } + } + cursor.execute(search_sql, search_hints) + results = cursor.fetchall() + + for row in results: + # Parse metadata using centralized method + metadata = self._parse_metadata(row[2], row[0]) + + # Add score based on distance + if self._config.vector_distance_function == "cosine_distance": + metadata["score"] = 1 - (row[3] / 2) + else: + metadata["score"] = 1 / (1 + row[3]) + + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + + return documents + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Search for documents using full-text search with inverted index.""" + if not self._config.enable_inverted_index: + logger.warning("Full-text search is not enabled. Enable inverted index in config.") + return [] + + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + + top_k = kwargs.get("top_k", 10) + document_ids_filter = kwargs.get("document_ids_filter") + + # Handle filter parameter from canvas (workflow) + filter_param = kwargs.get("filter", {}) + + # Build filter clause + filter_clauses = [] + if document_ids_filter: + safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) + # Use json_extract_string function for ClickZetta compatibility + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + + # No need for dataset_id filter since each dataset has its own table + + # Use match_all function for full-text search + # match_all requires all terms to be present + # Use simple quote escaping for MATCH_ALL since it needs to be in the WHERE clause + escaped_query = query.replace("'", "''") + filter_clauses.append(f"MATCH_ALL({Field.CONTENT_KEY.value}, '{escaped_query}')") + + where_clause = " AND ".join(filter_clauses) + + # Execute full-text search query + search_sql = f""" + SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + FROM {self._config.schema_name}.{self._table_name} + WHERE {where_clause} + LIMIT {top_k} + """ + + documents = [] + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + try: + # Use hints parameter for full-text search optimization + fulltext_hints = { + "hints": { + "sdk.job.timeout": 30, # Timeout for full-text search + "cz.sql.job.fast.mode": True, + "cz.sql.index.prewhere.enabled": True, + } + } + cursor.execute(search_sql, fulltext_hints) + results = cursor.fetchall() + + for row in results: + # Parse metadata from JSON string (may be double-encoded) + try: + if row[2]: + metadata = json.loads(row[2]) + + # If result is a string, it's double-encoded JSON - parse again + if isinstance(metadata, str): + metadata = json.loads(metadata) + + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = {} + except (json.JSONDecodeError, TypeError) as e: + logger.exception("JSON parsing failed") + # Fallback: extract document_id with regex + + doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) + metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} + + # Ensure required fields are set + metadata["doc_id"] = row[0] # segment id + + # Ensure document_id exists (critical for Dify's format_retrieval_documents) + if "document_id" not in metadata: + metadata["document_id"] = row[0] # fallback to segment id + + # Add a relevance score for full-text search + metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + except (RuntimeError, ValueError, TypeError, ConnectionError) as e: + logger.exception("Full-text search failed") + # Fallback to LIKE search if full-text search fails + return self._search_by_like(query, **kwargs) + + return documents + + def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]: + """Fallback search using LIKE operator.""" + # Check if table exists first + if not self._table_exists(): + logger.warning( + "Table %s.%s does not exist, returning empty results", + self._config.schema_name, + self._table_name, + ) + return [] + + top_k = kwargs.get("top_k", 10) + document_ids_filter = kwargs.get("document_ids_filter") + + # Handle filter parameter from canvas (workflow) + filter_param = kwargs.get("filter", {}) + + # Build filter clause + filter_clauses = [] + if document_ids_filter: + safe_doc_ids = [str(id).replace("'", "''") for id in document_ids_filter] + doc_ids_str = ",".join(f"'{id}'" for id in safe_doc_ids) + # Use json_extract_string function for ClickZetta compatibility + filter_clauses.append( + f"json_extract_string({Field.METADATA_KEY.value}, '$.document_id') IN ({doc_ids_str})" + ) + + # No need for dataset_id filter since each dataset has its own table + + # Use simple quote escaping for LIKE clause + escaped_query = query.replace("'", "''") + filter_clauses.append(f"{Field.CONTENT_KEY.value} LIKE '%{escaped_query}%'") + where_clause = " AND ".join(filter_clauses) + + search_sql = f""" + SELECT id, {Field.CONTENT_KEY.value}, {Field.METADATA_KEY.value} + FROM {self._config.schema_name}.{self._table_name} + WHERE {where_clause} + LIMIT {top_k} + """ + + documents = [] + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + # Use hints parameter for LIKE search optimization + like_hints = { + "hints": { + "sdk.job.timeout": 20, # Timeout for LIKE search + "cz.sql.job.fast.mode": True, + } + } + cursor.execute(search_sql, like_hints) + results = cursor.fetchall() + + for row in results: + # Parse metadata using centralized method + metadata = self._parse_metadata(row[2], row[0]) + + metadata["score"] = 0.5 # Lower score for LIKE search + doc = Document(page_content=row[1], metadata=metadata) + documents.append(doc) + + return documents + + def delete(self) -> None: + """Delete the entire collection.""" + with self.get_connection_context() as connection: + with connection.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") + + def _format_vector_simple(self, vector: list[float]) -> str: + """Simple vector formatting for SQL queries.""" + return ",".join(map(str, vector)) + + def _safe_doc_id(self, doc_id: str) -> str: + """Ensure doc_id is safe for SQL and doesn't contain special characters.""" + if not doc_id: + return str(uuid.uuid4()) + # Remove or replace potentially problematic characters + safe_id = str(doc_id) + # Only allow alphanumeric, hyphens, underscores + safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_") + if not safe_id: # If all characters were removed + return str(uuid.uuid4()) + return safe_id[:255] # Limit length + + +class ClickzettaVectorFactory(AbstractVectorFactory): + """Factory for creating Clickzetta vector instances.""" + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: + """Initialize a Clickzetta vector instance.""" + # Get configuration from environment variables or dataset config + config = ClickzettaConfig( + username=dify_config.CLICKZETTA_USERNAME or "", + password=dify_config.CLICKZETTA_PASSWORD or "", + instance=dify_config.CLICKZETTA_INSTANCE or "", + service=dify_config.CLICKZETTA_SERVICE or "api.clickzetta.com", + workspace=dify_config.CLICKZETTA_WORKSPACE or "quick_start", + vcluster=dify_config.CLICKZETTA_VCLUSTER or "default_ap", + schema_name=dify_config.CLICKZETTA_SCHEMA or "dify", + batch_size=dify_config.CLICKZETTA_BATCH_SIZE or 100, + enable_inverted_index=dify_config.CLICKZETTA_ENABLE_INVERTED_INDEX or True, + analyzer_type=dify_config.CLICKZETTA_ANALYZER_TYPE or "chinese", + analyzer_mode=dify_config.CLICKZETTA_ANALYZER_MODE or "smart", + vector_distance_function=dify_config.CLICKZETTA_VECTOR_DISTANCE_FUNCTION or "cosine_distance", + ) + + # Use dataset collection name as table name + collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower() + + return ClickzettaVector(collection_name=collection_name, config=config) diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index 68a9952789..bd986393d1 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -74,9 +74,9 @@ class CouchbaseVector(BaseVector): self.add_texts(texts, embeddings) def _create_collection(self, vector_length: int, uuid: str): - lock_name = "vector_indexing_lock_{}".format(self._collection_name) + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return if self._collection_exists(self._collection_name): @@ -242,7 +242,7 @@ class CouchbaseVector(BaseVector): try: self._cluster.query(query, named_parameters={"doc_ids": ids}).execute() except Exception as e: - logger.exception(f"Failed to delete documents, ids: {ids}") + logger.exception("Failed to delete documents, ids: %s", ids) def delete_by_document_id(self, document_id: str): query = f""" diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py index 27575197fa..7118029d40 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py @@ -29,7 +29,7 @@ class ElasticSearchJaVector(ElasticSearchVector): with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): - logger.info(f"Collection {self._collection_name} already exists.") + logger.info("Collection %s already exists.", self._collection_name) return if not self._client.indices.exists(index=self._collection_name): diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index ad39717183..49c4b392fe 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse import requests from elasticsearch import Elasticsearch from flask import current_app +from packaging.version import parse as parse_version from pydantic import BaseModel, model_validator from core.rag.datasource.vdb.field import Field @@ -22,22 +23,50 @@ logger = logging.getLogger(__name__) class ElasticSearchConfig(BaseModel): - host: str - port: int - username: str - password: str + # Regular Elasticsearch config + host: Optional[str] = None + port: Optional[int] = None + username: Optional[str] = None + password: Optional[str] = None + + # Elastic Cloud specific config + cloud_url: Optional[str] = None # Cloud URL for Elasticsearch Cloud + api_key: Optional[str] = None + + # Common config + use_cloud: bool = False + ca_certs: Optional[str] = None + verify_certs: bool = False + request_timeout: int = 100000 + retry_on_timeout: bool = True + max_retries: int = 10000 @model_validator(mode="before") @classmethod def validate_config(cls, values: dict) -> dict: - if not values["host"]: - raise ValueError("config HOST is required") - if not values["port"]: - raise ValueError("config PORT is required") - if not values["username"]: - raise ValueError("config USERNAME is required") - if not values["password"]: - raise ValueError("config PASSWORD is required") + use_cloud = values.get("use_cloud", False) + cloud_url = values.get("cloud_url") + + if use_cloud: + # Cloud configuration validation - requires cloud_url and api_key + if not cloud_url: + raise ValueError("cloud_url is required for Elastic Cloud") + + api_key = values.get("api_key") + if not api_key: + raise ValueError("api_key is required for Elastic Cloud") + + else: + # Regular Elasticsearch validation + if not values.get("host"): + raise ValueError("config HOST is required for regular Elasticsearch") + if not values.get("port"): + raise ValueError("config PORT is required for regular Elasticsearch") + if not values.get("username"): + raise ValueError("config USERNAME is required for regular Elasticsearch") + if not values.get("password"): + raise ValueError("config PASSWORD is required for regular Elasticsearch") + return values @@ -50,21 +79,69 @@ class ElasticSearchVector(BaseVector): self._attributes = attributes def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: + """ + Initialize Elasticsearch client for both regular Elasticsearch and Elastic Cloud. + """ try: - parsed_url = urlparse(config.host) - if parsed_url.scheme in {"http", "https"}: - hosts = f"{config.host}:{config.port}" + # Check if using Elastic Cloud + client_config: dict[str, Any] + if config.use_cloud and config.cloud_url: + client_config = { + "request_timeout": config.request_timeout, + "retry_on_timeout": config.retry_on_timeout, + "max_retries": config.max_retries, + "verify_certs": config.verify_certs, + } + + # Parse cloud URL and configure hosts + parsed_url = urlparse(config.cloud_url) + host = f"{parsed_url.scheme}://{parsed_url.hostname}" + if parsed_url.port: + host += f":{parsed_url.port}" + + client_config["hosts"] = [host] + + # API key authentication for cloud + client_config["api_key"] = config.api_key + + # SSL settings + if config.ca_certs: + client_config["ca_certs"] = config.ca_certs + else: - hosts = f"http://{config.host}:{config.port}" - client = Elasticsearch( - hosts=hosts, - basic_auth=(config.username, config.password), - request_timeout=100000, - retry_on_timeout=True, - max_retries=10000, - ) - except requests.exceptions.ConnectionError: - raise ConnectionError("Vector database connection error") + # Regular Elasticsearch configuration + parsed_url = urlparse(config.host or "") + if parsed_url.scheme in {"http", "https"}: + hosts = f"{config.host}:{config.port}" + use_https = parsed_url.scheme == "https" + else: + hosts = f"http://{config.host}:{config.port}" + use_https = False + + client_config = { + "hosts": [hosts], + "basic_auth": (config.username, config.password), + "request_timeout": config.request_timeout, + "retry_on_timeout": config.retry_on_timeout, + "max_retries": config.max_retries, + } + + # Only add SSL settings if using HTTPS + if use_https: + client_config["verify_certs"] = config.verify_certs + if config.ca_certs: + client_config["ca_certs"] = config.ca_certs + + client = Elasticsearch(**client_config) + + # Test connection + if not client.ping(): + raise ConnectionError("Failed to connect to Elasticsearch") + + except requests.exceptions.ConnectionError as e: + raise ConnectionError(f"Vector database connection error: {str(e)}") + except Exception as e: + raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") return client @@ -73,7 +150,7 @@ class ElasticSearchVector(BaseVector): return cast(str, info["version"]["number"]) def _check_version(self): - if self._version < "8.0.0": + if parse_version(self._version) < parse_version("8.0.0"): raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") def get_type(self) -> str: @@ -186,7 +263,7 @@ class ElasticSearchVector(BaseVector): with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): - logger.info(f"Collection {self._collection_name} already exists.") + logger.info("Collection %s already exists.", self._collection_name) return if not self._client.indices.exists(index=self._collection_name): @@ -209,7 +286,11 @@ class ElasticSearchVector(BaseVector): }, } } + self._client.indices.create(index=self._collection_name, mappings=mappings) + logger.info("Created index %s with dimension %s", self._collection_name, dim) + else: + logger.info("Collection %s already exists.", self._collection_name) redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -225,13 +306,51 @@ class ElasticSearchVectorFactory(AbstractVectorFactory): dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) config = current_app.config + + # Check if ELASTICSEARCH_USE_CLOUD is explicitly set to false (boolean) + use_cloud_env = config.get("ELASTICSEARCH_USE_CLOUD", False) + + if use_cloud_env is False: + # Use regular Elasticsearch with config values + config_dict = { + "use_cloud": False, + "host": config.get("ELASTICSEARCH_HOST", "elasticsearch"), + "port": config.get("ELASTICSEARCH_PORT", 9200), + "username": config.get("ELASTICSEARCH_USERNAME", "elastic"), + "password": config.get("ELASTICSEARCH_PASSWORD", "elastic"), + } + else: + # Check for cloud configuration + cloud_url = config.get("ELASTICSEARCH_CLOUD_URL") + if cloud_url: + config_dict = { + "use_cloud": True, + "cloud_url": cloud_url, + "api_key": config.get("ELASTICSEARCH_API_KEY"), + } + else: + # Fallback to regular Elasticsearch + config_dict = { + "use_cloud": False, + "host": config.get("ELASTICSEARCH_HOST", "localhost"), + "port": config.get("ELASTICSEARCH_PORT", 9200), + "username": config.get("ELASTICSEARCH_USERNAME", "elastic"), + "password": config.get("ELASTICSEARCH_PASSWORD", ""), + } + + # Common configuration + config_dict.update( + { + "ca_certs": str(config.get("ELASTICSEARCH_CA_CERTS")) if config.get("ELASTICSEARCH_CA_CERTS") else None, + "verify_certs": bool(config.get("ELASTICSEARCH_VERIFY_CERTS", False)), + "request_timeout": int(config.get("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)), + "retry_on_timeout": bool(config.get("ELASTICSEARCH_RETRY_ON_TIMEOUT", True)), + "max_retries": int(config.get("ELASTICSEARCH_MAX_RETRIES", 10000)), + } + ) + return ElasticSearchVector( index_name=collection_name, - config=ElasticSearchConfig( - host=config.get("ELASTICSEARCH_HOST", "localhost"), - port=config.get("ELASTICSEARCH_PORT", 9200), - username=config.get("ELASTICSEARCH_USERNAME", ""), - password=config.get("ELASTICSEARCH_PASSWORD", ""), - ), + config=ElasticSearchConfig(**config_dict), attributes=[], ) diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index 89423eb160..0a4067e39c 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -164,7 +164,7 @@ class HuaweiCloudVector(BaseVector): with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): - logger.info(f"Collection {self._collection_name} already exists.") + logger.info("Collection %s already exists.", self._collection_name) return if not self._client.indices.exists(index=self._collection_name): diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index e9ff1ce43d..3c65a41f08 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -89,7 +89,7 @@ class LindormVectorStore(BaseVector): timeout: int = 60, **kwargs, ): - logger.info(f"Total documents to add: {len(documents)}") + logger.info("Total documents to add: %s", len(documents)) uuids = self._get_uuids(documents) total_docs = len(documents) @@ -147,7 +147,7 @@ class LindormVectorStore(BaseVector): time.sleep(0.5) except Exception: - logger.exception(f"Failed to process batch {batch_num + 1}") + logger.exception("Failed to process batch %s", batch_num + 1) raise def get_ids_by_metadata_field(self, key: str, value: str): @@ -180,7 +180,7 @@ class LindormVectorStore(BaseVector): # 1. First check if collection exists if not self._client.indices.exists(index=self._collection_name): - logger.warning(f"Collection {self._collection_name} does not exist") + logger.warning("Collection %s does not exist", self._collection_name) return # 2. Batch process deletions @@ -196,7 +196,7 @@ class LindormVectorStore(BaseVector): } ) else: - logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") + logger.warning("DELETE BY ID: ID %s does not exist in the index.", id) # 3. Perform bulk deletion if there are valid documents to delete if actions: @@ -209,9 +209,9 @@ class LindormVectorStore(BaseVector): doc_id = delete_error.get("_id") if status == 404: - logger.warning(f"Document not found for deletion: {doc_id}") + logger.warning("Document not found for deletion: %s", doc_id) else: - logger.exception(f"Error deleting document: {error}") + logger.exception("Error deleting document: %s", error) def delete(self) -> None: if self._using_ugc: @@ -225,7 +225,7 @@ class LindormVectorStore(BaseVector): self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) logger.info("Delete index success") else: - logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") + logger.warning("Index '%s' does not exist. No deletion performed.", self._collection_name) def text_exists(self, id: str) -> bool: try: @@ -257,7 +257,7 @@ class LindormVectorStore(BaseVector): params["routing"] = self._routing # type: ignore response = self._client.search(index=self._collection_name, body=query, params=params) except Exception: - logger.exception(f"Error executing vector search, query: {query}") + logger.exception("Error executing vector search, query: %s", query) raise docs_and_scores = [] @@ -324,10 +324,10 @@ class LindormVectorStore(BaseVector): with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): - logger.info(f"Collection {self._collection_name} already exists.") + logger.info("Collection %s already exists.", self._collection_name) return if self._client.indices.exists(index=self._collection_name): - logger.info(f"{self._collection_name.lower()} already exists.") + logger.info("%s already exists.", self._collection_name.lower()) redis_client.set(collection_exist_cache_key, 1, ex=3600) return if len(self.kwargs) == 0 and len(kwargs) != 0: diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 63de6a0603..112f07844c 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -101,9 +101,9 @@ class MilvusVector(BaseVector): if "Zilliz Cloud" in milvus_version: return True # For standard Milvus installations, check version number - return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version + return version.parse(milvus_version) >= version.parse("2.5.0") except Exception as e: - logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.") + logger.warning("Failed to check Milvus version: %s. Disabling hybrid search.", str(e)) return False def get_type(self) -> str: @@ -289,9 +289,9 @@ class MilvusVector(BaseVector): """ Create a new collection in Milvus with the specified schema and index parameters. """ - lock_name = "vector_indexing_lock_{}".format(self._collection_name) + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return # Grab the existing collection if it exists diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index dbb1a7fe19..d5ec4b4436 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -53,7 +53,7 @@ class MyScaleVector(BaseVector): return self.add_texts(documents=texts, embeddings=embeddings, **kwargs) def _create_collection(self, dimension: int): - logging.info(f"create MyScale collection {self._collection_name} with dimension {dimension}") + logging.info("create MyScale collection %s with dimension %s", self._collection_name, dimension) self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}") fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else "" sql = f""" @@ -151,7 +151,7 @@ class MyScaleVector(BaseVector): for r in self._client.query(sql).named_results() ] except Exception as e: - logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") # noqa:TRY401 + logging.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401 return [] def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index dd196e1f09..556d03940e 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -4,8 +4,8 @@ import math from typing import Any from pydantic import BaseModel, model_validator -from pyobvector import VECTOR, ObVecClient # type: ignore -from sqlalchemy import JSON, Column, String, func +from pyobvector import VECTOR, FtsIndexParam, FtsParser, ObVecClient, l2_distance # type: ignore +from sqlalchemy import JSON, Column, String from sqlalchemy.dialects.mysql import LONGTEXT from configs import dify_config @@ -119,14 +119,21 @@ class OceanBaseVector(BaseVector): ) try: if self._hybrid_search_enabled: - self._client.perform_raw_text_sql(f"""ALTER TABLE {self._collection_name} - ADD FULLTEXT INDEX fulltext_index_for_col_text (text) WITH PARSER ik""") + self._client.create_fts_idx_with_fts_index_param( + table_name=self._collection_name, + fts_idx_param=FtsIndexParam( + index_name="fulltext_index_for_col_text", + field_names=["text"], + parser_type=FtsParser.IK, + ), + ) except Exception as e: raise Exception( "Failed to add fulltext index to the target table, your OceanBase version must be 4.3.5.1 or above " + "to support fulltext index and vector index in the same table", e, ) + self._client.refresh_metadata([self._collection_name]) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _check_hybrid_search_support(self) -> bool: @@ -145,9 +152,9 @@ class OceanBaseVector(BaseVector): ob_full_version = result.fetchone()[0] ob_version = ob_full_version.split()[1] logger.debug("Current OceanBase version is %s", ob_version) - return version.parse(ob_version).base_version >= version.parse("4.3.5.1").base_version + return version.parse(ob_version) >= version.parse("4.3.5.1") except Exception as e: - logger.warning(f"Failed to check OceanBase version: {str(e)}. Disabling hybrid search.") + logger.warning("Failed to check OceanBase version: %s. Disabling hybrid search.", str(e)) return False def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -229,7 +236,7 @@ class OceanBaseVector(BaseVector): return docs except Exception as e: - logger.warning(f"Failed to fulltext search: {str(e)}.") + logger.warning("Failed to fulltext search: %s.", str(e)) return [] def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: @@ -252,7 +259,7 @@ class OceanBaseVector(BaseVector): vec_column_name="vector", vec_data=query_vector, topk=topk, - distance_func=func.l2_distance, + distance_func=l2_distance, output_column_names=["text", "metadata"], with_dist=True, where_clause=_where_clause, diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 0abb3c0077..ed2dcb40ad 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -131,7 +131,7 @@ class OpenSearchVector(BaseVector): def delete_by_ids(self, ids: list[str]) -> None: index_name = self._collection_name.lower() if not self._client.indices.exists(index=index_name): - logger.warning(f"Index {index_name} does not exist") + logger.warning("Index %s does not exist", index_name) return # Obtaining All Actual Documents_ID @@ -142,7 +142,7 @@ class OpenSearchVector(BaseVector): if es_ids: actual_ids.extend(es_ids) else: - logger.warning(f"Document with metadata doc_id {doc_id} not found for deletion") + logger.warning("Document with metadata doc_id %s not found for deletion", doc_id) if actual_ids: actions = [{"_op_type": "delete", "_index": index_name, "_id": es_id} for es_id in actual_ids] @@ -155,9 +155,9 @@ class OpenSearchVector(BaseVector): doc_id = delete_error.get("_id") if status == 404: - logger.warning(f"Document not found for deletion: {doc_id}") + logger.warning("Document not found for deletion: %s", doc_id) else: - logger.exception(f"Error deleting document: {error}") + logger.exception("Error deleting document: %s", error) def delete(self) -> None: self._client.indices.delete(index=self._collection_name.lower()) @@ -198,7 +198,7 @@ class OpenSearchVector(BaseVector): try: response = self._client.search(index=self._collection_name.lower(), body=query) except Exception as e: - logger.exception(f"Error executing vector search, query: {query}") + logger.exception("Error executing vector search, query: %s", query) raise docs = [] @@ -242,7 +242,7 @@ class OpenSearchVector(BaseVector): with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}" if redis_client.get(collection_exist_cache_key): - logger.info(f"Collection {self._collection_name.lower()} already exists.") + logger.info("Collection %s already exists.", self._collection_name.lower()) return if not self._client.indices.exists(index=self._collection_name.lower()): @@ -272,7 +272,7 @@ class OpenSearchVector(BaseVector): }, } - logger.info(f"Creating OpenSearch index {self._collection_name.lower()}") + logger.info("Creating OpenSearch index %s", self._collection_name.lower()) self._client.indices.create(index=self._collection_name.lower(), body=index_body) redis_client.set(collection_exist_cache_key, 1, ex=3600) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index d1c8142b3d..303c3fe31c 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -109,8 +109,19 @@ class OracleVector(BaseVector): ) def _get_connection(self) -> Connection: - connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn) - return connection + if self.config.is_autonomous: + connection = oracledb.connect( + user=self.config.user, + password=self.config.password, + dsn=self.config.dsn, + config_dir=self.config.config_dir, + wallet_location=self.config.wallet_location, + wallet_password=self.config.wallet_password, + ) + return connection + else: + connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn) + return connection def _create_connection_pool(self, config: OracleVectorConfig): pool_params = { diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index b0f0eeca38..e77befcdae 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -82,9 +82,9 @@ class PGVectoRS(BaseVector): self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = "vector_indexing_lock_{}".format(self._collection_name) + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 04e9cf801e..746773da63 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -155,7 +155,7 @@ class PGVector(BaseVector): cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) except psycopg2.errors.UndefinedTable: # table not exists - logging.warning(f"Table {self.table_name} not found, skipping delete operation.") + logging.warning("Table %s not found, skipping delete operation.", self.table_name) return except Exception as e: raise e diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index dfb95a1839..fcf3a6d126 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -95,9 +95,9 @@ class QdrantVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str, vector_size: int): - lock_name = "vector_indexing_lock_{}".format(collection_name) + lock_name = f"vector_indexing_lock_{collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return collection_name = collection_name or uuid.uuid4().hex @@ -331,6 +331,12 @@ class QdrantVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models + score_threshold = float(kwargs.get("score_threshold") or 0.0) + if score_threshold >= 1: + # return empty list because some versions of qdrant may response with 400 bad request, + # and at the same time, the score_threshold with value 1 may be valid for other vector stores + return [] + filter = models.Filter( must=[ models.FieldCondition( @@ -355,7 +361,7 @@ class QdrantVector(BaseVector): limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=float(kwargs.get("score_threshold") or 0.0), + score_threshold=score_threshold, ) docs = [] for result in results: @@ -363,7 +369,6 @@ class QdrantVector(BaseVector): continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 0c0d6a463d..7a42dd1a89 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -70,9 +70,9 @@ class RelytVector(BaseVector): self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = "vector_indexing_lock_{}".format(self._collection_name) + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index 9ed6e7369b..91d667ff2c 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -1,5 +1,6 @@ import json import logging +import math from typing import Any, Optional import tablestore # type: ignore @@ -22,6 +23,7 @@ class TableStoreConfig(BaseModel): access_key_secret: Optional[str] = None instance_name: Optional[str] = None endpoint: Optional[str] = None + normalize_full_text_bm25_score: Optional[bool] = False @model_validator(mode="before") @classmethod @@ -47,6 +49,7 @@ class TableStoreVector(BaseVector): config.access_key_secret, config.instance_name, ) + self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score self._table_name = f"{collection_name}" self._index_name = f"{collection_name}_idx" self._tags_field = f"{Field.METADATA_KEY.value}_tags" @@ -131,8 +134,8 @@ class TableStoreVector(BaseVector): filtered_list = None if document_ids_filter: filtered_list = ["document_id=" + item for item in document_ids_filter] - - return self._search_by_full_text(query, filtered_list, top_k) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._search_by_full_text(query, filtered_list, top_k, score_threshold) def delete(self) -> None: self._delete_table_if_exist() @@ -142,7 +145,7 @@ class TableStoreVector(BaseVector): with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): - logging.info(f"Collection {self._collection_name} already exists.") + logging.info("Collection %s already exists.", self._collection_name) return self._create_table_if_not_exist() @@ -318,7 +321,19 @@ class TableStoreVector(BaseVector): documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents - def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: + @staticmethod + def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float: + """ + Args: + score: BM25 search score. + k: decay factor, the larger the k, the steeper the low score end + """ + normalized_score = 1 - math.exp(-k * score) + return max(0.0, min(1.0, normalized_score)) + + def _search_by_full_text( + self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float + ) -> list[Document]: bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[]) bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) @@ -339,15 +354,27 @@ class TableStoreVector(BaseVector): documents = [] for search_hit in search_response.search_hits: + score = None + if self._normalize_full_text_bm25_score: + score = self._normalize_score_exp_decay(search_hit.score) + + # skip when score is below threshold and use normalize score + if score and score <= score_threshold: + continue + ots_column_map = {} for col in search_hit.row[1]: ots_column_map[col[0]] = col[1] - vector_str = ots_column_map.get(Field.VECTOR.value) metadata_str = ots_column_map.get(Field.METADATA_KEY.value) - vector = json.loads(vector_str) if vector_str else None metadata = json.loads(metadata_str) if metadata_str else {} + vector_str = ots_column_map.get(Field.VECTOR.value) + vector = json.loads(vector_str) if vector_str else None + + if score: + metadata["score"] = score + documents.append( Document( page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", @@ -355,6 +382,8 @@ class TableStoreVector(BaseVector): metadata=metadata, ) ) + if self._normalize_full_text_bm25_score: + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents @@ -375,5 +404,6 @@ class TableStoreVectorFactory(AbstractVectorFactory): instance_name=dify_config.TABLESTORE_INSTANCE_NAME, access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID, access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET, + normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE, ), ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 23ed8a3344..0517d5a6d1 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -92,9 +92,9 @@ class TencentVector(BaseVector): def _create_collection(self, dimension: int) -> None: self._dimension = dimension - lock_name = "vector_indexing_lock_{}".format(self._collection_name) + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return @@ -246,6 +246,10 @@ class TencentVector(BaseVector): return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + document_ids_filter = kwargs.get("document_ids_filter") + filter = None + if document_ids_filter: + filter = Filter(Filter.In("metadata.document_id", document_ids_filter)) if not self._enable_hybrid_search: return [] res = self._client.hybrid_search( @@ -269,6 +273,7 @@ class TencentVector(BaseVector): ), retrieve_vector=False, limit=kwargs.get("top_k", 4), + filter=filter, ) score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index ba6a9654f0..e848b39c4d 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -104,9 +104,9 @@ class TidbOnQdrantVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str, vector_size: int): - lock_name = "vector_indexing_lock_{}".format(collection_name) + lock_name = f"vector_indexing_lock_{collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return collection_name = collection_name or uuid.uuid4().hex diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 61c68b939e..f8a851a246 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -91,9 +91,9 @@ class TiDBVector(BaseVector): def _create_collection(self, dimension: int): logger.info("_create_collection, collection_name " + self._collection_name) - lock_name = "vector_indexing_lock_{}".format(self._collection_name) + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return tidb_dist_func = self._get_distance_func() @@ -192,7 +192,7 @@ class TiDBVector(BaseVector): query_vector_str = ", ".join(format(x) for x in query_vector) query_vector_str = "[" + query_vector_str + "]" logger.debug( - f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}" + "_collection_name: %s, score_threshold: %s, distance: %s", self._collection_name, score_threshold, distance ) docs = [] diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index e018f7d3d4..eef03ce412 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -172,25 +172,29 @@ class Vector: from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory return MatrixoneVectorFactory + case VectorType.CLICKZETTA: + from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory + + return ClickzettaVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") def create(self, texts: Optional[list] = None, **kwargs): if texts: start = time.time() - logger.info(f"start embedding {len(texts)} texts {start}") + logger.info("start embedding %s texts %s", len(texts), start) batch_size = 1000 total_batches = len(texts) + batch_size - 1 for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] batch_start = time.time() - logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)") + logger.info("Processing batch %s/%s (%s texts)", i // batch_size + 1, total_batches, len(batch)) batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch]) logger.info( - f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s" + "Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start ) self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) - logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s") + logger.info("Embedding %s texts took %s s", len(texts), time.time() - start) def add_texts(self, documents: list[Document], **kwargs): if kwargs.get("duplicate_check", False): @@ -219,7 +223,7 @@ class Vector: self._vector_processor.delete() # delete collection redis cache if self._vector_processor.collection_name: - collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name) + collection_exist_cache_key = f"vector_indexing_{self._vector_processor.collection_name}" redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 0d70947b72..a415142196 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -30,3 +30,4 @@ class VectorType(StrEnum): TABLESTORE = "tablestore" HUAWEI_CLOUD = "huawei_cloud" MATRIXONE = "matrixone" + CLICKZETTA = "clickzetta" diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 7a8efb4068..5525ef1685 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -92,9 +92,9 @@ class WeaviateVector(BaseVector): self.add_texts(texts, embeddings) def _create_collection(self): - lock_name = "vector_indexing_lock_{}".format(self._collection_name) + lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" if redis_client.get(collection_exist_cache_key): return schema = self._default_schema(self._collection_name) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index f844770a20..f8da3657fc 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -32,7 +32,7 @@ class DatasetDocumentStore: } @property - def dateset_id(self) -> Any: + def dataset_id(self) -> Any: return self._dataset.id @property @@ -123,13 +123,13 @@ class DatasetDocumentStore: db.session.flush() if save_child: if doc.children: - for postion, child in enumerate(doc.children, start=1): + for position, child in enumerate(doc.children, start=1): child_segment = ChildChunk( tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, document_id=self._document_id, segment_id=segment_document.id, - position=postion, + position=position, index_node_id=child.metadata.get("doc_id"), index_node_hash=child.metadata.get("doc_hash"), content=child.page_content, diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index f50f9f6b60..9848a28384 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -69,7 +69,7 @@ class CacheEmbedding(Embeddings): # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan if np.isnan(normalized_embedding).any(): # for issue #11827 float values are not json compliant - logger.warning(f"Normalized embedding is nan: {normalized_embedding}") + logger.warning("Normalized embedding is nan: %s", normalized_embedding) continue embedding_queue_embeddings.append(normalized_embedding) except IntegrityError: @@ -122,7 +122,7 @@ class CacheEmbedding(Embeddings): raise ValueError("Normalized embedding is nan please try again") except Exception as ex: if dify_config.DEBUG: - logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'") + logging.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text)) raise ex try: @@ -136,7 +136,9 @@ class CacheEmbedding(Embeddings): redis_client.setex(embedding_cache_key, 600, encoded_str) except Exception as ex: if dify_config.DEBUG: - logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'") + logging.exception( + "Failed to add embedding to redis for the text '%s...(%s chars)'", text[:10], len(text) + ) raise ex return embedding_results # type: ignore diff --git a/api/core/rag/entities/metadata_entities.py b/api/core/rag/entities/metadata_entities.py index 6ef932ad22..1f054bccdb 100644 --- a/api/core/rag/entities/metadata_entities.py +++ b/api/core/rag/entities/metadata_entities.py @@ -13,6 +13,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index a3b35458df..7cc554c74d 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -34,9 +34,8 @@ class ExcelExtractor(BaseExtractor): for sheet_name in wb.sheetnames: sheet = wb[sheet_name] data = sheet.values - try: - cols = next(data) - except StopIteration: + cols = next(data, None) + if cols is None: continue df = pd.DataFrame(data, columns=cols) diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 875626eb34..17f4d1af2d 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,5 +1,6 @@ import json import logging +import operator from typing import Any, Optional, cast import requests @@ -130,13 +131,15 @@ class NotionExtractor(BaseExtractor): data[property_name] = value row_dict = {k: v for k, v in data.items() if v} row_content = "" - for key, value in row_dict.items(): + for key, value in sorted(row_dict.items(), key=operator.itemgetter(0)): if isinstance(value, dict): value_dict = {k: v for k, v in value.items() if v} value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) row_content = row_content + f"{key}:{value_content}\n" else: row_content = row_content + f"{key}:{value}\n" + if "url" in result: + row_content = row_content + f"Row Page URL:{result.get('url', '')}\n" database_content.append(row_content) has_more = response_data.get("has_more", False) diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 04033dec3f..7dfe2e357c 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,5 +1,6 @@ """Abstract interface for document loader implementations.""" +import contextlib from collections.abc import Iterator from typing import Optional, cast @@ -25,12 +26,10 @@ class PdfExtractor(BaseExtractor): def extract(self) -> list[Document]: plaintext_file_exists = False if self._file_cache_key: - try: + with contextlib.suppress(FileNotFoundError): text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] - except FileNotFoundError: - pass documents = list(self.load()) text_list = [] for document in documents: diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index f1fa5dde5c..856a9bce18 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,4 +1,5 @@ import base64 +import contextlib import logging from typing import Optional @@ -33,7 +34,7 @@ class UnstructuredEmailExtractor(BaseExtractor): elements = partition_email(filename=self._file_path) # noinspection PyBroadException - try: + with contextlib.suppress(Exception): for element in elements: element_text = element.text.strip() @@ -43,8 +44,6 @@ class UnstructuredEmailExtractor(BaseExtractor): element_decode = base64.b64decode(element_text) soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser") element.text = soup.get_text() - except Exception: - pass from unstructured.chunking.title import chunk_by_title diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index 21fbb2100f..da03fc67a6 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -1,6 +1,6 @@ from collections.abc import Generator from datetime import datetime -from typing import Any +from typing import Any, Optional from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient @@ -9,7 +9,7 @@ class WaterCrawlProvider: def __init__(self, api_key, base_url: str | None = None): self.client = WaterCrawlAPIClient(api_key, base_url) - def crawl_url(self, url, options: dict | Any = None) -> dict: + def crawl_url(self, url, options: Optional[dict | Any] = None) -> dict: options = options or {} spider_options = { "max_depth": 1, diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 14363de7d4..f3b162e3d3 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,6 +1,5 @@ """Abstract interface for document loader implementations.""" -import datetime import logging import mimetypes import os @@ -19,6 +18,7 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import UploadFile @@ -62,7 +62,7 @@ class WordExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load given path as single page.""" - content = self.parse_docx(self.file_path, "storage") + content = self.parse_docx(self.file_path) return [ Document( page_content=content, @@ -117,10 +117,10 @@ class WordExtractor(BaseExtractor): mime_type=mime_type or "", created_by=self.user_id, created_by_role=CreatorUserRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + created_at=naive_utc_now(), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + used_at=naive_utc_now(), ) db.session.add(upload_file) @@ -189,23 +189,8 @@ class WordExtractor(BaseExtractor): paragraph_content.append(run.text) return "".join(paragraph_content).strip() - def _parse_paragraph(self, paragraph, image_map): - paragraph_content = [] - for run in paragraph.runs: - if run.element.xpath(".//a:blip"): - for blip in run.element.xpath(".//a:blip"): - embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") - if embed_id: - rel_target = run.part.rels[embed_id].target_ref - if rel_target in image_map: - paragraph_content.append(image_map[rel_target]) - if run.text.strip(): - paragraph_content.append(run.text.strip()) - return " ".join(paragraph_content) if paragraph_content else "" - - def parse_docx(self, docx_path, image_folder): + def parse_docx(self, docx_path): doc = DocxDocument(docx_path) - os.makedirs(image_folder, exist_ok=True) content = [] diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 04a3428ad8..ff63a6780e 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field class ChildDocument(BaseModel): @@ -15,7 +15,7 @@ class ChildDocument(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = {} + metadata: dict = Field(default_factory=dict) class Document(BaseModel): @@ -28,7 +28,7 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = {} + metadata: dict = Field(default_factory=dict) provider: Optional[str] = "dify" diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index a25bc65646..cd4af72832 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1012,7 +1012,7 @@ class DatasetRetrieval: def _process_metadata_filter_func( self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list ): - if value is None: + if value is None and condition not in ("empty", "not empty"): return key = f"{metadata_name}_{sequence}" diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index bcaf299892..d654463be9 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -5,14 +5,13 @@ from __future__ import annotations from typing import Any, Optional from core.model_manager import ModelInstance -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer from core.rag.splitter.text_splitter import ( TS, Collection, Literal, RecursiveCharacterTextSplitter, Set, - TokenTextSplitter, Union, ) @@ -45,14 +44,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return [len(text) for text in texts] - if issubclass(cls, TokenTextSplitter): - extra_kwargs = { - "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", - "allowed_special": allowed_special, - "disallowed_special": disallowed_special, - } - kwargs = {**kwargs, **extra_kwargs} - return cls(length_function=_character_encoder, **kwargs) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 529d8ccd27..489aa05430 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -116,7 +116,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( - f"Created a chunk of size {total}, which is longer than the specified {self._chunk_size}" + "Created a chunk of size %s, which is longer than the specified %s", total, self._chunk_size ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 052ba1c2cb..d83823d7b9 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -5,10 +5,14 @@ This package contains concrete implementations of the repository interfaces defined in the core.workflow.repository package. """ +from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ + "CeleryWorkflowExecutionRepository", + "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", "RepositoryImportError", "SQLAlchemyWorkflowNodeExecutionRepository", diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py new file mode 100644 index 0000000000..df1f8db67f --- /dev/null +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -0,0 +1,126 @@ +""" +Celery-based implementation of the WorkflowExecutionRepository. + +This implementation uses Celery tasks for asynchronous storage operations, +providing improved performance by offloading database operations to background workers. +""" + +import logging +from typing import Optional, Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_execution import WorkflowExecution +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from libs.helper import extract_tenant_id +from models import Account, CreatorUserRole, EndUser +from models.enums import WorkflowRunTriggeredFrom +from tasks.workflow_execution_tasks import ( + save_workflow_execution_task, +) + +logger = logging.getLogger(__name__) + + +class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): + """ + Celery-based implementation of the WorkflowExecutionRepository interface. + + This implementation provides asynchronous storage capabilities by using Celery tasks + to handle database operations in background workers. This improves performance by + reducing the blocking time for workflow execution storage operations. + + Key features: + - Asynchronous save operations using Celery tasks + - Support for multi-tenancy through tenant/app filtering + - Automatic retry and error handling through Celery + """ + + _session_factory: sessionmaker + _tenant_id: str + _app_id: Optional[str] + _triggered_from: Optional[WorkflowRunTriggeredFrom] + _creator_user_id: str + _creator_user_role: CreatorUserRole + + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: Optional[str], + triggered_from: Optional[WorkflowRunTriggeredFrom], + ): + """ + Initialize the repository with Celery task configuration and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for fallback operations + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) + """ + # Store session factory for fallback operations + if isinstance(session_factory, Engine): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + elif isinstance(session_factory, sessionmaker): + self._session_factory = session_factory + else: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) + + # Extract tenant_id from user + tenant_id = extract_tenant_id(user) + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + logger.info( + "Initialized CeleryWorkflowExecutionRepository for tenant %s, app %s, triggered_from %s", + self._tenant_id, + self._app_id, + self._triggered_from, + ) + + def save(self, execution: WorkflowExecution) -> None: + """ + Save or update a WorkflowExecution instance asynchronously using Celery. + + This method queues the save operation as a Celery task and returns immediately, + providing improved performance for high-throughput scenarios. + + Args: + execution: The WorkflowExecution instance to save or update + """ + try: + # Serialize execution for Celery task + execution_data = execution.model_dump() + + # Queue the save operation as a Celery task (fire and forget) + save_workflow_execution_task.delay( + execution_data=execution_data, + tenant_id=self._tenant_id, + app_id=self._app_id or "", + triggered_from=self._triggered_from.value if self._triggered_from else "", + creator_user_id=self._creator_user_id, + creator_user_role=self._creator_user_role.value, + ) + + logger.debug("Queued async save for workflow execution: %s", execution.id_) + + except Exception as e: + logger.exception("Failed to queue save operation for execution %s", execution.id_) + # In case of Celery failure, we could implement a fallback to synchronous save + # For now, we'll re-raise the exception + raise diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py new file mode 100644 index 0000000000..5b410a7b56 --- /dev/null +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -0,0 +1,190 @@ +""" +Celery-based implementation of the WorkflowNodeExecutionRepository. + +This implementation uses Celery tasks for asynchronous storage operations, +providing improved performance by offloading database operations to background workers. +""" + +import logging +from collections.abc import Sequence +from typing import Optional, Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution +from core.workflow.repositories.workflow_node_execution_repository import ( + OrderConfig, + WorkflowNodeExecutionRepository, +) +from libs.helper import extract_tenant_id +from models import Account, CreatorUserRole, EndUser +from models.workflow import WorkflowNodeExecutionTriggeredFrom +from tasks.workflow_node_execution_tasks import ( + save_workflow_node_execution_task, +) + +logger = logging.getLogger(__name__) + + +class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): + """ + Celery-based implementation of the WorkflowNodeExecutionRepository interface. + + This implementation provides asynchronous storage capabilities by using Celery tasks + to handle database operations in background workers. This improves performance by + reducing the blocking time for workflow node execution storage operations. + + Key features: + - Asynchronous save operations using Celery tasks + - In-memory cache for immediate reads + - Support for multi-tenancy through tenant/app filtering + - Automatic retry and error handling through Celery + """ + + _session_factory: sessionmaker + _tenant_id: str + _app_id: Optional[str] + _triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom] + _creator_user_id: str + _creator_user_role: CreatorUserRole + _execution_cache: dict[str, WorkflowNodeExecution] + _workflow_execution_mapping: dict[str, list[str]] + + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: Optional[str], + triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + ): + """ + Initialize the repository with Celery task configuration and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for fallback operations + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) + """ + # Store session factory for fallback operations + if isinstance(session_factory, Engine): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + elif isinstance(session_factory, sessionmaker): + self._session_factory = session_factory + else: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) + + # Extract tenant_id from user + tenant_id = extract_tenant_id(user) + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # In-memory cache for workflow node executions + self._execution_cache: dict[str, WorkflowNodeExecution] = {} + + # Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval + self._workflow_execution_mapping: dict[str, list[str]] = {} + + logger.info( + "Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s", + self._tenant_id, + self._app_id, + self._triggered_from, + ) + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save or update a WorkflowNodeExecution instance to cache and asynchronously to database. + + This method stores the execution in cache immediately for fast reads and queues + the save operation as a Celery task without tracking the task status. + + Args: + execution: The WorkflowNodeExecution instance to save or update + """ + try: + # Store in cache immediately for fast reads + self._execution_cache[execution.id] = execution + + # Update workflow execution mapping for efficient retrieval + if execution.workflow_execution_id: + if execution.workflow_execution_id not in self._workflow_execution_mapping: + self._workflow_execution_mapping[execution.workflow_execution_id] = [] + if execution.id not in self._workflow_execution_mapping[execution.workflow_execution_id]: + self._workflow_execution_mapping[execution.workflow_execution_id].append(execution.id) + + # Serialize execution for Celery task + execution_data = execution.model_dump() + + # Queue the save operation as a Celery task (fire and forget) + save_workflow_node_execution_task.delay( + execution_data=execution_data, + tenant_id=self._tenant_id, + app_id=self._app_id or "", + triggered_from=self._triggered_from.value if self._triggered_from else "", + creator_user_id=self._creator_user_id, + creator_user_role=self._creator_user_role.value, + ) + + logger.debug("Cached and queued async save for workflow node execution: %s", execution.id) + + except Exception as e: + logger.exception("Failed to cache or queue save operation for node execution %s", execution.id) + # In case of Celery failure, we could implement a fallback to synchronous save + # For now, we'll re-raise the exception + raise + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. + + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + + Returns: + A sequence of WorkflowNodeExecution instances + """ + try: + # Get execution IDs for this workflow run from cache + execution_ids = self._workflow_execution_mapping.get(workflow_run_id, []) + + # Retrieve executions from cache + result = [] + for execution_id in execution_ids: + if execution_id in self._execution_cache: + result.append(self._execution_cache[execution_id]) + + # Apply ordering if specified + if order_config and result: + # Sort based on the order configuration + reverse = order_config.order_direction == "desc" + + # Sort by multiple fields if specified + for field_name in reversed(order_config.order_by): + result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse) + + logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id) + return result + + except Exception as e: + logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id) + return [] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 4118aa61c7..854c122331 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -5,10 +5,7 @@ This module provides a Django-like settings system for repository implementation allowing users to configure different repository backends through string paths. """ -import importlib -import inspect -import logging -from typing import Protocol, Union +from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,12 +13,11 @@ from sqlalchemy.orm import sessionmaker from configs import dify_config from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom -logger = logging.getLogger(__name__) - class RepositoryImportError(Exception): """Raised when a repository implementation cannot be imported or instantiated.""" @@ -37,98 +33,6 @@ class DifyCoreRepositoryFactory: are specified as module paths (e.g., 'module.submodule.ClassName'). """ - @staticmethod - def _import_class(class_path: str) -> type: - """ - Import a class from a module path string. - - Args: - class_path: Full module path to the class (e.g., 'module.submodule.ClassName') - - Returns: - The imported class - - Raises: - RepositoryImportError: If the class cannot be imported - """ - try: - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - repo_class = getattr(module, class_name) - assert isinstance(repo_class, type) - return repo_class - except (ValueError, ImportError, AttributeError) as e: - raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e - - @staticmethod - def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore - """ - Validate that a class implements the expected repository interface. - - Args: - repository_class: The class to validate - expected_interface: The expected interface/protocol - - Raises: - RepositoryImportError: If the class doesn't implement the interface - """ - # Check if the class has all required methods from the protocol - required_methods = [ - method - for method in dir(expected_interface) - if not method.startswith("_") and callable(getattr(expected_interface, method, None)) - ] - - missing_methods = [] - for method_name in required_methods: - if not hasattr(repository_class, method_name): - missing_methods.append(method_name) - - if missing_methods: - raise RepositoryImportError( - f"Repository class '{repository_class.__name__}' does not implement required methods " - f"{missing_methods} from interface '{expected_interface.__name__}'" - ) - - @staticmethod - def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: - """ - Validate that a repository class constructor accepts required parameters. - - Args: - repository_class: The class to validate - required_params: List of required parameter names - - Raises: - RepositoryImportError: If the constructor doesn't accept required parameters - """ - - try: - # MyPy may flag the line below with the following error: - # - # > Accessing "__init__" on an instance is unsound, since - # > instance.__init__ could be from an incompatible subclass. - # - # Despite this, we need to ensure that the constructor of `repository_class` - # has a compatible signature. - signature = inspect.signature(repository_class.__init__) # type: ignore[misc] - param_names = list(signature.parameters.keys()) - - # Remove 'self' parameter - if "self" in param_names: - param_names.remove("self") - - missing_params = [param for param in required_params if param not in param_names] - if missing_params: - raise RepositoryImportError( - f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: " - f"{missing_params}. Expected parameters: {required_params}" - ) - except Exception as e: - raise RepositoryImportError( - f"Failed to validate constructor signature for '{repository_class.__name__}': {e}" - ) from e - @classmethod def create_workflow_execution_repository( cls, @@ -153,26 +57,16 @@ class DifyCoreRepositoryFactory: RepositoryImportError: If the configured repository cannot be created """ class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY - logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) - cls._validate_constructor_signature( - repository_class, ["session_factory", "user", "app_id", "triggered_from"] - ) - + repository_class = import_string(class_path) return repository_class( # type: ignore[no-any-return] session_factory=session_factory, user=user, app_id=app_id, triggered_from=triggered_from, ) - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create WorkflowExecutionRepository") + except (ImportError, Exception) as e: raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e @classmethod @@ -199,26 +93,16 @@ class DifyCoreRepositoryFactory: RepositoryImportError: If the configured repository cannot be created """ class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY - logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) - cls._validate_constructor_signature( - repository_class, ["session_factory", "user", "app_id", "triggered_from"] - ) - + repository_class = import_string(class_path) return repository_class( # type: ignore[no-any-return] session_factory=session_factory, user=user, app_id=app_id, triggered_from=triggered_from, ) - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create WorkflowNodeExecutionRepository") + except (ImportError, Exception) as e: raise RepositoryImportError( f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" ) from e diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index c579ff4028..74a49842f3 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -203,5 +203,5 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): session.commit() # Update the in-memory cache for faster subsequent lookups - logger.debug(f"Updating cache for execution_id: {db_model.id}") + logger.debug("Updating cache for execution_id: %s", db_model.id) self._execution_cache[db_model.id] = db_model diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index d4a31390f8..f4532d7f29 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -215,7 +215,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) # Update the in-memory cache for faster subsequent lookups # Only cache if we have a node_execution_id to use as the cache key if db_model.node_execution_id: - logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}") + logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id) self._node_execution_cache[db_model.node_execution_id] = db_model def get_db_models_by_workflow_run( diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 35e16b5c8f..d6961cdaa4 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -20,9 +20,6 @@ class Tool(ABC): The base class of a tool """ - entity: ToolEntity - runtime: ToolRuntime - def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None: self.entity = entity self.runtime = runtime diff --git a/api/core/tools/__base/tool_provider.py b/api/core/tools/__base/tool_provider.py index d096fc7df7..d1d7976cc3 100644 --- a/api/core/tools/__base/tool_provider.py +++ b/api/core/tools/__base/tool_provider.py @@ -12,8 +12,6 @@ from core.tools.errors import ToolProviderCredentialValidationError class ToolProviderController(ABC): - entity: ToolProviderEntity - def __init__(self, entity: ToolProviderEntity) -> None: self.entity = entity diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index 1639dd687f..a8fd6ec2cd 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -37,12 +37,12 @@ class LocaltimeToTimestampTool(BuiltinTool): @staticmethod def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: try: - if local_tz is None: - local_tz = datetime.now().astimezone().tzinfo - if isinstance(local_tz, str): - local_tz = pytz.timezone(local_tz) local_time = datetime.strptime(localtime, time_format) - localtime = local_tz.localize(local_time) # type: ignore + if local_tz is None: + localtime = local_time.astimezone() # type: ignore + elif isinstance(local_tz, str): + local_tz = pytz.timezone(local_tz) + localtime = local_tz.localize(local_time) # type: ignore timestamp = int(localtime.timestamp()) # type: ignore return timestamp except Exception as e: diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index f9b776b3b9..91316b859a 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -27,7 +27,7 @@ class TimezoneConversionTool(BuiltinTool): target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore if not target_time: yield self.create_text_message( - f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}" + f"Invalid datetime and timezone: {current_time},{current_timezone},{target_timezone}" ) return diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 724a2291c6..84efefba07 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -20,8 +20,6 @@ class BuiltinTool(Tool): :param meta: the meta data of a tool call processing """ - provider: str - def __init__(self, provider: str, **kwargs): super().__init__(**kwargs) self.provider = provider diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 10653b9948..3c0bfa5240 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -1,7 +1,8 @@ import json from collections.abc import Generator +from dataclasses import dataclass from os import getenv -from typing import Any, Optional +from typing import Any, Optional, Union from urllib.parse import urlencode import httpx @@ -20,10 +21,21 @@ API_TOOL_DEFAULT_TIMEOUT = ( ) -class ApiTool(Tool): - api_bundle: ApiToolBundle - provider_id: str +@dataclass +class ParsedResponse: + """Represents a parsed HTTP response with type information""" + content: Union[str, dict] + is_json: bool + + def to_string(self) -> str: + """Convert response to string format for credential validation""" + if isinstance(self.content, dict): + return json.dumps(self.content, ensure_ascii=False) + return str(self.content) + + +class ApiTool(Tool): """ Api tool """ @@ -61,20 +73,19 @@ class ApiTool(Tool): response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) # validate response - return self.validate_and_parse_response(response) + parsed_response = self.validate_and_parse_response(response) + # For credential validation, always return as string + return parsed_response.to_string() def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.API def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: + headers = {} if self.runtime is None: raise ToolProviderCredentialValidationError("runtime not initialized") - headers = {} - if self.runtime is None: - raise ValueError("runtime is required") credentials = self.runtime.credentials or {} - if "auth_type" not in credentials: raise ToolProviderCredentialValidationError("Missing auth_type") @@ -115,23 +126,36 @@ class ApiTool(Tool): return headers - def validate_and_parse_response(self, response: httpx.Response) -> str: + def validate_and_parse_response(self, response: httpx.Response) -> ParsedResponse: """ - validate the response + validate the response and return parsed content with type information + + :return: ParsedResponse with content and is_json flag """ if isinstance(response, httpx.Response): if response.status_code >= 400: raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") if not response.content: - return "Empty response from the tool, please check your parameters and try again." + return ParsedResponse( + "Empty response from the tool, please check your parameters and try again.", False + ) + + # Check content type + content_type = response.headers.get("content-type", "").lower() + is_json_content_type = "application/json" in content_type + + # Try to parse as JSON try: - response = response.json() - try: - return json.dumps(response, ensure_ascii=False) - except Exception: - return json.dumps(response) + json_response = response.json() + # If content-type indicates JSON, return as JSON object + if is_json_content_type: + return ParsedResponse(json_response, True) + else: + # If content-type doesn't indicate JSON, treat as text regardless of content + return ParsedResponse(response.text, False) except Exception: - return response.text + # Not valid JSON, return as text + return ParsedResponse(response.text, False) else: raise ValueError(f"Invalid response type {type(response)}") @@ -372,7 +396,14 @@ class ApiTool(Tool): response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters) # validate response - response = self.validate_and_parse_response(response) + parsed_response = self.validate_and_parse_response(response) - # assemble invoke message - yield self.create_text_message(response) + # assemble invoke message based on response type + if parsed_response.is_json and isinstance(parsed_response.content, dict): + yield self.create_json_message(parsed_response.content) + else: + # Convert to string if needed and create text message + text_response = ( + parsed_response.content if isinstance(parsed_response.content, str) else str(parsed_response.content) + ) + yield self.create_text_message(text_response) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 27ce96b90e..48015c04ee 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -62,7 +62,7 @@ class ToolProviderApiEntity(BaseModel): parameter.pop("input_schema", None) # ------------- optional_fields = self.optional_field("server_url", self.server_url) - if self.type == ToolProviderType.MCP.value: + if self.type == ToolProviderType.MCP: optional_fields.update(self.optional_field("updated_at", self.updated_at)) optional_fields.update(self.optional_field("server_identifier", self.server_identifier)) return { diff --git a/api/core/tools/entities/file_entities.py b/api/core/tools/entities/file_entities.py deleted file mode 100644 index 8b13789179..0000000000 --- a/api/core/tools/entities/file_entities.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 5377cbbb69..df599a09a3 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,4 +1,5 @@ import base64 +import contextlib import enum from collections.abc import Mapping from enum import Enum @@ -108,10 +109,18 @@ class ApiProviderAuthType(Enum): :param value: mode value :return: mode """ + # 'api_key' deprecated in PR #21656 + # normalize & tiny alias for backward compatibility + v = (value or "").strip().lower() + if v == "api_key": + v = cls.API_KEY_HEADER.value + for mode in cls: - if mode.value == value: + if mode.value == v: return mode - raise ValueError(f"invalid mode value {value}") + + valid = ", ".join(m.value for m in cls) + raise ValueError(f"invalid mode value '{value}', expected one of: {valid}") class ToolInvokeMessage(BaseModel): @@ -219,10 +228,8 @@ class ToolInvokeMessage(BaseModel): @classmethod def decode_blob_message(cls, v): if isinstance(v, dict) and "blob" in v: - try: + with contextlib.suppress(Exception): v["blob"] = base64.b64decode(v["blob"]) - except Exception: - pass return v @field_serializer("message") diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 93f003effe..24ee981a1b 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, Optional from core.mcp.types import Tool as RemoteMCPTool from core.tools.__base.tool_provider import ToolProviderController @@ -19,15 +19,24 @@ from services.tools.tools_transform_service import ToolTransformService class MCPToolProviderController(ToolProviderController): - provider_id: str - entity: ToolProviderEntityWithPlugin - - def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None: + def __init__( + self, + entity: ToolProviderEntityWithPlugin, + provider_id: str, + tenant_id: str, + server_url: str, + headers: Optional[dict[str, str]] = None, + timeout: Optional[float] = None, + sse_read_timeout: Optional[float] = None, + ) -> None: super().__init__(entity) - self.entity = entity + self.entity: ToolProviderEntityWithPlugin = entity self.tenant_id = tenant_id self.provider_id = provider_id self.server_url = server_url + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout @property def provider_type(self) -> ToolProviderType: @@ -85,6 +94,9 @@ class MCPToolProviderController(ToolProviderController): provider_id=db_provider.server_identifier or "", tenant_id=db_provider.tenant_id or "", server_url=db_provider.decrypted_server_url, + headers={}, # TODO: get headers from db provider + timeout=db_provider.timeout, + sse_read_timeout=db_provider.sse_read_timeout, ) def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: @@ -111,6 +123,9 @@ class MCPToolProviderController(ToolProviderController): icon=self.entity.identity.icon, server_url=self.server_url, provider_id=self.provider_id, + headers=self.headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, ) def get_tools(self) -> list[MCPTool]: # type: ignore @@ -125,6 +140,9 @@ class MCPToolProviderController(ToolProviderController): icon=self.entity.identity.icon, server_url=self.server_url, provider_id=self.provider_id, + headers=self.headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, ) for tool_entity in self.entity.tools ] diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index d1bacbc735..26789b23ce 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -8,25 +8,30 @@ from core.mcp.mcp_client import MCPClient from core.mcp.types import ImageContent, TextContent from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType class MCPTool(Tool): - tenant_id: str - icon: str - runtime_parameters: Optional[list[ToolParameter]] - server_url: str - provider_id: str - def __init__( - self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str + self, + entity: ToolEntity, + runtime: ToolRuntime, + tenant_id: str, + icon: str, + server_url: str, + provider_id: str, + headers: Optional[dict[str, str]] = None, + timeout: Optional[float] = None, + sse_read_timeout: Optional[float] = None, ) -> None: super().__init__(entity, runtime) self.tenant_id = tenant_id self.icon = icon - self.runtime_parameters = None self.server_url = server_url self.provider_id = provider_id + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.MCP @@ -42,7 +47,15 @@ class MCPTool(Tool): from core.tools.errors import ToolInvokeError try: - with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: + with MCPClient( + self.server_url, + self.provider_id, + self.tenant_id, + authed=True, + headers=self.headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, + ) as mcp_client: tool_parameters = self._handle_none_parameter(tool_parameters) result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) except MCPAuthError as e: @@ -79,6 +92,9 @@ class MCPTool(Tool): icon=self.icon, server_url=self.server_url, provider_id=self.provider_id, + headers=self.headers, + timeout=self.timeout, + sse_read_timeout=self.sse_read_timeout, ) def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index aef2677c36..db38c10e81 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -9,11 +9,6 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too class PluginTool(Tool): - tenant_id: str - icon: str - plugin_unique_identifier: str - runtime_parameters: Optional[list[ToolParameter]] - def __init__( self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str ) -> None: @@ -21,7 +16,7 @@ class PluginTool(Tool): self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier - self.runtime_parameters = None + self.runtime_parameters: Optional[list[ToolParameter]] = None def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.PLUGIN diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 178f2b9689..10db4d9503 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,3 +1,4 @@ +import contextlib import json from collections.abc import Generator, Iterable from copy import deepcopy @@ -29,7 +30,7 @@ from core.tools.errors import ( ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) -from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.enums import CreatorUserRole @@ -69,10 +70,8 @@ class ToolEngine: if parameters and len(parameters) == 1: tool_parameters = {parameters[0].name: tool_parameters} else: - try: + with contextlib.suppress(Exception): tool_parameters = json.loads(tool_parameters) - except Exception: - pass if not isinstance(tool_parameters, dict): raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") @@ -247,7 +246,8 @@ class ToolEngine: ) elif response.type == ToolInvokeMessage.MessageType.JSON: result += json.dumps( - cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False + safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object), + ensure_ascii=False, ) else: result += str(response.message) @@ -269,14 +269,12 @@ class ToolEngine: if response.meta.get("mime_type"): mimetype = response.meta.get("mime_type") else: - try: + with contextlib.suppress(Exception): url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text) extension = url.suffix guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: mimetype = guess_type_result - except Exception: - pass if not mimetype: mimetype = "image/jpeg" diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 71c237c7f7..2089313b08 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -7,6 +7,7 @@ from os import listdir, path from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +import sqlalchemy as sa from pydantic import TypeAdapter from yarl import URL @@ -206,7 +207,7 @@ class ToolManager: ) except Exception as e: builtin_provider = None - logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True) + logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) # if the provider has been deleted, raise an error if builtin_provider is None: raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") @@ -237,7 +238,7 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - encrypter, _ = create_provider_encrypter( + encrypter, cache = create_provider_encrypter( tenant_id=tenant_id, config=[ x.to_basic_provider_config() @@ -281,6 +282,7 @@ class ToolManager: builtin_provider.expires_at = refreshed_credentials.expires_at db.session.commit() decrypted_credentials = refreshed_credentials.credentials + cache.delete() return cast( BuiltinTool, @@ -569,7 +571,7 @@ class ToolManager: yield provider except Exception: - logger.exception(f"load builtin provider {provider_path}") + logger.exception("load builtin provider %s", provider_path) continue # set builtin providers loaded cls._builtin_providers_loaded = True @@ -615,7 +617,7 @@ class ToolManager: WHERE tenant_id = :tenant_id ORDER BY tenant_id, provider, is_default DESC, created_at DESC """ - ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() @classmethod @@ -787,9 +789,6 @@ class ToolManager: """ get api provider """ - """ - get tool provider - """ provider_name = provider provider_obj: ApiToolProvider | None = ( db.session.query(ApiToolProvider) @@ -960,7 +959,7 @@ class ToolManager: elif provider_type == ToolProviderType.WORKFLOW: return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) elif provider_type == ToolProviderType.PLUGIN: - provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + provider = ToolManager.get_plugin_provider(provider_id, tenant_id) if isinstance(provider, PluginToolProviderController): try: return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index aceba6e69f..3a9391dbb1 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,3 +1,4 @@ +import contextlib from copy import deepcopy from typing import Any @@ -137,11 +138,9 @@ class ToolParameterConfigurationManager: and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT ): if parameter.name in parameters: - try: - has_secret_input = True + has_secret_input = True + with contextlib.suppress(Exception): parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) - except Exception: - pass if has_secret_input: cache.set(parameters) diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index ec0575f6c3..d58807e29f 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -20,8 +20,6 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas class DatasetRetrieverTool(Tool): - retrieval_tool: DatasetRetrieverBaseTool - def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None: super().__init__(entity, runtime) self.retrieval_tool = retrieval_tool diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index 5fdfd3b9d1..d771293e11 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -1,3 +1,4 @@ +import contextlib from copy import deepcopy from typing import Any, Optional, Protocol @@ -111,14 +112,12 @@ class ProviderConfigEncrypter: for field_name, field in fields.items(): if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field_name in data: - try: + with contextlib.suppress(Exception): # if the value is None or empty string, skip decrypt if not data[field_name]: continue data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - except Exception: - pass self.provider_config_cache.set(data) return data diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 9998de0465..ac12d83ef2 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,14 @@ import logging from collections.abc import Generator +from datetime import date, datetime +from decimal import Decimal from mimetypes import guess_extension -from typing import Optional +from typing import Optional, cast +from uuid import UUID + +import numpy as np +import pytz +from flask_login import current_user from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage @@ -10,6 +17,41 @@ from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) +def safe_json_value(v): + if isinstance(v, datetime): + tz_name = getattr(current_user, "timezone", None) if current_user is not None else None + if not tz_name: + tz_name = "UTC" + return v.astimezone(pytz.timezone(tz_name)).isoformat() + elif isinstance(v, date): + return v.isoformat() + elif isinstance(v, UUID): + return str(v) + elif isinstance(v, Decimal): + return float(v) + elif isinstance(v, bytes): + try: + return v.decode("utf-8") + except UnicodeDecodeError: + return v.hex() + elif isinstance(v, memoryview): + return v.tobytes().hex() + elif isinstance(v, np.ndarray): + return v.tolist() + elif isinstance(v, dict): + return safe_json_dict(v) + elif isinstance(v, list | tuple | set): + return [safe_json_value(i) for i in v] + else: + return v + + +def safe_json_dict(d): + if not isinstance(d, dict): + raise TypeError("safe_json_dict() expects a dictionary (dict) as input") + return {k: safe_json_value(v) for k, v in d.items()} + + class ToolFileMessageTransformer: @classmethod def transform_tool_invoke_messages( @@ -113,6 +155,12 @@ class ToolFileMessageTransformer: ) else: yield message + + elif message.type == ToolInvokeMessage.MessageType.JSON: + if isinstance(message.message, ToolInvokeMessage.JsonMessage): + json_msg = cast(ToolInvokeMessage.JsonMessage, message.message) + json_msg.json_object = safe_json_value(json_msg.json_object) + yield message else: yield message diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index a3c84615ca..3857a2a16b 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -105,6 +105,29 @@ class ApiBasedToolSchemaParser: # overwrite the content interface["operation"]["requestBody"]["content"][content_type]["schema"] = root + # handle allOf reference in schema properties + for prop_dict in root.get("properties", {}).values(): + for item in prop_dict.get("allOf", []): + if "$ref" in item: + ref_schema = openapi + reference = item["$ref"].split("/")[1:] + for ref in reference: + ref_schema = ref_schema[ref] + else: + ref_schema = item + for key, value in ref_schema.items(): + if isinstance(value, list): + if key not in prop_dict: + prop_dict[key] = [] + # extends list field + if isinstance(prop_dict[key], list): + prop_dict[key].extend(value) + elif key not in prop_dict: + # add new field + prop_dict[key] = value + if "allOf" in prop_dict: + del prop_dict["allOf"] + # parse body parameters if "schema" in interface["operation"]["requestBody"]["content"][content_type]: body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] diff --git a/api/core/tools/utils/rag_web_reader.py b/api/core/tools/utils/rag_web_reader.py deleted file mode 100644 index 22c47fa814..0000000000 --- a/api/core/tools/utils/rag_web_reader.py +++ /dev/null @@ -1,17 +0,0 @@ -import re - - -def get_image_upload_file_ids(content): - pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" - matches = re.findall(pattern, content) - image_upload_file_ids = [] - for match in matches: - if match[1] == "file-preview": - content_pattern = r"files/([^/]+)/file-preview" - else: - content_pattern = r"files/([^/]+)/image-preview" - content_match = re.search(content_pattern, match[0]) - if content_match: - image_upload_file_id = content_match.group(1) - image_upload_file_ids.append(image_upload_file_id) - return image_upload_file_ids diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index cbd06fc186..d8403c2e15 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -55,7 +55,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: main_content_type = mimetypes.guess_type(filename)[0] if main_content_type not in supported_content_types: - return "Unsupported content-type [{}] of URL.".format(main_content_type) + return f"Unsupported content-type [{main_content_type}] of URL." if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) @@ -67,7 +67,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore if response.status_code != 200: - return "URL returned status code {}.".format(response.status_code) + return f"URL returned status code {response.status_code}." # Detect encoding using chardet detected_encoding = chardet.detect(response.content) @@ -87,7 +87,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: res = FULL_TEMPLATE.format( title=article.title, - author=article.auther, + author=article.author, text=article.text, ) @@ -97,7 +97,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: @dataclass class Article: title: str - auther: str + author: str text: Sequence[dict] @@ -105,7 +105,7 @@ def extract_using_readabilipy(html: str): json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True) article = Article( title=json_article.get("title") or "", - auther=json_article.get("byline") or "", + author=json_article.get("byline") or "", text=json_article.get("plain_text") or [], ) @@ -113,7 +113,7 @@ def extract_using_readabilipy(html: str): def get_image_upload_file_ids(content): - pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" + pattern = r"!\[image\]\((https?://.*?(file-preview|image-preview))\)" matches = re.findall(pattern, content) image_upload_file_ids = [] for match in matches: diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 83f5f558d5..18e6993b38 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -203,9 +203,6 @@ class WorkflowToolProviderController(ToolProviderController): raise ValueError("app not found") app = db_providers.app - if not app: - raise ValueError("can not read app of workflow") - self.tools = [self._get_db_provider_tool(db_providers, app)] return self.tools diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 8b89c2a7a9..6824e5e0e8 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -25,15 +25,6 @@ logger = logging.getLogger(__name__) class WorkflowTool(Tool): - workflow_app_id: str - version: str - workflow_entities: dict[str, Any] - workflow_call_depth: int - thread_pool_id: Optional[str] = None - workflow_as_tool_id: str - - label: str - """ Workflow tool. """ @@ -142,7 +133,7 @@ class WorkflowTool(Tool): if not version: workflow = ( db.session.query(Workflow) - .where(Workflow.app_id == app_id, Workflow.version != "draft") + .where(Workflow.app_id == app_id, Workflow.version != Workflow.VERSION_DRAFT) .order_by(Workflow.created_at.desc()) .first() ) @@ -194,7 +185,7 @@ class WorkflowTool(Tool): files.append(file_dict) except Exception: - logger.exception(f"Failed to transform file {file}") + logger.exception("Failed to transform file %s", file) else: parameters_result[parameter.name] = tool_parameters.get(parameter.name) diff --git a/api/core/variables/consts.py b/api/core/variables/consts.py index 03b277d619..8f3f78f740 100644 --- a/api/core/variables/consts.py +++ b/api/core/variables/consts.py @@ -4,4 +4,4 @@ # # If the selector length is more than 2, the remaining parts are the keys / indexes paths used # to extract part of the variable value. -MIN_SELECTORS_LENGTH = 2 +SELECTORS_LENGTH = 2 diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 13274f4e0e..a99f5eece3 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -119,6 +119,13 @@ class ObjectSegment(Segment): class ArraySegment(Segment): + @property + def text(self) -> str: + # Return empty string for empty arrays instead of "[]" + if not self.value: + return "" + return super().text + @property def markdown(self) -> str: items = [] @@ -155,6 +162,9 @@ class ArrayStringSegment(ArraySegment): @property def text(self) -> str: + # Return empty string for empty arrays instead of "[]" + if not self.value: + return "" return json.dumps(self.value, ensure_ascii=False) diff --git a/api/core/variables/types.py b/api/core/variables/types.py index e79b2410bf..6629056042 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -109,7 +109,7 @@ class SegmentType(StrEnum): elif array_validation == ArrayValidation.FIRST: return element_type.is_valid(value[0]) else: - return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value) + return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value) def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool: """ @@ -126,7 +126,7 @@ class SegmentType(StrEnum): """ if self.is_array_type(): return self._validate_array(value, array_validation) - elif self == SegmentType.NUMBER: + elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]: return isinstance(value, (int, float)) elif self == SegmentType.STRING: return isinstance(value, str) @@ -152,7 +152,7 @@ class SegmentType(StrEnum): _ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { - # ARRAY_ANY does not have correpond element type. + # ARRAY_ANY does not have corresponding element type. SegmentType.ARRAY_STRING: SegmentType.STRING, SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, @@ -166,7 +166,6 @@ _ARRAY_TYPES = frozenset( ] ) - _NUMERICAL_TYPES = frozenset( [ SegmentType.NUMBER, diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py index 692db3502e..7ebd29f865 100644 --- a/api/core/variables/utils.py +++ b/api/core/variables/utils.py @@ -1,5 +1,7 @@ -import json from collections.abc import Iterable, Sequence +from typing import Any + +import orjson from .segment_group import SegmentGroup from .segments import ArrayFileSegment, FileSegment, Segment @@ -12,15 +14,20 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[ return selectors -class SegmentJSONEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, ArrayFileSegment): - return [v.model_dump() for v in o.value] - elif isinstance(o, FileSegment): - return o.value.model_dump() - elif isinstance(o, SegmentGroup): - return [self.default(seg) for seg in o.value] - elif isinstance(o, Segment): - return o.value - else: - super().default(o) +def segment_orjson_default(o: Any) -> Any: + """Default function for orjson serialization of Segment types""" + if isinstance(o, ArrayFileSegment): + return [v.model_dump() for v in o.value] + elif isinstance(o, FileSegment): + return o.value.model_dump() + elif isinstance(o, SegmentGroup): + return [segment_orjson_default(seg) for seg in o.value] + elif isinstance(o, Segment): + return o.value + raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") + + +def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str: + """JSON dumps with segment support using orjson""" + option = orjson.OPT_NON_STR_KEYS + return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8") diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index fbb8df6b01..fb0794844e 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable -from core.variables.consts import MIN_SELECTORS_LENGTH -from core.variables.segments import FileSegment, NoneSegment +from core.variables.consts import SELECTORS_LENGTH +from core.variables.segments import FileSegment, ObjectSegment from core.variables.variables import VariableUnion from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.system_variable import SystemVariable @@ -24,7 +24,7 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) @@ -36,6 +36,7 @@ class VariablePool(BaseModel): ) system_variables: SystemVariable = Field( description="System variables", + default_factory=SystemVariable.empty, ) environment_variables: Sequence[VariableUnion] = Field( description="Environment variables.", @@ -58,23 +59,29 @@ class VariablePool(BaseModel): def add(self, selector: Sequence[str], value: Any, /) -> None: """ - Adds a variable to the variable pool. + Add a variable to the variable pool. - NOTE: You should not add a non-Segment value to the variable pool - even if it is allowed now. + This method accepts a selector path and a value, converting the value + to a Variable object if necessary before storing it in the pool. Args: - selector (Sequence[str]): The selector for the variable. - value (VariableValue): The value of the variable. + selector: A two-element sequence containing [node_id, variable_name]. + The selector must have exactly 2 elements to be valid. + value: The value to store. Can be a Variable, Segment, or any value + that can be converted to a Segment (str, int, float, dict, list, File). Raises: - ValueError: If the selector is invalid. + ValueError: If selector length is not exactly 2 elements. - Returns: - None + Note: + While non-Segment values are currently accepted and automatically + converted, it's recommended to pass Segment or Variable objects directly. """ - if len(selector) < MIN_SELECTORS_LENGTH: - raise ValueError("Invalid selector") + if len(selector) != SELECTORS_LENGTH: + raise ValueError( + f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), " + f"got {len(selector)} elements" + ) if isinstance(value, Variable): variable = value @@ -84,57 +91,85 @@ class VariablePool(BaseModel): segment = variable_factory.build_segment(value) variable = variable_factory.segment_to_variable(segment=segment, selector=selector) - key, hash_key = self._selector_to_keys(selector) + node_id, name = self._selector_to_keys(selector) # Based on the definition of `VariableUnion`, # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. - self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) + self.variable_dictionary[node_id][name] = cast(VariableUnion, variable) @classmethod - def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]: - return selector[0], hash(tuple(selector[1:])) + def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: + return selector[0], selector[1] def _has(self, selector: Sequence[str]) -> bool: - key, hash_key = self._selector_to_keys(selector) - if key not in self.variable_dictionary: + node_id, name = self._selector_to_keys(selector) + if node_id not in self.variable_dictionary: return False - if hash_key not in self.variable_dictionary[key]: + if name not in self.variable_dictionary[node_id]: return False return True def get(self, selector: Sequence[str], /) -> Segment | None: """ - Retrieves the value from the variable pool based on the given selector. + Retrieve a variable's value from the pool as a Segment. + + This method supports both simple selectors [node_id, variable_name] and + extended selectors that include attribute access for FileSegment and + ObjectSegment types. Args: - selector (Sequence[str]): The selector used to identify the variable. + selector: A sequence with at least 2 elements: + - [node_id, variable_name]: Returns the full segment + - [node_id, variable_name, attr, ...]: Returns a nested value + from FileSegment (e.g., 'url', 'name') or ObjectSegment Returns: - Any: The value associated with the given selector. + The Segment associated with the selector, or None if not found. + Returns None if selector has fewer than 2 elements. Raises: - ValueError: If the selector is invalid. + ValueError: If attempting to access an invalid FileAttribute. """ - if len(selector) < MIN_SELECTORS_LENGTH: + if len(selector) < SELECTORS_LENGTH: return None - key, hash_key = self._selector_to_keys(selector) - value: Segment | None = self.variable_dictionary[key].get(hash_key) + node_id, name = self._selector_to_keys(selector) + segment: Segment | None = self.variable_dictionary[node_id].get(name) - if value is None: - selector, attr = selector[:-1], selector[-1] + if segment is None: + return None + + if len(selector) == 2: + return segment + + if isinstance(segment, FileSegment): + attr = selector[2] # Python support `attr in FileAttribute` after 3.12 if attr not in {item.value for item in FileAttribute}: return None - value = self.get(selector) - if not isinstance(value, FileSegment | NoneSegment): - return None - if isinstance(value, FileSegment): - attr = FileAttribute(attr) - attr_value = file_manager.get_attr(file=value.value, attr=attr) - return variable_factory.build_segment(attr_value) - return value + attr = FileAttribute(attr) + attr_value = file_manager.get_attr(file=segment.value, attr=attr) + return variable_factory.build_segment(attr_value) - return value + # Navigate through nested attributes + result: Any = segment + for attr in selector[2:]: + result = self._extract_value(result) + result = self._get_nested_attribute(result, attr) + if result is None: + return None + + # Return result as Segment + return result if isinstance(result, Segment) else variable_factory.build_segment(result) + + def _extract_value(self, obj: Any) -> Any: + """Extract the actual value from an ObjectSegment.""" + return obj.value if isinstance(obj, ObjectSegment) else obj + + def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any: + """Get a nested attribute from a dictionary-like object.""" + if not isinstance(obj, dict): + return None + return obj.get(attr) def remove(self, selector: Sequence[str], /): """ diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py index 781be4b3c6..f00dc11aa6 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/core/workflow/entities/workflow_execution.py @@ -6,12 +6,14 @@ implementation details like tenant_id, app_id, etc. """ from collections.abc import Mapping -from datetime import UTC, datetime +from datetime import datetime from enum import StrEnum from typing import Any, Optional from pydantic import BaseModel, Field +from libs.datetime_utils import naive_utc_now + class WorkflowType(StrEnum): """ @@ -60,7 +62,7 @@ class WorkflowExecution(BaseModel): Calculate elapsed time in seconds. If workflow is not finished, use current time. """ - end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) + end_time = self.finished_at or naive_utc_now() return (end_time - self.started_at).total_seconds() @classmethod diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index a62ffe46c9..e2ec7b17f0 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -22,7 +22,7 @@ class GraphRuntimeState(BaseModel): # # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent # after a serialization and deserialization round trip. - outputs: dict[str, Any] = {} + outputs: dict[str, Any] = Field(default_factory=dict) node_run_steps: int = 0 """node run steps""" diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index f2d9c98936..a4ddfafab5 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,5 +1,5 @@ import uuid -from datetime import UTC, datetime +from datetime import datetime from enum import Enum from typing import Optional @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from libs.datetime_utils import naive_utc_now class RouteNodeState(BaseModel): @@ -71,7 +72,7 @@ class RouteNodeState(BaseModel): raise Exception(f"Invalid route status {run_result.status}") self.node_run_result = run_result - self.finished_at = datetime.now(UTC).replace(tzinfo=None) + self.finished_at = naive_utc_now() class RuntimeRouteState(BaseModel): @@ -89,7 +90,7 @@ class RuntimeRouteState(BaseModel): :param node_id: node id """ - state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) + state = RouteNodeState(node_id=node_id, start_at=naive_utc_now()) self.node_state_mapping[state.id] = state return state diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index b315129763..03b920ccbb 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -6,7 +6,6 @@ import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy -from datetime import UTC, datetime from typing import Any, Optional, cast from flask import Flask, current_app @@ -15,7 +14,7 @@ from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult -from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( @@ -51,7 +50,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.utils import variable_utils +from libs.datetime_utils import naive_utc_now from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom from models.workflow import WorkflowType @@ -238,13 +237,13 @@ class GraphEngine: while True: # max steps reached if self.graph_runtime_state.node_run_steps > self.max_execution_steps: - raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) + raise GraphRunFailedError(f"Max steps {self.max_execution_steps} reached.") # or max execution time reached if self._is_timed_out( start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time ): - raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time)) + raise GraphRunFailedError(f"Max execution time {self.max_execution_time}s reached.") # init route node state route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) @@ -377,7 +376,7 @@ class GraphEngine: edge = cast(GraphEdge, sub_edge_mappings[0]) if edge.run_condition is None: - logger.warning(f"Edge {edge.target_node_id} run condition is None") + logger.warning("Edge %s run condition is None", edge.target_node_id) continue result = ConditionManager.get_condition_handler( @@ -641,7 +640,7 @@ class GraphEngine: while should_continue_retry and retries <= max_retries: try: # run node - retry_start_at = datetime.now(UTC).replace(tzinfo=None) + retry_start_at = naive_utc_now() # yield control to other threads time.sleep(0.001) event_stream = node.run() @@ -701,11 +700,9 @@ class GraphEngine: route_node_state.status = RouteNodeState.Status.EXCEPTION if run_result.outputs: for variable_key, variable_value in run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value, + # Add variables to variable pool + self.graph_runtime_state.variable_pool.add( + [node.node_id, variable_key], variable_value ) yield NodeRunExceptionEvent( error=run_result.error or "System Error", @@ -758,11 +755,9 @@ class GraphEngine: # append node output variables to variable pool if run_result.outputs: for variable_key, variable_value in run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value, + # Add variables to variable pool + self.graph_runtime_state.variable_pool.add( + [node.node_id, variable_key], variable_value ) # When setting metadata, convert to dict first @@ -848,24 +843,9 @@ class GraphEngine: ) return except Exception as e: - logger.exception(f"Node {node.title} run failed") + logger.exception("Node %s run failed", node.title) raise e - def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): - """ - Append variables recursively - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - variable_utils.append_variables_recursively( - self.graph_runtime_state.variable_pool, - node_id, - variable_key_list, - variable_value, - ) - def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ Check timeout diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index c83303034e..144f036aa4 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -13,8 +13,9 @@ from core.agent.strategy.plugin import PluginAgentStrategy from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.request import InvokeCredentials from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller @@ -50,6 +51,7 @@ from .exc import ( AgentInputTypeError, AgentInvocationError, AgentMessageTransformError, + AgentNodeError, AgentVariableNotFoundError, AgentVariableTypeError, ToolFileNotFoundError, @@ -557,7 +559,7 @@ class AgentNode(BaseNode): assert isinstance(message.message, ToolInvokeMessage.JsonMessage) if node_type == NodeType.AGENT: msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(msg_metadata) + llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) agent_execution_metadata = { WorkflowNodeExecutionMetadataKey(key): value for key, value in msg_metadata.items() @@ -593,7 +595,14 @@ class AgentNode(BaseNode): variables[variable_name] = variable_value elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None - assert isinstance(message.meta, File) + assert isinstance(message.meta, dict) + # Validate that meta contains a 'file' key + if "file" not in message.meta: + raise AgentNodeError("File message is missing 'file' key in meta") + + # Validate that the file is an instance of File + if not isinstance(message.meta["file"], File): + raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") files.append(message.meta["file"]) elif message.type == ToolInvokeMessage.MessageType.LOG: assert isinstance(message.message, ToolInvokeMessage.LogMessage) @@ -684,7 +693,13 @@ class AgentNode(BaseNode): yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, + outputs={ + "text": text, + "usage": jsonable_encoder(llm_usage), + "files": ArrayFileSegment(value=files), + "json": json_output, + **variables, + }, metadata={ **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 09d5464d7a..7e84557a2d 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -36,7 +36,7 @@ class StreamProcessor(ABC): reachable_node_ids: list[str] = [] unreachable_first_node_ids: list[str] = [] if finished_node_id not in self.graph.edge_mapping: - logger.warning(f"node {finished_node_id} has no edge mapping") + logger.warning("node %s has no edge mapping", finished_node_id) return for edge in self.graph.edge_mapping[finished_node_id]: if ( diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index fb5ec55453..be4f79af19 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -65,7 +65,7 @@ class BaseNode: try: result = self._run() except Exception as e: - logger.exception(f"Node {self.node_id} failed to run") + logger.exception("Node %s failed to run", self.node_id) result = NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index ab5964ebd4..a61e6ba4ac 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -168,7 +168,57 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: """Extract text from a file based on its file extension.""" match file_extension: - case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml": + case ( + ".txt" + | ".markdown" + | ".md" + | ".html" + | ".htm" + | ".xml" + | ".c" + | ".h" + | ".cpp" + | ".hpp" + | ".cc" + | ".cxx" + | ".c++" + | ".py" + | ".js" + | ".ts" + | ".jsx" + | ".tsx" + | ".java" + | ".php" + | ".rb" + | ".go" + | ".rs" + | ".swift" + | ".kt" + | ".scala" + | ".sh" + | ".bash" + | ".bat" + | ".ps1" + | ".sql" + | ".r" + | ".m" + | ".pl" + | ".lua" + | ".vim" + | ".asm" + | ".s" + | ".css" + | ".scss" + | ".less" + | ".sass" + | ".ini" + | ".cfg" + | ".conf" + | ".toml" + | ".env" + | ".log" + | ".vtt" + ): return _extract_text_from_plain_text(file_content) case ".json": return _extract_text_from_json(file_content) @@ -194,8 +244,6 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) return _extract_text_from_eml(file_content) case ".msg": return _extract_text_from_msg(file_content) - case ".vtt": - return _extract_text_from_vtt(file_content) case ".properties": return _extract_text_from_properties(file_content) case _: @@ -305,7 +353,7 @@ def _extract_text_from_doc(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e -def paser_docx_part(block, doc: Document, content_items, i): +def parser_docx_part(block, doc: Document, content_items, i): if isinstance(block, CT_P): content_items.append((i, "paragraph", Paragraph(block, doc))) elif isinstance(block, CT_Tbl): @@ -329,7 +377,7 @@ def _extract_text_from_docx(file_content: bytes) -> str: part = next(it, None) i = 0 while part is not None: - paser_docx_part(part, doc, content_items, i) + parser_docx_part(part, doc, content_items, i) i = i + 1 part = next(it, None) @@ -363,7 +411,7 @@ def _extract_text_from_docx(file_content: bytes) -> str: text.append(markdown_table) except Exception as e: - logger.warning(f"Failed to extract table from DOC: {e}") + logger.warning("Failed to extract table from DOC: %s", e) continue return "\n".join(text) @@ -597,7 +645,7 @@ def _extract_text_from_vtt(vtt_bytes: bytes) -> str: for i in range(1, len(raw_results)): spk, txt = raw_results[i] - if spk == None: + if spk is None: merged_results.append((None, current_text)) continue diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 8ac1ae8526..a5a578a6ff 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -12,6 +12,7 @@ from json_repair import repair_json from configs import dify_config from core.file import file_manager +from core.file.enums import FileTransferMethod from core.helper import ssrf_proxy from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.entities.variable_pool import VariablePool @@ -91,7 +92,7 @@ class Executor: self.auth = node_data.authorization self.timeout = timeout self.ssl_verify = node_data.ssl_verify - self.params = [] + self.params = None self.headers = {} self.content = None self.files = None @@ -139,7 +140,8 @@ class Executor: (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) ) - self.params = result + if result: + self.params = result def _init_headers(self): """ @@ -227,7 +229,9 @@ class Executor: files: dict[str, list[tuple[str | None, bytes, str]]] = {} for key, files_in_segment in files_list: for file in files_in_segment: - if file.related_id is not None: + if file.related_id is not None or ( + file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None + ): file_tuple = ( file.filename, file_manager.download(file), @@ -265,9 +269,9 @@ class Executor: if not authorization.config.header: authorization.config.header = "Authorization" - if self.auth.config.type == "bearer": + if self.auth.config.type == "bearer" and authorization.config.api_key: headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.auth.config.type == "basic": + elif self.auth.config.type == "basic" and authorization.config.api_key: credentials = authorization.config.api_key if ":" in credentials: encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") @@ -275,7 +279,32 @@ class Executor: encoded_credentials = credentials headers[authorization.config.header] = f"Basic {encoded_credentials}" elif self.auth.config.type == "custom": - headers[authorization.config.header] = authorization.config.api_key or "" + if authorization.config.header and authorization.config.api_key: + headers[authorization.config.header] = authorization.config.api_key + + # Handle Content-Type for multipart/form-data requests + # Fix for issue #23829: Missing boundary when using multipart/form-data + body = self.node_data.body + if body and body.type == "form-data": + # For multipart/form-data with files (including placeholder files), + # remove any manually set Content-Type header to let httpx handle + # For multipart/form-data, if any files are present (including placeholder files), + # we must remove any manually set Content-Type header. This is because httpx needs to + # automatically set the Content-Type and boundary for multipart encoding whenever files + # are included, even if they are placeholders, to avoid boundary issues and ensure correct + # file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the + # boundary, resulting in invalid requests. + if self.files: + # Remove Content-Type if it was manually set to avoid boundary issues + headers = {k: v for k, v in headers.items() if k.lower() != "content-type"} + else: + # No files at all, set Content-Type manually + if "content-type" not in (k.lower() for k in headers): + headers["Content-Type"] = "multipart/form-data" + elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE: + # Set Content-Type for other body types + if "content-type" not in (k.lower() for k in headers): + headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] return headers @@ -384,15 +413,24 @@ class Executor: # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. # This prevents logging meaningless placeholder entries. if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): - for key, (filename, content, mime_type) in self.files: + for file_entry in self.files: + # file_entry should be (key, (filename, content, mime_type)), but handle edge cases + if len(file_entry) != 2 or not isinstance(file_entry[1], tuple) or len(file_entry[1]) < 2: + continue # skip malformed entries + key = file_entry[0] + content = file_entry[1][1] body_string += f"--{boundary}\r\n" body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - # decode content - try: - body_string += content.decode("utf-8") - except UnicodeDecodeError: - # fix: decode binary content - pass + # decode content safely + if isinstance(content, bytes): + try: + body_string += content.decode("utf-8") + except UnicodeDecodeError: + body_string += content.decode("utf-8", errors="replace") + elif isinstance(content, str): + body_string += content + else: + body_string += f"[Unsupported content type: {type(content).__name__}]" body_string += "\r\n" body_string += f"--{boundary}--\r\n" elif self.node_data.body: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 6799d5c63c..bc1d5c9b87 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -129,7 +129,7 @@ class HttpRequestNode(BaseNode): }, ) except HttpRequestNodeError as e: - logger.warning(f"http request node {self.node_id} failed to run: {e}") + logger.warning("http request node %s failed to run: %s", self.node_id, e) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 86e703dc68..2c83ea3d4f 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -129,7 +129,7 @@ class IfElseNode(BaseNode): var_mapping: dict[str, list[str]] = {} for case in typed_node_data.cases or []: for condition in case.conditions: - key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) + key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" var_mapping[key] = condition.variable_selector return var_mapping diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 5842c8d64b..7f591a3ea9 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -4,7 +4,7 @@ import time import uuid from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait -from datetime import UTC, datetime +from datetime import datetime from queue import Empty, Queue from typing import TYPE_CHECKING, Any, Optional, cast @@ -41,6 +41,7 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from factories.variable_factory import build_segment +from libs.datetime_utils import naive_utc_now from libs.flask_utils import preserve_flask_contexts from .exc import ( @@ -179,7 +180,7 @@ class IterationNode(BaseNode): thread_pool_id=self.thread_pool_id, ) - start_at = datetime.now(UTC).replace(tzinfo=None) + start_at = naive_utc_now() yield IterationRunStartedEvent( iteration_id=self.id, @@ -428,7 +429,7 @@ class IterationNode(BaseNode): """ run single iteration """ - iter_start_at = datetime.now(UTC).replace(tzinfo=None) + iter_start_at = naive_utc_now() try: rst = graph_engine.run() @@ -505,7 +506,7 @@ class IterationNode(BaseNode): variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (naive_utc_now() - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -526,7 +527,7 @@ class IterationNode(BaseNode): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (naive_utc_now() - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -602,7 +603,7 @@ class IterationNode(BaseNode): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (naive_utc_now() - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -616,7 +617,7 @@ class IterationNode(BaseNode): ) except IterationNodeError as e: - logger.warning(f"Iteration run failed:{str(e)}") + logger.warning("Iteration run failed:%s", str(e)) yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index f1767bdf9e..b71271abeb 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -74,6 +74,8 @@ SupportedComparisonOperator = Literal[ "is not", "empty", "not empty", + "in", + "not in", # for number "=", "≠", diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 34b0afc75d..5e5c9f520e 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import cast as sqlalchemy_cast -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -175,7 +175,7 @@ class KnowledgeRetrievalNode(BaseNode): redis_client.zremrangebyscore(key, 0, current_time - 60000) request_count = redis_client.zcard(key) if request_count > knowledge_rate_limit.limit: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: # add ratelimit record rate_limit_log = RateLimitLog( tenant_id=self.tenant_id, @@ -183,7 +183,6 @@ class KnowledgeRetrievalNode(BaseNode): operation="knowledge", ) session.add(rate_limit_log) - session.commit() return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, @@ -389,6 +388,15 @@ class KnowledgeRetrievalNode(BaseNode): "segment_id": segment.id, "retriever_from": "workflow", "score": record.score or 0.0, + "child_chunks": [ + { + "id": str(getattr(chunk, "id", "")), + "content": str(getattr(chunk, "content", "")), + "position": int(getattr(chunk, "position", 0)), + "score": float(getattr(chunk, "score", 0.0)), + } + for chunk in (record.child_chunks or []) + ], "segment_hit_count": segment.hit_count, "segment_word_count": segment.word_count, "segment_position": segment.position, @@ -453,35 +461,34 @@ class KnowledgeRetrievalNode(BaseNode): elif node_data.metadata_filtering_mode == "manual": if node_data.metadata_filtering_conditions: conditions = [] - if node_data.metadata_filtering_conditions: - for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore - metadata_name = condition.name - expected_value = condition.value - if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): - if isinstance(expected_value, str): - expected_value = self.graph_runtime_state.variable_pool.convert_template( - expected_value - ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: # type: ignore - expected_value = expected_value.value # type: ignore - elif expected_value.value_type == "string": # type: ignore - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore - else: - raise ValueError("Invalid expected metadata value type") - conditions.append( - Condition( - name=metadata_name, - comparison_operator=condition.comparison_operator, - value=expected_value, - ) - ) - filters = self._process_metadata_filter_func( - sequence, - condition.comparison_operator, - metadata_name, - expected_value, - filters, + for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore + metadata_name = condition.name + expected_value = condition.value + if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): + if isinstance(expected_value, str): + expected_value = self.graph_runtime_state.variable_pool.convert_template( + expected_value + ).value[0] + if expected_value.value_type in {"number", "integer", "float"}: # type: ignore + expected_value = expected_value.value # type: ignore + elif expected_value.value_type == "string": # type: ignore + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore + else: + raise ValueError("Invalid expected metadata value type") + conditions.append( + Condition( + name=metadata_name, + comparison_operator=condition.comparison_operator, + value=expected_value, ) + ) + filters = self._process_metadata_filter_func( + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, + ) metadata_condition = MetadataCondition( logical_operator=node_data.metadata_filtering_conditions.logical_operator, conditions=conditions, @@ -573,7 +580,7 @@ class KnowledgeRetrievalNode(BaseNode): def _process_metadata_filter_func( self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list ): - if value is None: + if value is None and condition not in ("empty", "not empty"): return key = f"{metadata_name}_{sequence}" @@ -603,6 +610,28 @@ class KnowledgeRetrievalNode(BaseNode): **{key: metadata_name, key_value: f"%{value}"} ) ) + case "in": + if isinstance(value, str): + escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")] + escaped_value_str = ",".join(escaped_values) + else: + escaped_value_str = str(value) + filters.append( + (text(f"documents.doc_metadata ->> :{key} = any(string_to_array(:{key_value},','))")).params( + **{key: metadata_name, key_value: escaped_value_str} + ) + ) + case "not in": + if isinstance(value, str): + escaped_values = [v.strip().replace("'", "''") for v in str(value).split(",")] + escaped_value_str = ",".join(escaped_values) + else: + escaped_value_str = str(value) + filters.append( + (text(f"documents.doc_metadata ->> :{key} != all(string_to_array(:{key_value},','))")).params( + **{key: metadata_name, key_value: escaped_value_str} + ) + ) case "=" | "is": if isinstance(value, str): filters.append(Document.doc_metadata[metadata_name] == f'"{value}"') diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index b91fc622f6..d2e022dc9d 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -299,7 +299,7 @@ def _endswith(value: str) -> Callable[[str], bool]: def _is(value: str) -> Callable[[str], bool]: - return lambda x: x is value + return lambda x: x == value def _in(value: str | Sequence[str]) -> Callable[[str], bool]: diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 4bb62d35a2..e6f8abeba0 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -13,7 +13,7 @@ class ModelConfig(BaseModel): provider: str name: str mode: LLMMode - completion_params: dict[str, Any] = {} + completion_params: dict[str, Any] = Field(default_factory=dict) class ContextConfig(BaseModel): diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 0966c87a1d..2441e30c87 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from datetime import UTC, datetime from typing import Optional, cast from sqlalchemy import select, update @@ -20,6 +19,7 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.nodes.llm.entities import ModelConfig +from libs.datetime_utils import naive_utc_now from models import db from models.model import Conversation from models.provider import Provider, ProviderType @@ -149,7 +149,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs ) .values( quota_used=Provider.quota_used + used_quota, - last_used=datetime.now(tz=UTC).replace(tzinfo=None), + last_used=naive_utc_now(), ) ) session.execute(stmt) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 90a0397b67..dfc2a0000b 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -3,7 +3,7 @@ import io import json import logging from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import FileType, file_manager @@ -33,12 +33,10 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.entities.model_entities import ( - AIModelEntity, ModelFeature, ModelPropertyKey, ModelType, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -1006,21 +1004,6 @@ class LLMNode(BaseNode): ) return saved_file - def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: - """ - Fetch model schema - """ - model_name = self._node_data.model.name - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name - ) - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - model_credentials = model_instance.credentials - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) - return model_schema - @staticmethod def fetch_structured_output_schema( *, diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 655de9362f..b2ab943129 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -2,7 +2,7 @@ import json import logging import time from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime +from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, Optional, cast from configs import dify_config @@ -36,6 +36,7 @@ from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor from factories.variable_factory import TypeMismatchError, build_segment_with_type +from libs.datetime_utils import naive_utc_now if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool @@ -143,7 +144,7 @@ class LoopNode(BaseNode): thread_pool_id=self.thread_pool_id, ) - start_at = datetime.now(UTC).replace(tzinfo=None) + start_at = naive_utc_now() condition_processor = ConditionProcessor() # Start Loop event @@ -171,7 +172,7 @@ class LoopNode(BaseNode): try: check_break_result = False for i in range(loop_count): - loop_start_time = datetime.now(UTC).replace(tzinfo=None) + loop_start_time = naive_utc_now() # run single loop loop_result = yield from self._run_single_loop( graph_engine=graph_engine, @@ -185,7 +186,7 @@ class LoopNode(BaseNode): start_at=start_at, inputs=inputs, ) - loop_end_time = datetime.now(UTC).replace(tzinfo=None) + loop_end_time = naive_utc_now() single_loop_variable = {} for key, selector in loop_variable_selectors.items(): @@ -313,30 +314,31 @@ class LoopNode(BaseNode): and event.node_type == NodeType.LOOP_END and not isinstance(event, NodeRunStreamChunkEvent) ): - check_break_result = True + # Check if variables in break conditions exist and process conditions + # Allow loop internal variables to be used in break conditions + available_conditions = [] + for condition in break_conditions: + variable = self.graph_runtime_state.variable_pool.get(condition.variable_selector) + if variable: + available_conditions.append(condition) + + # Process conditions if at least one variable is available + if available_conditions: + input_conditions, group_result, check_break_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=available_conditions, + operator=logical_operator, + ) + if check_break_result: + break + else: + check_break_result = True yield self._handle_event_metadata(event=event, iter_run_index=current_index) break if isinstance(event, NodeRunSucceededEvent): yield self._handle_event_metadata(event=event, iter_run_index=current_index) - # Check if all variables in break conditions exist - exists_variable = False - for condition in break_conditions: - if not self.graph_runtime_state.variable_pool.get(condition.variable_selector): - exists_variable = False - break - else: - exists_variable = True - if exists_variable: - input_conditions, group_result, check_break_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - if check_break_result: - break - elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): # Loop run failed diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index a23d284626..49c4c142e1 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -1,3 +1,4 @@ +import contextlib import json import logging import uuid @@ -666,11 +667,9 @@ class ParameterExtractorNode(BaseNode): if result[idx] == "{" or result[idx] == "[": json_str = extract_json(result[idx:]) if json_str: - try: + with contextlib.suppress(Exception): return cast(dict, json.loads(json_str)) - except Exception: - pass - logger.info(f"extra error: {result}") + logger.info("extra error: %s", result) return None def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: @@ -686,11 +685,10 @@ class ParameterExtractorNode(BaseNode): if result[idx] == "{" or result[idx] == "[": json_str = extract_json(result[idx:]) if json_str: - try: + with contextlib.suppress(Exception): return cast(dict, json.loads(json_str)) - except Exception: - pass - logger.info(f"extra error: {result}") + + logger.info("extra error: %s", result) return None def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 15012fa48d..3e4984ecd5 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -385,9 +385,8 @@ class QuestionClassifierNode(BaseNode): text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( histories=memory_str, input_text=input_text, - categories=json.dumps(categories), + categories=json.dumps(categories, ensure_ascii=False), classification_instructions=instruction, - ensure_ascii=False, ) ) diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 4f47fb1efc..c1cfbb1edc 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -55,7 +55,7 @@ class ToolNodeData(BaseNodeData, ToolEntity): if not isinstance(val, str): raise ValueError("value must be a list of strings") elif typ == "constant" and not isinstance(value, str | int | float | bool | dict): - raise ValueError("value must be a string, int, float, or bool") + raise ValueError("value must be a string, int, float, bool or dict") return typ tool_parameters: dict[str, ToolInput] diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 0d2822233e..48deda724a 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -4,7 +4,7 @@ from typing import Any, TypeVar from pydantic import BaseModel from core.variables import Segment -from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.consts import SELECTORS_LENGTH from core.variables.types import SegmentType # Use double underscore (`__`) prefix for internal variables @@ -23,7 +23,7 @@ _T = TypeVar("_T", bound=MutableMapping[str, Any]) def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: - if len(selector) < MIN_SELECTORS_LENGTH: + if len(selector) < SELECTORS_LENGTH: raise Exception("selector too short") node_id, var_name = selector[:2] return UpdatedVariable( diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index c0215cae71..00ee921cee 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -4,7 +4,7 @@ from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable -from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.consts import SELECTORS_LENGTH from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.node_entities import NodeRunResult @@ -46,7 +46,7 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ selector = item.value if not isinstance(selector, list): raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") - if len(selector) < MIN_SELECTORS_LENGTH: + if len(selector) < SELECTORS_LENGTH: raise InvalidDataError(f"selector too short, {node_id=}, {item=}") selector_str = ".".join(selector) key = f"{node_id}.#{selector_str}#" diff --git a/api/core/workflow/utils/variable_utils.py b/api/core/workflow/utils/variable_utils.py deleted file mode 100644 index 868868315b..0000000000 --- a/api/core/workflow/utils/variable_utils.py +++ /dev/null @@ -1,29 +0,0 @@ -from core.variables.segments import ObjectSegment, Segment -from core.workflow.entities.variable_pool import VariablePool, VariableValue - - -def append_variables_recursively( - pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment -): - """ - Append variables recursively - :param pool: variable pool to append variables to - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - pool.add([node_id] + variable_key_list, variable_value) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, ObjectSegment): - variable_dict = variable_value.value - elif isinstance(variable_value, dict): - variable_dict = variable_value - else: - return - - for key, value in variable_dict.items(): - # construct new key list - new_key_list = variable_key_list + [key] - append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value) diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index 1e13871d0a..a35215855e 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -3,9 +3,8 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.variables import Variable -from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.consts import SELECTORS_LENGTH from core.workflow.entities.variable_pool import VariablePool -from core.workflow.utils import variable_utils class VariableLoader(Protocol): @@ -78,7 +77,7 @@ def load_into_variable_pool( variables_to_load.append(list(selector)) loaded = variable_loader.load_variables(variables_to_load) for var in loaded: - assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}" - variable_utils.append_variables_recursively( - variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var - ) + assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}" + # Add variable directly to the pool + # The variable pool expects 2-element selectors [node_id, variable_name] + variable_pool.add([var.selector[0], var.selector[1]], var) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index c8082ebf50..801e36e272 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -67,7 +67,7 @@ class WorkflowEntry: # check call depth workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: - raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) + raise ValueError(f"Max workflow call depth {workflow_call_max_depth} reached.") # init workflow run state graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -193,7 +193,13 @@ class WorkflowEntry: # run node generator = node.run() except Exception as e: - logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}") + logger.exception( + "error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s", + workflow.id, + node.id, + node.type_, + node.version(), + ) raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) return node, generator @@ -297,7 +303,12 @@ class WorkflowEntry: return node, generator except Exception as e: - logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}") + logger.exception( + "error while running node, node_id=%s, node_type=%s, node_version=%s", + node.id, + node.type_, + node.version(), + ) raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) @staticmethod diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index 2c634d25ec..08e12e2681 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -1,4 +1,5 @@ from collections.abc import Mapping +from decimal import Decimal from typing import Any from pydantic import BaseModel @@ -17,6 +18,9 @@ class WorkflowRuntimeTypeConverter: return value if isinstance(value, (bool, int, str, float)): return value + if isinstance(value, Decimal): + # Convert Decimal to float for JSON serialization + return float(value) if isinstance(value, Segment): return self._to_json_encodable_recursive(value.value) if isinstance(value, File): diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 4de9a25c2f..e21092349e 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -2,6 +2,11 @@ set -e +# Set UTF-8 encoding to address potential encoding issues in containerized environments +export LANG=${LANG:-en_US.UTF-8} +export LC_ALL=${LC_ALL:-en_US.UTF-8} +export PYTHONIOENCODING=${PYTHONIOENCODING:-utf-8} + if [[ "${MIGRATION_ENABLED}" == "true" ]]; then echo "Running migrations" flask upgrade-db @@ -27,7 +32,7 @@ if [[ "${MODE}" == "worker" ]]; then exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin} + -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation} elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} diff --git a/api/events/event_handlers/document_index_event.py b/api/events/document_index_event.py similarity index 100% rename from api/events/event_handlers/document_index_event.py rename to api/events/document_index_event.py diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py index 00a66f50ad..bbc913b7cf 100644 --- a/api/events/event_handlers/clean_when_document_deleted.py +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -8,4 +8,6 @@ def handle(sender, **kwargs): dataset_id = kwargs.get("dataset_id") doc_form = kwargs.get("doc_form") file_id = kwargs.get("file_id") + assert dataset_id is not None + assert doc_form is not None clean_document_task.delay(document_id, dataset_id, doc_form, file_id) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index dc50ca8d96..1b0321f42e 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -1,3 +1,4 @@ +import contextlib import logging import time @@ -5,7 +6,7 @@ import click from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner -from events.event_handlers.document_index_event import document_index_created +from events.document_index_event import document_index_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document @@ -18,7 +19,7 @@ def handle(sender, **kwargs): documents = [] start_at = time.perf_counter() for document_id in document_ids: - logging.info(click.style("Start process document: {}".format(document_id), fg="green")) + logging.info(click.style(f"Start process document: {document_id}", fg="green")) document = ( db.session.query(Document) @@ -38,12 +39,11 @@ def handle(sender, **kwargs): db.session.add(document) db.session.commit() - try: - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedError as ex: - logging.info(click.style(str(ex), fg="yellow")) - except Exception: - pass + with contextlib.suppress(Exception): + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index d3943f2eda..f01dd58900 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -131,9 +131,11 @@ def handle(sender: Message, **kwargs): duration = time_module.perf_counter() - start_time logger.info( - f"Provider updates completed successfully. " - f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, " - f"Tenant: {tenant_id}, Provider: {provider_name}" + "Provider updates completed successfully. Updates: %s, Duration: %s s, Tenant: %s, Provider: %s", + len(updates_to_perform), + duration, + tenant_id, + provider_name, ) except Exception as e: @@ -141,9 +143,11 @@ def handle(sender: Message, **kwargs): duration = time_module.perf_counter() - start_time logger.exception( - f"Provider updates failed after {duration:.3f}s. " - f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, " - f"Provider: {provider_name}" + "Provider updates failed after %s s. Updates: %s, Tenant: %s, Provider: %s", + duration, + len(updates_to_perform), + tenant_id, + provider_name, ) raise @@ -184,7 +188,7 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation] # Use SQLAlchemy's context manager for transaction management # This automatically handles commit/rollback - with Session(db.engine) as session: + with Session(db.engine) as session, session.begin(): # Use a single transaction for all updates for update_operation in updates_to_perform: filters = update_operation.filters @@ -219,16 +223,20 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation] rows_affected = result.rowcount logger.debug( - f"Provider update ({description}): {rows_affected} rows affected. " - f"Filters: {filters.model_dump()}, Values: {update_values}" + "Provider update (%s): %s rows affected. Filters: %s, Values: %s", + description, + rows_affected, + filters.model_dump(), + update_values, ) # If no rows were affected for quota updates, log a warning if rows_affected == 0 and description == "quota_deduction_update": logger.warning( - f"No Provider rows updated for quota deduction. " - f"This may indicate quota limit exceeded or provider not found. " - f"Filters: {filters.model_dump()}" + "No Provider rows updated for quota deduction. " + "This may indicate quota limit exceeded or provider not found. " + "Filters: %s", + filters.model_dump(), ) - logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates") + logger.debug("Successfully processed %s Provider updates", len(updates_to_perform)) diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index a4d013ffc0..1024fd9ce6 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -29,7 +29,6 @@ def init_app(app: DifyApp): methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) - app.register_blueprint(web_bp) CORS( @@ -40,10 +39,13 @@ def init_app(app: DifyApp): methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], expose_headers=["X-Version", "X-Env"], ) - app.register_blueprint(console_app_bp) - CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) + CORS( + files_bp, + allow_headers=["Content-Type"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) app.register_blueprint(files_bp) app.register_blueprint(inner_api_bp) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 2c2846ba26..fb5352ca8f 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,13 +1,49 @@ +import ssl from datetime import timedelta +from typing import Any, Optional import pytz -from celery import Celery, Task # type: ignore -from celery.schedules import crontab # type: ignore +from celery import Celery, Task +from celery.schedules import crontab from configs import dify_config from dify_app import DifyApp +def _get_celery_ssl_options() -> Optional[dict[str, Any]]: + """Get SSL configuration for Celery broker/backend connections.""" + # Use REDIS_USE_SSL for consistency with the main Redis client + # Only apply SSL if we're using Redis as broker/backend + if not dify_config.REDIS_USE_SSL: + return None + + # Check if Celery is actually using Redis + broker_is_redis = dify_config.CELERY_BROKER_URL and ( + dify_config.CELERY_BROKER_URL.startswith("redis://") or dify_config.CELERY_BROKER_URL.startswith("rediss://") + ) + + if not broker_is_redis: + return None + + # Map certificate requirement strings to SSL constants + cert_reqs_map = { + "CERT_NONE": ssl.CERT_NONE, + "CERT_OPTIONAL": ssl.CERT_OPTIONAL, + "CERT_REQUIRED": ssl.CERT_REQUIRED, + } + + ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) + + ssl_options = { + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, + "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, + "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, + } + + return ssl_options + + def init_app(app: DifyApp) -> Celery: class FlaskTask(Task): def __call__(self, *args: object, **kwargs: object) -> object: @@ -30,17 +66,8 @@ def init_app(app: DifyApp) -> Celery: task_cls=FlaskTask, broker=dify_config.CELERY_BROKER_URL, backend=dify_config.CELERY_BACKEND, - task_ignore_result=True, ) - # Add SSL options to the Celery configuration - ssl_options = { - "ssl_cert_reqs": None, - "ssl_ca_certs": None, - "ssl_certfile": None, - "ssl_keyfile": None, - } - celery_app.conf.update( result_backend=dify_config.CELERY_RESULT_BACKEND, broker_transport_options=broker_transport_options, @@ -49,11 +76,16 @@ def init_app(app: DifyApp) -> Celery: worker_task_log_format=dify_config.LOG_FORMAT, worker_hijack_root_logger=False, timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), + task_ignore_result=True, ) - if dify_config.BROKER_USE_SSL: + # Apply SSL configuration if enabled + ssl_options = _get_celery_ssl_options() + if ssl_options: celery_app.conf.update( - broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration + broker_use_ssl=ssl_options, + # Also apply SSL to the backend if it's Redis + redis_backend_use_ssl=ssl_options if dify_config.CELERY_BACKEND == "redis" else None, ) if dify_config.LOG_FILE: @@ -73,13 +105,13 @@ def init_app(app: DifyApp) -> Celery: imports.append("schedule.clean_embedding_cache_task") beat_schedule["clean_embedding_cache_task"] = { "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", - "schedule": timedelta(days=day), + "schedule": crontab(minute="0", hour="2", day_of_month=f"*/{day}"), } if dify_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK: imports.append("schedule.clean_unused_datasets_task") beat_schedule["clean_unused_datasets_task"] = { "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", - "schedule": timedelta(days=day), + "schedule": crontab(minute="0", hour="3", day_of_month=f"*/{day}"), } if dify_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK: imports.append("schedule.create_tidb_serverless_task") @@ -97,7 +129,7 @@ def init_app(app: DifyApp) -> Celery: imports.append("schedule.clean_messages") beat_schedule["clean_messages"] = { "task": "schedule.clean_messages.clean_messages", - "schedule": timedelta(days=day), + "schedule": crontab(minute="0", hour="4", day_of_month=f"*/{day}"), } if dify_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: imports.append("schedule.mail_clean_document_notify_task") @@ -113,13 +145,19 @@ def init_app(app: DifyApp) -> Celery: minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 ), } - if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: + if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED: imports.append("schedule.check_upgradable_plugin_task") beat_schedule["check_upgradable_plugin_task"] = { "task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task", "schedule": crontab(minute="*/15"), } - + if dify_config.WORKFLOW_LOG_CLEANUP_ENABLED: + # 2:00 AM every day + imports.append("schedule.clean_workflow_runlogs_precise") + beat_schedule["clean_workflow_runlogs_precise"] = { + "task": "schedule.clean_workflow_runlogs_precise.clean_workflow_runlogs_precise", + "schedule": crontab(minute="0", hour="2"), + } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 600e336c19..8904ff7a92 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -4,6 +4,7 @@ from dify_app import DifyApp def init_app(app: DifyApp): from commands import ( add_qdrant_index, + cleanup_orphaned_draft_variables, clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, convert_to_agent_apps, @@ -42,6 +43,7 @@ def init_app(app: DifyApp): clear_orphaned_file_records, remove_orphaned_files_on_storage, setup_system_tool_oauth_client, + cleanup_orphaned_draft_variables, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 9b18e25eaa..9e5c71fb1d 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -20,6 +20,10 @@ login_manager = flask_login.LoginManager() @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" + # Skip authentication for documentation endpoints + if request.path.endswith("/docs") or request.path.endswith("/swagger.json"): + return None + auth_header = request.headers.get("Authorization", "") auth_token: str | None = None if auth_header: diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index df5d8a9c11..fe05138196 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -64,7 +64,7 @@ class Mail: sendgrid_api_key=dify_config.SENDGRID_API_KEY, _from=dify_config.MAIL_DEFAULT_SEND_FROM or "" ) case _: - raise ValueError("Unsupported mail type {}".format(mail_type)) + raise ValueError(f"Unsupported mail type {mail_type}") def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): if not self._client: diff --git a/api/extensions/ext_orjson.py b/api/extensions/ext_orjson.py new file mode 100644 index 0000000000..659784a585 --- /dev/null +++ b/api/extensions/ext_orjson.py @@ -0,0 +1,8 @@ +from flask_orjson import OrjsonProvider + +from dify_app import DifyApp + + +def init_app(app: DifyApp) -> None: + """Initialize Flask-Orjson extension for faster JSON serialization""" + app.json = OrjsonProvider(app) diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index b027a165f9..544a2dc625 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -1,4 +1,5 @@ import atexit +import contextlib import logging import os import platform @@ -7,7 +8,7 @@ import sys from typing import Union import flask -from celery.signals import worker_init # type: ignore +from celery.signals import worker_init from flask_login import user_loaded_from_request, user_logged_in # type: ignore from configs import dify_config @@ -106,7 +107,7 @@ def init_app(app: DifyApp): """Custom logging handler that creates spans for logging.exception() calls""" def emit(self, record: logging.LogRecord): - try: + with contextlib.suppress(Exception): if record.exc_info: tracer = get_tracer_provider().get_tracer("dify.exception.logging") with tracer.start_as_current_span( @@ -126,9 +127,6 @@ def init_app(app: DifyApp): if record.exc_info[0]: span.set_attribute("exception.type", record.exc_info[0].__name__) - except Exception: - pass - from opentelemetry import trace from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter @@ -136,6 +134,8 @@ def init_app(app: DifyApp): from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter from opentelemetry.instrumentation.celery import CeleryInstrumentor from opentelemetry.instrumentation.flask import FlaskInstrumentor + from opentelemetry.instrumentation.redis import RedisInstrumentor + from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider from opentelemetry.propagate import set_global_textmap @@ -234,6 +234,8 @@ def init_app(app: DifyApp): CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument() instrument_exception_logging() init_sqlalchemy_instrumentor(app) + RedisInstrumentor().instrument() + RequestsInstrumentor().instrument() atexit.register(shutdown_tracer) diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index be2f6115f7..1b22886fc1 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,18 +1,24 @@ import functools import logging +import ssl from collections.abc import Callable -from typing import Any, Union +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Optional, Union import redis from redis import RedisError from redis.cache import CacheConfig from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection +from redis.lock import Lock from redis.sentinel import Sentinel from configs import dify_config from dify_app import DifyApp +if TYPE_CHECKING: + from redis.lock import Lock + logger = logging.getLogger(__name__) @@ -28,8 +34,8 @@ class RedisClientWrapper: a failover in a Sentinel-managed Redis setup. Attributes: - _client (redis.Redis): The actual Redis client instance. It remains None until - initialized with the `initialize` method. + _client: The actual Redis client instance. It remains None until + initialized with the `initialize` method. Methods: initialize(client): Initializes the Redis client if it hasn't been initialized already. @@ -37,93 +43,210 @@ class RedisClientWrapper: if the client is not initialized. """ - def __init__(self): + _client: Union[redis.Redis, RedisCluster, None] + + def __init__(self) -> None: self._client = None - def initialize(self, client): + def initialize(self, client: Union[redis.Redis, RedisCluster]) -> None: if self._client is None: self._client = client - def __getattr__(self, item): + if TYPE_CHECKING: + # Type hints for IDE support and static analysis + # These are not executed at runtime but provide type information + def get(self, name: str | bytes) -> Any: ... + + def set( + self, + name: str | bytes, + value: Any, + ex: int | None = None, + px: int | None = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: int | None = None, + pxat: int | None = None, + ) -> Any: ... + + def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ... + def setnx(self, name: str | bytes, value: Any) -> Any: ... + def delete(self, *names: str | bytes) -> Any: ... + def incr(self, name: str | bytes, amount: int = 1) -> Any: ... + def expire( + self, + name: str | bytes, + time: int | timedelta, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: ... + def lock( + self, + name: str, + timeout: float | None = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: float | None = None, + thread_local: bool = True, + ) -> Lock: ... + def zadd( + self, + name: str | bytes, + mapping: dict[str | bytes | int | float, float | int | str | bytes], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: ... + def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... + def zcard(self, name: str | bytes) -> Any: ... + def getdel(self, name: str | bytes) -> Any: ... + + def __getattr__(self, item: str) -> Any: if self._client is None: raise RuntimeError("Redis client is not initialized. Call init_app first.") return getattr(self._client, item) -redis_client = RedisClientWrapper() +redis_client: RedisClientWrapper = RedisClientWrapper() -def init_app(app: DifyApp): - global redis_client - connection_class: type[Union[Connection, SSLConnection]] = Connection - if dify_config.REDIS_USE_SSL: - connection_class = SSLConnection +def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: + """Get SSL configuration for Redis connection.""" + if not dify_config.REDIS_USE_SSL: + return Connection, {} + + cert_reqs_map = { + "CERT_NONE": ssl.CERT_NONE, + "CERT_OPTIONAL": ssl.CERT_OPTIONAL, + "CERT_REQUIRED": ssl.CERT_REQUIRED, + } + ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) + + ssl_kwargs = { + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, + "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, + "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, + } + + return SSLConnection, ssl_kwargs + + +def _get_cache_configuration() -> CacheConfig | None: + """Get client-side cache configuration if enabled.""" + if not dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: + return None + resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL - if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: - if resp_protocol >= 3: - clientside_cache_config = CacheConfig() - else: - raise ValueError("Client side cache is only supported in RESP3") - else: - clientside_cache_config = None + if resp_protocol < 3: + raise ValueError("Client side cache is only supported in RESP3") - redis_params: dict[str, Any] = { + return CacheConfig() + + +def _get_base_redis_params() -> dict[str, Any]: + """Get base Redis connection parameters.""" + return { "username": dify_config.REDIS_USERNAME, - "password": dify_config.REDIS_PASSWORD or None, # Temporary fix for empty password + "password": dify_config.REDIS_PASSWORD or None, "db": dify_config.REDIS_DB, "encoding": "utf-8", "encoding_errors": "strict", "decode_responses": False, - "protocol": resp_protocol, - "cache_config": clientside_cache_config, + "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, + "cache_config": _get_cache_configuration(), } - if dify_config.REDIS_USE_SENTINEL: - assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True" - sentinel_hosts = [ - (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") - ] - sentinel = Sentinel( - sentinel_hosts, - sentinel_kwargs={ - "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, - "username": dify_config.REDIS_SENTINEL_USERNAME, - "password": dify_config.REDIS_SENTINEL_PASSWORD, - }, - ) - master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) - redis_client.initialize(master) - elif dify_config.REDIS_USE_CLUSTERS: - assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True" - nodes = [ - ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) - for node in dify_config.REDIS_CLUSTERS.split(",") - ] - redis_client.initialize( - RedisCluster( - startup_nodes=nodes, - password=dify_config.REDIS_CLUSTERS_PASSWORD, - protocol=resp_protocol, - cache_config=clientside_cache_config, - ) - ) - else: - redis_params.update( - { - "host": dify_config.REDIS_HOST, - "port": dify_config.REDIS_PORT, - "connection_class": connection_class, - "protocol": resp_protocol, - "cache_config": clientside_cache_config, - } - ) - pool = redis.ConnectionPool(**redis_params) - redis_client.initialize(redis.Redis(connection_pool=pool)) +def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: + """Create Redis client using Sentinel configuration.""" + if not dify_config.REDIS_SENTINELS: + raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") + + if not dify_config.REDIS_SENTINEL_SERVICE_NAME: + raise ValueError("REDIS_SENTINEL_SERVICE_NAME must be set when REDIS_USE_SENTINEL is True") + + sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")] + + sentinel = Sentinel( + sentinel_hosts, + sentinel_kwargs={ + "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, + "username": dify_config.REDIS_SENTINEL_USERNAME, + "password": dify_config.REDIS_SENTINEL_PASSWORD, + }, + ) + + master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) + return master + + +def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: + """Create Redis cluster client.""" + if not dify_config.REDIS_CLUSTERS: + raise ValueError("REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True") + + nodes = [ + ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) + for node in dify_config.REDIS_CLUSTERS.split(",") + ] + + cluster: RedisCluster = RedisCluster( + startup_nodes=nodes, + password=dify_config.REDIS_CLUSTERS_PASSWORD, + protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, + cache_config=_get_cache_configuration(), + ) + return cluster + + +def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: + """Create standalone Redis client.""" + connection_class, ssl_kwargs = _get_ssl_configuration() + + redis_params.update( + { + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, + "connection_class": connection_class, + } + ) + + if ssl_kwargs: + redis_params.update(ssl_kwargs) + + pool = redis.ConnectionPool(**redis_params) + client: redis.Redis = redis.Redis(connection_pool=pool) + return client + + +def init_app(app: DifyApp): + """Initialize Redis client and attach it to the app.""" + global redis_client + + # Determine Redis mode and create appropriate client + if dify_config.REDIS_USE_SENTINEL: + redis_params = _get_base_redis_params() + client = _create_sentinel_client(redis_params) + elif dify_config.REDIS_USE_CLUSTERS: + client = _create_cluster_client() + else: + redis_params = _get_base_redis_params() + client = _create_standalone_client(redis_params) + + # Initialize the wrapper and attach to app + redis_client.initialize(client) app.extensions["redis"] = redis_client -def redis_fallback(default_return: Any = None): +def redis_fallback(default_return: Optional[Any] = None): """ decorator to handle Redis operation exceptions and return a default value when Redis is unavailable. @@ -137,7 +260,7 @@ def redis_fallback(default_return: Any = None): try: return func(*args, **kwargs) except RedisError as e: - logger.warning(f"Redis operation failed in {func.__name__}: {str(e)}", exc_info=True) + logger.warning("Redis operation failed in %s: %s", func.__name__, str(e), exc_info=True) return default_return return wrapper diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index bd35278544..d13393dd14 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -69,6 +69,19 @@ class Storage: from extensions.storage.supabase_storage import SupabaseStorage return SupabaseStorage + case StorageType.CLICKZETTA_VOLUME: + from extensions.storage.clickzetta_volume.clickzetta_volume_storage import ( + ClickZettaVolumeConfig, + ClickZettaVolumeStorage, + ) + + def create_clickzetta_volume_storage(): + # ClickZettaVolumeConfig will automatically read from environment variables + # and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set + volume_config = ClickZettaVolumeConfig() + return ClickZettaVolumeStorage(volume_config) + + return create_clickzetta_volume_storage case _: raise ValueError(f"unsupported storage type {storage_type}") diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 81eec94da4..7ec0889776 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -69,7 +69,7 @@ class AzureBlobStorage(BaseStorage): if self.account_key == "managedidentity": return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore - cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key) + cache_key = f"azure_blob_sas_token_{self.account_name}_{self.account_key}" cache_result = redis_client.get(cache_key) if cache_result is not None: sas_token = cache_result.decode("utf-8") diff --git a/api/extensions/storage/clickzetta_volume/__init__.py b/api/extensions/storage/clickzetta_volume/__init__.py new file mode 100644 index 0000000000..8a1588034b --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/__init__.py @@ -0,0 +1,5 @@ +"""ClickZetta Volume storage implementation.""" + +from .clickzetta_volume_storage import ClickZettaVolumeStorage + +__all__ = ["ClickZettaVolumeStorage"] diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py new file mode 100644 index 0000000000..09ab37f42e --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -0,0 +1,530 @@ +"""ClickZetta Volume Storage Implementation + +This module provides storage backend using ClickZetta Volume functionality. +Supports Table Volume, User Volume, and External Volume types. +""" + +import logging +import os +import tempfile +from collections.abc import Generator +from io import BytesIO +from pathlib import Path +from typing import Optional + +import clickzetta # type: ignore[import] +from pydantic import BaseModel, model_validator + +from extensions.storage.base_storage import BaseStorage + +from .volume_permissions import VolumePermissionManager, check_volume_permission + +logger = logging.getLogger(__name__) + + +class ClickZettaVolumeConfig(BaseModel): + """Configuration for ClickZetta Volume storage.""" + + username: str = "" + password: str = "" + instance: str = "" + service: str = "api.clickzetta.com" + workspace: str = "quick_start" + vcluster: str = "default_ap" + schema_name: str = "dify" + volume_type: str = "table" # table|user|external + volume_name: Optional[str] = None # For external volumes + table_prefix: str = "dataset_" # Prefix for table volume names + dify_prefix: str = "dify_km" # Directory prefix for User Volume + permission_check: bool = True # Enable/disable permission checking + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + """Validate the configuration values. + + This method will first try to use CLICKZETTA_VOLUME_* environment variables, + then fall back to CLICKZETTA_* environment variables (for vector DB config). + """ + import os + + # Helper function to get environment variable with fallback + def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str: + # First try CLICKZETTA_VOLUME_* specific config + volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", "")) + if volume_value: + return str(volume_value) + + # Then try environment variables + volume_env = os.getenv(volume_key) + if volume_env: + return volume_env + + # Fall back to existing CLICKZETTA_* config + fallback_env = os.getenv(fallback_key) + if fallback_env: + return fallback_env + + return default or "" + + # Apply environment variables with fallback to existing CLICKZETTA_* config + values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME")) + values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD")) + values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE")) + values.setdefault( + "service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com") + ) + values.setdefault( + "workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start") + ) + values.setdefault( + "vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap") + ) + values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify")) + + # Volume-specific configurations (no fallback to vector DB config) + values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table")) + values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME")) + values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_")) + values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km")) + # 暂时禁用权限检查功能,直接设置为false + values.setdefault("permission_check", False) + + # Validate required fields + if not values.get("username"): + raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required") + if not values.get("password"): + raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required") + if not values.get("instance"): + raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required") + + # Validate volume type + volume_type = values["volume_type"] + if volume_type not in ["table", "user", "external"]: + raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external") + + if volume_type == "external" and not values.get("volume_name"): + raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type") + + return values + + +class ClickZettaVolumeStorage(BaseStorage): + """ClickZetta Volume storage implementation.""" + + def __init__(self, config: ClickZettaVolumeConfig): + """Initialize ClickZetta Volume storage. + + Args: + config: ClickZetta Volume configuration + """ + self._config = config + self._connection = None + self._permission_manager: VolumePermissionManager | None = None + self._init_connection() + self._init_permission_manager() + + logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type) + + def _init_connection(self): + """Initialize ClickZetta connection.""" + try: + self._connection = clickzetta.connect( + username=self._config.username, + password=self._config.password, + instance=self._config.instance, + service=self._config.service, + workspace=self._config.workspace, + vcluster=self._config.vcluster, + schema=self._config.schema_name, + ) + logger.debug("ClickZetta connection established") + except Exception as e: + logger.exception("Failed to connect to ClickZetta") + raise + + def _init_permission_manager(self): + """Initialize permission manager.""" + try: + self._permission_manager = VolumePermissionManager( + self._connection, self._config.volume_type, self._config.volume_name + ) + logger.debug("Permission manager initialized") + except Exception as e: + logger.exception("Failed to initialize permission manager") + raise + + def _get_volume_path(self, filename: str, dataset_id: Optional[str] = None) -> str: + """Get the appropriate volume path based on volume type.""" + if self._config.volume_type == "user": + # Add dify prefix for User Volume to organize files + return f"{self._config.dify_prefix}/{filename}" + elif self._config.volume_type == "table": + # Check if this should use User Volume (special directories) + if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + # Use User Volume with dify prefix for special directories + return f"{self._config.dify_prefix}/{filename}" + + if dataset_id: + return f"{self._config.table_prefix}{dataset_id}/{filename}" + else: + # Extract dataset_id from filename if not provided + # Format: dataset_id/filename + if "/" in filename: + return filename + else: + raise ValueError("dataset_id is required for table volume or filename must include dataset_id/") + elif self._config.volume_type == "external": + return filename + else: + raise ValueError(f"Unsupported volume type: {self._config.volume_type}") + + def _get_volume_sql_prefix(self, dataset_id: Optional[str] = None) -> str: + """Get SQL prefix for volume operations.""" + if self._config.volume_type == "user": + return "USER VOLUME" + elif self._config.volume_type == "table": + # For Dify's current file storage pattern, most files are stored in + # paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext" + # These should use USER VOLUME for better compatibility + if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + return "USER VOLUME" + + # Only use TABLE VOLUME for actual dataset-specific paths + # like "dataset_12345/file.pdf" or paths with dataset_ prefix + if dataset_id: + table_name = f"{self._config.table_prefix}{dataset_id}" + else: + # Default table name for generic operations + table_name = "default_dataset" + return f"TABLE VOLUME {table_name}" + elif self._config.volume_type == "external": + return f"VOLUME {self._config.volume_name}" + else: + raise ValueError(f"Unsupported volume type: {self._config.volume_type}") + + def _execute_sql(self, sql: str, fetch: bool = False): + """Execute SQL command.""" + try: + if self._connection is None: + raise RuntimeError("Connection not initialized") + with self._connection.cursor() as cursor: + cursor.execute(sql) + if fetch: + return cursor.fetchall() + return None + except Exception as e: + logger.exception("SQL execution failed: %s", sql) + raise + + def _ensure_table_volume_exists(self, dataset_id: str) -> None: + """Ensure table volume exists for the given dataset_id.""" + if self._config.volume_type != "table" or not dataset_id: + return + + # Skip for upload_files and other special directories that use USER VOLUME + if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + return + + table_name = f"{self._config.table_prefix}{dataset_id}" + + try: + # Check if table exists + check_sql = f"SHOW TABLES LIKE '{table_name}'" + result = self._execute_sql(check_sql, fetch=True) + + if not result: + # Create table with volume + create_sql = f""" + CREATE TABLE {table_name} ( + id INT PRIMARY KEY AUTO_INCREMENT, + filename VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + INDEX idx_filename (filename) + ) WITH VOLUME + """ + self._execute_sql(create_sql) + logger.info("Created table volume: %s", table_name) + + except Exception as e: + logger.warning("Failed to create table volume %s: %s", table_name, e) + # Don't raise exception, let the operation continue + # The table might exist but not be visible due to permissions + + def save(self, filename: str, data: bytes) -> None: + """Save data to ClickZetta Volume. + + Args: + filename: File path in volume + data: File content as bytes + """ + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + # Ensure table volume exists (for table volumes) + if dataset_id: + self._ensure_table_volume_exists(dataset_id) + + # Check permissions (if enabled) + if self._config.permission_check: + # Skip permission check for special directories that use USER VOLUME + if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + if self._permission_manager is not None: + check_volume_permission(self._permission_manager, "save", dataset_id) + + # Write data to temporary file + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file.write(data) + temp_file_path = temp_file.name + + try: + # Upload to volume + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + actual_filename = volume_path.split("/")[-1] if "/" in volume_path else volume_path + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'" + else: + sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'" + + self._execute_sql(sql) + logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path) + finally: + # Clean up temporary file + Path(temp_file_path).unlink(missing_ok=True) + + def load_once(self, filename: str) -> bytes: + """Load file content from ClickZetta Volume. + + Args: + filename: File path in volume + + Returns: + File content as bytes + """ + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + # Check permissions (if enabled) + if self._config.permission_check: + # Skip permission check for special directories that use USER VOLUME + if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]: + if self._permission_manager is not None: + check_volume_permission(self._permission_manager, "load_once", dataset_id) + + # Download to temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'" + else: + sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'" + + self._execute_sql(sql) + + # Find the downloaded file (may be in subdirectories) + downloaded_file = None + for root, dirs, files in os.walk(temp_dir): + for file in files: + if file == filename or file == os.path.basename(filename): + downloaded_file = Path(root) / file + break + if downloaded_file: + break + + if not downloaded_file or not downloaded_file.exists(): + raise FileNotFoundError(f"Downloaded file not found: {filename}") + + content = downloaded_file.read_bytes() + + logger.debug("File %s loaded from ClickZetta Volume", filename) + return content + + def load_stream(self, filename: str) -> Generator: + """Load file as stream from ClickZetta Volume. + + Args: + filename: File path in volume + + Yields: + File content chunks + """ + content = self.load_once(filename) + batch_size = 4096 + stream = BytesIO(content) + + while chunk := stream.read(batch_size): + yield chunk + + logger.debug("File %s loaded as stream from ClickZetta Volume", filename) + + def download(self, filename: str, target_filepath: str): + """Download file from ClickZetta Volume to local path. + + Args: + filename: File path in volume + target_filepath: Local target file path + """ + content = self.load_once(filename) + + with Path(target_filepath).open("wb") as f: + f.write(content) + + logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath) + + def exists(self, filename: str) -> bool: + """Check if file exists in ClickZetta Volume. + + Args: + filename: File path in volume + + Returns: + True if file exists, False otherwise + """ + try: + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'" + else: + sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'" + + rows = self._execute_sql(sql, fetch=True) + + exists = len(rows) > 0 + logger.debug("File %s exists check: %s", filename, exists) + return exists + except Exception as e: + logger.warning("Error checking file existence for %s: %s", filename, e) + return False + + def delete(self, filename: str): + """Delete file from ClickZetta Volume. + + Args: + filename: File path in volume + """ + if not self.exists(filename): + logger.debug("File %s not found, skip delete", filename) + return + + # Extract dataset_id from filename if present + dataset_id = None + if "/" in filename and self._config.volume_type == "table": + parts = filename.split("/", 1) + if parts[0].startswith(self._config.table_prefix): + dataset_id = parts[0][len(self._config.table_prefix) :] + filename = parts[1] + else: + dataset_id = parts[0] + filename = parts[1] + + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # Get the actual volume path (may include dify_km prefix) + volume_path = self._get_volume_path(filename, dataset_id) + + # For User Volume, use the full path with dify_km prefix + if volume_prefix == "USER VOLUME": + sql = f"REMOVE {volume_prefix} FILE '{volume_path}'" + else: + sql = f"REMOVE {volume_prefix} FILE '{filename}'" + + self._execute_sql(sql) + + logger.debug("File %s deleted from ClickZetta Volume", filename) + + def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: + """Scan files and directories in ClickZetta Volume. + + Args: + path: Path to scan (dataset_id for table volumes) + files: Include files in results + directories: Include directories in results + + Returns: + List of file/directory paths + """ + try: + # For table volumes, path is treated as dataset_id + dataset_id = None + if self._config.volume_type == "table": + dataset_id = path + path = "" # Root of the table volume + + volume_prefix = self._get_volume_sql_prefix(dataset_id) + + # For User Volume, add dify prefix to path + if volume_prefix == "USER VOLUME": + if path: + scan_path = f"{self._config.dify_prefix}/{path}" + sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'" + else: + sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'" + else: + if path: + sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'" + else: + sql = f"LIST {volume_prefix}" + + rows = self._execute_sql(sql, fetch=True) + + result = [] + for row in rows: + file_path = row[0] # relative_path column + + # For User Volume, remove dify prefix from results + dify_prefix_with_slash = f"{self._config.dify_prefix}/" + if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash): + file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix + + if files and not file_path.endswith("/") or directories and file_path.endswith("/"): + result.append(file_path) + + logger.debug("Scanned %d items in path %s", len(result), path) + return result + + except Exception as e: + logger.exception("Error scanning path %s", path) + return [] diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py new file mode 100644 index 0000000000..d5d04f121b --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -0,0 +1,516 @@ +"""ClickZetta Volume文件生命周期管理 + +该模块提供文件版本控制、自动清理、备份和恢复等生命周期管理功能。 +支持知识库文件的完整生命周期管理。 +""" + +import json +import logging +from dataclasses import asdict, dataclass +from datetime import datetime, timedelta +from enum import Enum +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +class FileStatus(Enum): + """文件状态枚举""" + + ACTIVE = "active" # 活跃状态 + ARCHIVED = "archived" # 已归档 + DELETED = "deleted" # 已删除(软删除) + BACKUP = "backup" # 备份文件 + + +@dataclass +class FileMetadata: + """文件元数据""" + + filename: str + size: int | None + created_at: datetime + modified_at: datetime + version: int | None + status: FileStatus + checksum: Optional[str] = None + tags: Optional[dict[str, str]] = None + parent_version: Optional[int] = None + + def to_dict(self) -> dict: + """转换为字典格式""" + data = asdict(self) + data["created_at"] = self.created_at.isoformat() + data["modified_at"] = self.modified_at.isoformat() + data["status"] = self.status.value + return data + + @classmethod + def from_dict(cls, data: dict) -> "FileMetadata": + """从字典创建实例""" + data = data.copy() + data["created_at"] = datetime.fromisoformat(data["created_at"]) + data["modified_at"] = datetime.fromisoformat(data["modified_at"]) + data["status"] = FileStatus(data["status"]) + return cls(**data) + + +class FileLifecycleManager: + """文件生命周期管理器""" + + def __init__(self, storage, dataset_id: Optional[str] = None): + """初始化生命周期管理器 + + Args: + storage: ClickZetta Volume存储实例 + dataset_id: 数据集ID(用于Table Volume) + """ + self._storage = storage + self._dataset_id = dataset_id + self._metadata_file = ".dify_file_metadata.json" + self._version_prefix = ".versions/" + self._backup_prefix = ".backups/" + self._deleted_prefix = ".deleted/" + + # 获取权限管理器(如果存在) + self._permission_manager: Optional[Any] = getattr(storage, "_permission_manager", None) + + def save_with_lifecycle(self, filename: str, data: bytes, tags: Optional[dict[str, str]] = None) -> FileMetadata: + """保存文件并管理生命周期 + + Args: + filename: 文件名 + data: 文件内容 + tags: 文件标签 + + Returns: + 文件元数据 + """ + # 权限检查 + if not self._check_permission(filename, "save"): + from .volume_permissions import VolumePermissionError + + raise VolumePermissionError( + f"Permission denied for lifecycle save operation on file: {filename}", + operation="save", + volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"), + dataset_id=self._dataset_id, + ) + + try: + # 1. 检查是否存在旧版本 + metadata_dict = self._load_metadata() + current_metadata = metadata_dict.get(filename) + + # 2. 如果存在旧版本,创建版本备份 + if current_metadata: + self._create_version_backup(filename, current_metadata) + + # 3. 计算文件信息 + now = datetime.now() + checksum = self._calculate_checksum(data) + new_version = (current_metadata["version"] + 1) if current_metadata else 1 + + # 4. 保存新文件 + self._storage.save(filename, data) + + # 5. 创建元数据 + created_at = now + parent_version = None + + if current_metadata: + # 如果created_at是字符串,转换为datetime + if isinstance(current_metadata["created_at"], str): + created_at = datetime.fromisoformat(current_metadata["created_at"]) + else: + created_at = current_metadata["created_at"] + parent_version = current_metadata["version"] + + file_metadata = FileMetadata( + filename=filename, + size=len(data), + created_at=created_at, + modified_at=now, + version=new_version, + status=FileStatus.ACTIVE, + checksum=checksum, + tags=tags or {}, + parent_version=parent_version, + ) + + # 6. 更新元数据 + metadata_dict[filename] = file_metadata.to_dict() + self._save_metadata(metadata_dict) + + logger.info("File %s saved with lifecycle management, version %s", filename, new_version) + return file_metadata + + except Exception as e: + logger.exception("Failed to save file with lifecycle") + raise + + def get_file_metadata(self, filename: str) -> Optional[FileMetadata]: + """获取文件元数据 + + Args: + filename: 文件名 + + Returns: + 文件元数据,如果不存在返回None + """ + try: + metadata_dict = self._load_metadata() + if filename in metadata_dict: + return FileMetadata.from_dict(metadata_dict[filename]) + return None + except Exception as e: + logger.exception("Failed to get file metadata for %s", filename) + return None + + def list_file_versions(self, filename: str) -> list[FileMetadata]: + """列出文件的所有版本 + + Args: + filename: 文件名 + + Returns: + 文件版本列表,按版本号排序 + """ + try: + versions = [] + + # 获取当前版本 + current_metadata = self.get_file_metadata(filename) + if current_metadata: + versions.append(current_metadata) + + # 获取历史版本 + version_pattern = f"{self._version_prefix}{filename}.v*" + try: + version_files = self._storage.scan(self._dataset_id or "", files=True) + for file_path in version_files: + if file_path.startswith(f"{self._version_prefix}{filename}.v"): + # 解析版本号 + version_str = file_path.split(".v")[-1].split(".")[0] + try: + version_num = int(version_str) + # 这里简化处理,实际应该从版本文件中读取元数据 + # 暂时创建基本的元数据信息 + except ValueError: + continue + except: + # 如果无法扫描版本文件,只返回当前版本 + pass + + return sorted(versions, key=lambda x: x.version or 0, reverse=True) + + except Exception as e: + logger.exception("Failed to list file versions for %s", filename) + return [] + + def restore_version(self, filename: str, version: int) -> bool: + """恢复文件到指定版本 + + Args: + filename: 文件名 + version: 要恢复的版本号 + + Returns: + 恢复是否成功 + """ + try: + version_filename = f"{self._version_prefix}{filename}.v{version}" + + # 检查版本文件是否存在 + if not self._storage.exists(version_filename): + logger.warning("Version %s of %s not found", version, filename) + return False + + # 读取版本文件内容 + version_data = self._storage.load_once(version_filename) + + # 保存当前版本为备份 + current_metadata = self.get_file_metadata(filename) + if current_metadata: + self._create_version_backup(filename, current_metadata.to_dict()) + + # 恢复文件 + self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)}) + return True + + except Exception as e: + logger.exception("Failed to restore %s to version %s", filename, version) + return False + + def archive_file(self, filename: str) -> bool: + """归档文件 + + Args: + filename: 文件名 + + Returns: + 归档是否成功 + """ + # 权限检查 + if not self._check_permission(filename, "archive"): + logger.warning("Permission denied for archive operation on file: %s", filename) + return False + + try: + # 更新文件状态为归档 + metadata_dict = self._load_metadata() + if filename not in metadata_dict: + logger.warning("File %s not found in metadata", filename) + return False + + metadata_dict[filename]["status"] = FileStatus.ARCHIVED.value + metadata_dict[filename]["modified_at"] = datetime.now().isoformat() + + self._save_metadata(metadata_dict) + + logger.info("File %s archived successfully", filename) + return True + + except Exception as e: + logger.exception("Failed to archive file %s", filename) + return False + + def soft_delete_file(self, filename: str) -> bool: + """软删除文件(移动到删除目录) + + Args: + filename: 文件名 + + Returns: + 删除是否成功 + """ + # 权限检查 + if not self._check_permission(filename, "delete"): + logger.warning("Permission denied for soft delete operation on file: %s", filename) + return False + + try: + # 检查文件是否存在 + if not self._storage.exists(filename): + logger.warning("File %s not found", filename) + return False + + # 读取文件内容 + file_data = self._storage.load_once(filename) + + # 移动到删除目录 + deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}" + self._storage.save(deleted_filename, file_data) + + # 删除原文件 + self._storage.delete(filename) + + # 更新元数据 + metadata_dict = self._load_metadata() + if filename in metadata_dict: + metadata_dict[filename]["status"] = FileStatus.DELETED.value + metadata_dict[filename]["modified_at"] = datetime.now().isoformat() + self._save_metadata(metadata_dict) + + logger.info("File %s soft deleted successfully", filename) + return True + + except Exception as e: + logger.exception("Failed to soft delete file %s", filename) + return False + + def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int: + """清理旧版本文件 + + Args: + max_versions: 保留的最大版本数 + max_age_days: 版本文件的最大保留天数 + + Returns: + 清理的文件数量 + """ + try: + cleaned_count = 0 + cutoff_date = datetime.now() - timedelta(days=max_age_days) + + # 获取所有版本文件 + try: + all_files = self._storage.scan(self._dataset_id or "", files=True) + version_files = [f for f in all_files if f.startswith(self._version_prefix)] + + # 按文件分组 + file_versions: dict[str, list[tuple[int, str]]] = {} + for version_file in version_files: + # 解析文件名和版本 + parts = version_file[len(self._version_prefix) :].split(".v") + if len(parts) >= 2: + base_filename = parts[0] + version_part = parts[1].split(".")[0] + try: + version_num = int(version_part) + if base_filename not in file_versions: + file_versions[base_filename] = [] + file_versions[base_filename].append((version_num, version_file)) + except ValueError: + continue + + # 清理每个文件的旧版本 + for base_filename, versions in file_versions.items(): + # 按版本号排序 + versions.sort(key=lambda x: x[0], reverse=True) + + # 保留最新的max_versions个版本,删除其余的 + if len(versions) > max_versions: + to_delete = versions[max_versions:] + for version_num, version_file in to_delete: + self._storage.delete(version_file) + cleaned_count += 1 + logger.debug("Cleaned old version: %s", version_file) + + logger.info("Cleaned %d old version files", cleaned_count) + + except Exception as e: + logger.warning("Could not scan for version files: %s", e) + + return cleaned_count + + except Exception as e: + logger.exception("Failed to cleanup old versions") + return 0 + + def get_storage_statistics(self) -> dict[str, Any]: + """获取存储统计信息 + + Returns: + 存储统计字典 + """ + try: + metadata_dict = self._load_metadata() + + stats: dict[str, Any] = { + "total_files": len(metadata_dict), + "active_files": 0, + "archived_files": 0, + "deleted_files": 0, + "total_size": 0, + "versions_count": 0, + "oldest_file": None, + "newest_file": None, + } + + oldest_date = None + newest_date = None + + for filename, metadata in metadata_dict.items(): + file_meta = FileMetadata.from_dict(metadata) + + # 统计文件状态 + if file_meta.status == FileStatus.ACTIVE: + stats["active_files"] = (stats["active_files"] or 0) + 1 + elif file_meta.status == FileStatus.ARCHIVED: + stats["archived_files"] = (stats["archived_files"] or 0) + 1 + elif file_meta.status == FileStatus.DELETED: + stats["deleted_files"] = (stats["deleted_files"] or 0) + 1 + + # 统计大小 + stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0) + + # 统计版本 + stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0) + + # 找出最新和最旧的文件 + if oldest_date is None or file_meta.created_at < oldest_date: + oldest_date = file_meta.created_at + stats["oldest_file"] = filename + + if newest_date is None or file_meta.modified_at > newest_date: + newest_date = file_meta.modified_at + stats["newest_file"] = filename + + return stats + + except Exception as e: + logger.exception("Failed to get storage statistics") + return {} + + def _create_version_backup(self, filename: str, metadata: dict): + """创建版本备份""" + try: + # 读取当前文件内容 + current_data = self._storage.load_once(filename) + + # 保存为版本文件 + version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}" + self._storage.save(version_filename, current_data) + + logger.debug("Created version backup: %s", version_filename) + + except Exception as e: + logger.warning("Failed to create version backup for %s: %s", filename, e) + + def _load_metadata(self) -> dict[str, Any]: + """加载元数据文件""" + try: + if self._storage.exists(self._metadata_file): + metadata_content = self._storage.load_once(self._metadata_file) + result = json.loads(metadata_content.decode("utf-8")) + return dict(result) if result else {} + else: + return {} + except Exception as e: + logger.warning("Failed to load metadata: %s", e) + return {} + + def _save_metadata(self, metadata_dict: dict): + """保存元数据文件""" + try: + metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) + self._storage.save(self._metadata_file, metadata_content.encode("utf-8")) + logger.debug("Metadata saved successfully") + except Exception as e: + logger.exception("Failed to save metadata") + raise + + def _calculate_checksum(self, data: bytes) -> str: + """计算文件校验和""" + import hashlib + + return hashlib.md5(data).hexdigest() + + def _check_permission(self, filename: str, operation: str) -> bool: + """检查文件操作权限 + + Args: + filename: 文件名 + operation: 操作类型 + + Returns: + True if permission granted, False otherwise + """ + # 如果没有权限管理器,默认允许 + if not self._permission_manager: + return True + + try: + # 根据操作类型映射到权限 + operation_mapping = { + "save": "save", + "load": "load_once", + "delete": "delete", + "archive": "delete", # 归档需要删除权限 + "restore": "save", # 恢复需要写权限 + "cleanup": "delete", # 清理需要删除权限 + "read": "load_once", + "write": "save", + } + + mapped_operation = operation_mapping.get(operation, operation) + + # 检查权限 + result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id) + return bool(result) + + except Exception as e: + logger.exception("Permission check failed for %s operation %s", filename, operation) + # 安全默认:权限检查失败时拒绝访问 + return False diff --git a/api/extensions/storage/clickzetta_volume/volume_permissions.py b/api/extensions/storage/clickzetta_volume/volume_permissions.py new file mode 100644 index 0000000000..4801df5102 --- /dev/null +++ b/api/extensions/storage/clickzetta_volume/volume_permissions.py @@ -0,0 +1,646 @@ +"""ClickZetta Volume权限管理机制 + +该模块提供Volume权限检查、验证和管理功能。 +根据ClickZetta的权限模型,不同Volume类型有不同的权限要求。 +""" + +import logging +from enum import Enum +from typing import Optional + +logger = logging.getLogger(__name__) + + +class VolumePermission(Enum): + """Volume权限类型枚举""" + + READ = "SELECT" # 对应ClickZetta的SELECT权限 + WRITE = "INSERT,UPDATE,DELETE" # 对应ClickZetta的写权限 + LIST = "SELECT" # 列出文件需要SELECT权限 + DELETE = "INSERT,UPDATE,DELETE" # 删除文件需要写权限 + USAGE = "USAGE" # External Volume需要的基本权限 + + +class VolumePermissionManager: + """Volume权限管理器""" + + def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: Optional[str] = None): + """初始化权限管理器 + + Args: + connection_or_config: ClickZetta连接对象或配置字典 + volume_type: Volume类型 (user|table|external) + volume_name: Volume名称 (用于external volume) + """ + # 支持两种初始化方式:连接对象或配置字典 + if isinstance(connection_or_config, dict): + # 从配置字典创建连接 + import clickzetta # type: ignore[import-untyped] + + config = connection_or_config + self._connection = clickzetta.connect( + username=config.get("username"), + password=config.get("password"), + instance=config.get("instance"), + service=config.get("service"), + workspace=config.get("workspace"), + vcluster=config.get("vcluster"), + schema=config.get("schema") or config.get("database"), + ) + self._volume_type = config.get("volume_type", volume_type) + self._volume_name = config.get("volume_name", volume_name) + else: + # 直接使用连接对象 + self._connection = connection_or_config + self._volume_type = volume_type + self._volume_name = volume_name + + if not self._connection: + raise ValueError("Valid connection or config is required") + if not self._volume_type: + raise ValueError("volume_type is required") + + self._permission_cache: dict[str, set[str]] = {} + self._current_username = None # 将从连接中获取当前用户名 + + def check_permission(self, operation: VolumePermission, dataset_id: Optional[str] = None) -> bool: + """检查用户是否有执行特定操作的权限 + + Args: + operation: 要执行的操作类型 + dataset_id: 数据集ID (用于table volume) + + Returns: + True if user has permission, False otherwise + """ + try: + if self._volume_type == "user": + return self._check_user_volume_permission(operation) + elif self._volume_type == "table": + return self._check_table_volume_permission(operation, dataset_id) + elif self._volume_type == "external": + return self._check_external_volume_permission(operation) + else: + logger.warning("Unknown volume type: %s", self._volume_type) + return False + + except Exception as e: + logger.exception("Permission check failed") + return False + + def _check_user_volume_permission(self, operation: VolumePermission) -> bool: + """检查User Volume权限 + + User Volume权限规则: + - 用户对自己的User Volume有全部权限 + - 只要用户能够连接到ClickZetta,就默认具有User Volume的基本权限 + - 更注重连接身份验证,而不是复杂的权限检查 + """ + try: + # 获取当前用户名 + current_user = self._get_current_username() + + # 检查基本连接状态 + with self._connection.cursor() as cursor: + # 简单的连接测试,如果能执行查询说明用户有基本权限 + cursor.execute("SELECT 1") + result = cursor.fetchone() + + if result: + logger.debug( + "User Volume permission check for %s, operation %s: granted (basic connection verified)", + current_user, + operation.name, + ) + return True + else: + logger.warning( + "User Volume permission check failed: cannot verify basic connection for %s", current_user + ) + return False + + except Exception as e: + logger.exception("User Volume permission check failed") + # 对于User Volume,如果权限检查失败,可能是配置问题,给出更友好的错误提示 + logger.info("User Volume permission check failed, but permission checking is disabled in this version") + return False + + def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: Optional[str]) -> bool: + """检查Table Volume权限 + + Table Volume权限规则: + - Table Volume权限继承对应表的权限 + - SELECT权限 -> 可以READ/LIST文件 + - INSERT,UPDATE,DELETE权限 -> 可以WRITE/DELETE文件 + """ + if not dataset_id: + logger.warning("dataset_id is required for table volume permission check") + return False + + table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id + + try: + # 检查表权限 + permissions = self._get_table_permissions(table_name) + required_permissions = set(operation.value.split(",")) + + # 检查是否有所需的所有权限 + has_permission = required_permissions.issubset(permissions) + + logger.debug( + "Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s", + table_name, + operation.name, + required_permissions, + permissions, + has_permission, + ) + + return has_permission + + except Exception as e: + logger.exception("Table volume permission check failed for %s", table_name) + return False + + def _check_external_volume_permission(self, operation: VolumePermission) -> bool: + """检查External Volume权限 + + External Volume权限规则: + - 尝试获取对External Volume的权限 + - 如果权限检查失败,进行备选验证 + - 对于开发环境,提供更宽松的权限检查 + """ + if not self._volume_name: + logger.warning("volume_name is required for external volume permission check") + return False + + try: + # 检查External Volume权限 + permissions = self._get_external_volume_permissions(self._volume_name) + + # External Volume权限映射:根据操作类型确定所需权限 + required_permissions = set() + + if operation in [VolumePermission.READ, VolumePermission.LIST]: + required_permissions.add("read") + elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]: + required_permissions.add("write") + + # 检查是否有所需的所有权限 + has_permission = required_permissions.issubset(permissions) + + logger.debug( + "External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s", + self._volume_name, + operation.name, + required_permissions, + permissions, + has_permission, + ) + + # 如果权限检查失败,尝试备选验证 + if not has_permission: + logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name) + + # 备选验证:尝试列出Volume来验证基本访问权限 + try: + with self._connection.cursor() as cursor: + cursor.execute("SHOW VOLUMES") + volumes = cursor.fetchall() + for volume in volumes: + if len(volume) > 0 and volume[0] == self._volume_name: + logger.info("Fallback verification successful for %s", self._volume_name) + return True + except Exception as fallback_e: + logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e) + + return has_permission + + except Exception as e: + logger.exception("External volume permission check failed for %s", self._volume_name) + logger.info("External Volume permission check failed, but permission checking is disabled in this version") + return False + + def _get_table_permissions(self, table_name: str) -> set[str]: + """获取用户对指定表的权限 + + Args: + table_name: 表名 + + Returns: + 用户对该表的权限集合 + """ + cache_key = f"table:{table_name}" + + if cache_key in self._permission_cache: + return self._permission_cache[cache_key] + + permissions = set() + + try: + with self._connection.cursor() as cursor: + # 使用正确的ClickZetta语法检查当前用户权限 + cursor.execute("SHOW GRANTS") + grants = cursor.fetchall() + + # 解析权限结果,查找对该表的权限 + for grant in grants: + if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) + privilege = grant[0].upper() + object_type = grant[1].upper() if len(grant) > 1 else "" + object_name = grant[2] if len(grant) > 2 else "" + + # 检查是否是对该表的权限 + if ( + object_type == "TABLE" + and object_name == table_name + or object_type == "SCHEMA" + and object_name in table_name + ): + if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: + if privilege == "ALL": + permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) + else: + permissions.add(privilege) + + # 如果没有找到明确的权限,尝试执行一个简单的查询来验证权限 + if not permissions: + try: + cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1") + permissions.add("SELECT") + except Exception: + logger.debug("Cannot query table %s, no SELECT permission", table_name) + + except Exception as e: + logger.warning("Could not check table permissions for %s: %s", table_name, e) + # 安全默认:权限检查失败时拒绝访问 + pass + + # 缓存权限信息 + self._permission_cache[cache_key] = permissions + return permissions + + def _get_current_username(self) -> str: + """获取当前用户名""" + if self._current_username: + return self._current_username + + try: + with self._connection.cursor() as cursor: + cursor.execute("SELECT CURRENT_USER()") + result = cursor.fetchone() + if result: + self._current_username = result[0] + return str(self._current_username) + except Exception as e: + logger.exception("Failed to get current username") + + return "unknown" + + def _get_user_permissions(self, username: str) -> set[str]: + """获取用户的基本权限集合""" + cache_key = f"user_permissions:{username}" + + if cache_key in self._permission_cache: + return self._permission_cache[cache_key] + + permissions = set() + + try: + with self._connection.cursor() as cursor: + # 使用正确的ClickZetta语法检查当前用户权限 + cursor.execute("SHOW GRANTS") + grants = cursor.fetchall() + + # 解析权限结果,查找用户的基本权限 + for grant in grants: + if len(grant) >= 3: # 典型格式: (privilege, object_type, object_name, ...) + privilege = grant[0].upper() + object_type = grant[1].upper() if len(grant) > 1 else "" + + # 收集所有相关权限 + if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]: + if privilege == "ALL": + permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"]) + else: + permissions.add(privilege) + + except Exception as e: + logger.warning("Could not check user permissions for %s: %s", username, e) + # 安全默认:权限检查失败时拒绝访问 + pass + + # 缓存权限信息 + self._permission_cache[cache_key] = permissions + return permissions + + def _get_external_volume_permissions(self, volume_name: str) -> set[str]: + """获取用户对指定External Volume的权限 + + Args: + volume_name: External Volume名称 + + Returns: + 用户对该Volume的权限集合 + """ + cache_key = f"external_volume:{volume_name}" + + if cache_key in self._permission_cache: + return self._permission_cache[cache_key] + + permissions = set() + + try: + with self._connection.cursor() as cursor: + # 使用正确的ClickZetta语法检查Volume权限 + logger.info("Checking permissions for volume: %s", volume_name) + cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}") + grants = cursor.fetchall() + + logger.info("Raw grants result for %s: %s", volume_name, grants) + + # 解析权限结果 + # 格式: (granted_type, privilege, conditions, granted_on, object_name, granted_to, + # grantee_name, grantor_name, grant_option, granted_time) + for grant in grants: + logger.info("Processing grant: %s", grant) + if len(grant) >= 5: + granted_type = grant[0] + privilege = grant[1].upper() + granted_on = grant[3] + object_name = grant[4] + + logger.info( + "Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s", + granted_type, + privilege, + granted_on, + object_name, + ) + + # 检查是否是对该Volume的权限或者是层级权限 + if ( + granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name) + ) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"): + logger.info("Matching grant found for %s", volume_name) + + if "READ" in privilege: + permissions.add("read") + logger.info("Added READ permission for %s", volume_name) + if "WRITE" in privilege: + permissions.add("write") + logger.info("Added WRITE permission for %s", volume_name) + if "ALTER" in privilege: + permissions.add("alter") + logger.info("Added ALTER permission for %s", volume_name) + if privilege == "ALL": + permissions.update(["read", "write", "alter"]) + logger.info("Added ALL permissions for %s", volume_name) + + logger.info("Final permissions for %s: %s", volume_name, permissions) + + # 如果没有找到明确的权限,尝试查看Volume列表来验证基本权限 + if not permissions: + try: + cursor.execute("SHOW VOLUMES") + volumes = cursor.fetchall() + for volume in volumes: + if len(volume) > 0 and volume[0] == volume_name: + permissions.add("read") # 至少有读权限 + logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name) + break + except Exception: + logger.debug("Cannot access volume %s, no basic permission", volume_name) + + except Exception as e: + logger.warning("Could not check external volume permissions for %s: %s", volume_name, e) + # 在权限检查失败时,尝试基本的Volume访问验证 + try: + with self._connection.cursor() as cursor: + cursor.execute("SHOW VOLUMES") + volumes = cursor.fetchall() + for volume in volumes: + if len(volume) > 0 and volume[0] == volume_name: + logger.info("Basic volume access verified for %s", volume_name) + permissions.add("read") + permissions.add("write") # 假设有写权限 + break + except Exception as basic_e: + logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e) + # 最后的备选方案:假设有基本权限 + permissions.add("read") + + # 缓存权限信息 + self._permission_cache[cache_key] = permissions + return permissions + + def clear_permission_cache(self): + """清空权限缓存""" + self._permission_cache.clear() + logger.debug("Permission cache cleared") + + def get_permission_summary(self, dataset_id: Optional[str] = None) -> dict[str, bool]: + """获取权限摘要 + + Args: + dataset_id: 数据集ID (用于table volume) + + Returns: + 权限摘要字典 + """ + summary = {} + + for operation in VolumePermission: + summary[operation.name.lower()] = self.check_permission(operation, dataset_id) + + return summary + + def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool: + """检查文件路径的权限继承 + + Args: + file_path: 文件路径 + operation: 要执行的操作 + + Returns: + True if user has permission, False otherwise + """ + try: + # 解析文件路径 + path_parts = file_path.strip("/").split("/") + + if not path_parts: + logger.warning("Invalid file path for permission inheritance check") + return False + + # 对于Table Volume,第一层是dataset_id + if self._volume_type == "table": + if len(path_parts) < 1: + return False + + dataset_id = path_parts[0] + + # 检查对dataset的权限 + has_dataset_permission = self.check_permission(operation, dataset_id) + + if not has_dataset_permission: + logger.debug("Permission denied for dataset %s", dataset_id) + return False + + # 检查路径遍历攻击 + if self._contains_path_traversal(file_path): + logger.warning("Path traversal attack detected: %s", file_path) + return False + + # 检查是否访问敏感目录 + if self._is_sensitive_path(file_path): + logger.warning("Access to sensitive path denied: %s", file_path) + return False + + logger.debug("Permission inherited for path %s", file_path) + return True + + elif self._volume_type == "user": + # User Volume的权限继承 + current_user = self._get_current_username() + + # 检查是否试图访问其他用户的目录 + if len(path_parts) > 1 and path_parts[0] != current_user: + logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0]) + return False + + # 检查基本权限 + return self.check_permission(operation) + + elif self._volume_type == "external": + # External Volume的权限继承 + # 检查对External Volume的权限 + return self.check_permission(operation) + + else: + logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type) + return False + + except Exception as e: + logger.exception("Permission inheritance check failed") + return False + + def _contains_path_traversal(self, file_path: str) -> bool: + """检查路径是否包含路径遍历攻击""" + # 检查常见的路径遍历模式 + traversal_patterns = [ + "../", + "..\\", + "..%2f", + "..%2F", + "..%5c", + "..%5C", + "%2e%2e%2f", + "%2e%2e%5c", + "....//", + "....\\\\", + ] + + file_path_lower = file_path.lower() + + for pattern in traversal_patterns: + if pattern in file_path_lower: + return True + + # 检查绝对路径 + if file_path.startswith("/") or file_path.startswith("\\"): + return True + + # 检查Windows驱动器路径 + if len(file_path) >= 2 and file_path[1] == ":": + return True + + return False + + def _is_sensitive_path(self, file_path: str) -> bool: + """检查路径是否为敏感路径""" + sensitive_patterns = [ + "passwd", + "shadow", + "hosts", + "config", + "secrets", + "private", + "key", + "certificate", + "cert", + "ssl", + "database", + "backup", + "dump", + "log", + "tmp", + ] + + file_path_lower = file_path.lower() + + return any(pattern in file_path_lower for pattern in sensitive_patterns) + + def validate_operation(self, operation: str, dataset_id: Optional[str] = None) -> bool: + """验证操作权限 + + Args: + operation: 操作名称 (save|load|exists|delete|scan) + dataset_id: 数据集ID + + Returns: + True if operation is allowed, False otherwise + """ + operation_mapping = { + "save": VolumePermission.WRITE, + "load": VolumePermission.READ, + "load_once": VolumePermission.READ, + "load_stream": VolumePermission.READ, + "download": VolumePermission.READ, + "exists": VolumePermission.READ, + "delete": VolumePermission.DELETE, + "scan": VolumePermission.LIST, + } + + if operation not in operation_mapping: + logger.warning("Unknown operation: %s", operation) + return False + + volume_permission = operation_mapping[operation] + return self.check_permission(volume_permission, dataset_id) + + +class VolumePermissionError(Exception): + """Volume权限错误异常""" + + def __init__(self, message: str, operation: str, volume_type: str, dataset_id: Optional[str] = None): + self.operation = operation + self.volume_type = volume_type + self.dataset_id = dataset_id + super().__init__(message) + + +def check_volume_permission( + permission_manager: VolumePermissionManager, operation: str, dataset_id: Optional[str] = None +) -> None: + """权限检查装饰器函数 + + Args: + permission_manager: 权限管理器 + operation: 操作名称 + dataset_id: 数据集ID + + Raises: + VolumePermissionError: 如果没有权限 + """ + if not permission_manager.validate_operation(operation, dataset_id): + error_message = f"Permission denied for operation '{operation}' on {permission_manager._volume_type} volume" + if dataset_id: + error_message += f" (dataset: {dataset_id})" + + raise VolumePermissionError( + error_message, + operation=operation, + volume_type=permission_manager._volume_type or "unknown", + dataset_id=dataset_id, + ) diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index 12e2738e9d..0ba35506d3 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -35,21 +35,21 @@ class OpenDALStorage(BaseStorage): Path(root).mkdir(parents=True, exist_ok=True) self.op = opendal.Operator(scheme=scheme, **kwargs) # type: ignore - logger.debug(f"opendal operator created with scheme {scheme}") + logger.debug("opendal operator created with scheme %s", scheme) retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True) self.op = self.op.layer(retry_layer) logger.debug("added retry layer to opendal operator") def save(self, filename: str, data: bytes) -> None: self.op.write(path=filename, bs=data) - logger.debug(f"file {filename} saved") + logger.debug("file %s saved", filename) def load_once(self, filename: str) -> bytes: if not self.exists(filename): raise FileNotFoundError("File not found") content: bytes = self.op.read(path=filename) - logger.debug(f"file {filename} loaded") + logger.debug("file %s loaded", filename) return content def load_stream(self, filename: str) -> Generator: @@ -60,7 +60,7 @@ class OpenDALStorage(BaseStorage): file = self.op.open(path=filename, mode="rb") while chunk := file.read(batch_size): yield chunk - logger.debug(f"file {filename} loaded as stream") + logger.debug("file %s loaded as stream", filename) def download(self, filename: str, target_filepath: str): if not self.exists(filename): @@ -68,7 +68,7 @@ class OpenDALStorage(BaseStorage): with Path(target_filepath).open("wb") as f: f.write(self.op.read(path=filename)) - logger.debug(f"file {filename} downloaded to {target_filepath}") + logger.debug("file %s downloaded to %s", filename, target_filepath) def exists(self, filename: str) -> bool: res: bool = self.op.exists(path=filename) @@ -77,9 +77,9 @@ class OpenDALStorage(BaseStorage): def delete(self, filename: str): if self.exists(filename): self.op.delete(path=filename) - logger.debug(f"file {filename} deleted") + logger.debug("file %s deleted", filename) return - logger.debug(f"file {filename} not found, skip delete") + logger.debug("file %s not found, skip delete", filename) def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: if not self.exists(path): @@ -87,13 +87,13 @@ class OpenDALStorage(BaseStorage): all_files = self.op.scan(path=path) if files and directories: - logger.debug(f"files and directories on {path} scanned") + logger.debug("files and directories on %s scanned", path) return [f.path for f in all_files] if files: - logger.debug(f"files on {path} scanned") + logger.debug("files on %s scanned", path) return [f.path for f in all_files if not f.path.endswith("/")] elif directories: - logger.debug(f"directories on {path} scanned") + logger.debug("directories on %s scanned", path) return [f.path for f in all_files if f.path.endswith("/")] else: raise ValueError("At least one of files or directories must be True") diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py index 0a891e36cf..bc2d632159 100644 --- a/api/extensions/storage/storage_type.py +++ b/api/extensions/storage/storage_type.py @@ -5,6 +5,7 @@ class StorageType(StrEnum): ALIYUN_OSS = "aliyun-oss" AZURE_BLOB = "azure-blob" BAIDU_OBS = "baidu-obs" + CLICKZETTA_VOLUME = "clickzetta-volume" GOOGLE_STORAGE = "google-storage" HUAWEI_OBS = "huawei-obs" LOCAL = "local" diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index 55fe6545ec..32839d3497 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -25,7 +25,7 @@ class VolcengineTosStorage(BaseStorage): def load_once(self, filename: str) -> bytes: data = self.client.get_object(bucket=self.bucket_name, key=filename).read() if not isinstance(data, bytes): - raise TypeError("Expected bytes, got {}".format(type(data).__name__)) + raise TypeError(f"Expected bytes, got {type(data).__name__}") return data def load_stream(self, filename: str) -> Generator: diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 512a9cb608..a0ff33ab65 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,4 +1,6 @@ import mimetypes +import os +import urllib.parse import uuid from collections.abc import Callable, Mapping, Sequence from typing import Any, cast @@ -240,16 +242,28 @@ def _build_from_remote_url( def _get_remote_file_info(url: str): file_size = -1 - filename = url.split("/")[-1].split("?")[0] or "unknown_file" - mime_type = mimetypes.guess_type(filename)[0] or "" + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + + # Initialize mime_type from filename as fallback + mime_type, _ = mimetypes.guess_type(filename) + if mime_type is None: + mime_type = "" resp = ssrf_proxy.head(url, follow_redirects=True) resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): filename = str(content_disposition.split("filename=")[-1].strip('"')) + # Re-guess mime_type from updated filename + mime_type, _ = mimetypes.guess_type(filename) + if mime_type is None: + mime_type = "" file_size = int(resp.headers.get("Content-Length", file_size)) - mime_type = mime_type or str(resp.headers.get("Content-Type", "")) + # Fallback to Content-Type header if mime_type is still empty + if not mime_type: + mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() return mime_type, filename, file_size diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 379dcc6d16..38835d5ac7 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import TimestampField @@ -11,6 +11,12 @@ annotation_fields = { # 'account': fields.Nested(simple_account_fields, allow_null=True) } + +def build_annotation_model(api_or_ns: Api | Namespace): + """Build the annotation model for the API or Namespace.""" + return api_or_ns.model("Annotation", annotation_fields) + + annotation_list_fields = { "data": fields.List(fields.Nested(annotation_fields)), } diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index a85d4a34db..a2dda1dc15 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -1,10 +1,10 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField class HiddenAPIKey(fields.Raw): - def output(self, key, obj): + def output(self, key, obj, **kwargs): api_key = obj.api_key # If the length of the api_key is less than 8 characters, show the first and last characters if len(api_key) <= 8: diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index b6d85e0e24..1f14d663b8 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,6 +1,6 @@ import json -from flask_restful import fields +from flask_restx import fields from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField @@ -59,6 +59,8 @@ model_config_fields = { "updated_at": TimestampField, } +tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} + app_detail_fields = { "id": fields.String, "name": fields.String, @@ -77,6 +79,7 @@ app_detail_fields = { "updated_by": fields.String, "updated_at": TimestampField, "access_mode": fields.String, + "tags": fields.List(fields.Nested(tag_fields)), } prompt_config_fields = { @@ -92,8 +95,6 @@ model_config_partial_fields = { "updated_at": TimestampField, } -tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} - app_partial_fields = { "id": fields.String, "name": fields.String, @@ -185,7 +186,6 @@ app_detail_fields_with_site = { "enable_api": fields.Boolean, "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), "workflow": fields.Nested(workflow_partial_fields, allow_null=True), - "site": fields.Nested(site_fields), "api_base_url": fields.String, "use_icon_as_answer_icon": fields.Boolean, "max_active_requests": fields.Integer, @@ -195,6 +195,8 @@ app_detail_fields_with_site = { "updated_at": TimestampField, "deleted_tools": fields.List(fields.Nested(deleted_tool_fields)), "access_mode": fields.String, + "tags": fields.List(fields.Nested(tag_fields)), + "site": fields.Nested(site_fields), } diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 370e8a5a58..ecc267cf38 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -45,6 +45,12 @@ message_file_fields = { "upload_file_id": fields.String(default=None), } + +def build_message_file_model(api_or_ns: Api | Namespace): + """Build the message file fields for the API or Namespace.""" + return api_or_ns.model("MessageFile", message_file_fields) + + agent_thought_fields = { "id": fields.String, "chain_id": fields.String, @@ -209,3 +215,22 @@ conversation_infinite_scroll_pagination_fields = { "has_more": fields.Boolean, "data": fields.List(fields.Nested(simple_conversation_fields)), } + + +def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): + """Build the conversation infinite scroll pagination model for the API or Namespace.""" + simple_conversation_model = build_simple_conversation_model(api_or_ns) + + copied_fields = conversation_infinite_scroll_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model)) + return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields) + + +def build_conversation_delete_model(api_or_ns: Api | Namespace): + """Build the conversation delete model for the API or Namespace.""" + return api_or_ns.model("ConversationDelete", conversation_delete_fields) + + +def build_simple_conversation_model(api_or_ns: Api | Namespace): + """Build the simple conversation model for the API or Namespace.""" + return api_or_ns.model("SimpleConversation", simple_conversation_fields) diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index c5a0c9a49d..7d5e311591 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import TimestampField @@ -27,3 +27,19 @@ conversation_variable_infinite_scroll_pagination_fields = { "has_more": fields.Boolean, "data": fields.List(fields.Nested(conversation_variable_fields)), } + + +def build_conversation_variable_model(api_or_ns: Api | Namespace): + """Build the conversation variable model for the API or Namespace.""" + return api_or_ns.model("ConversationVariable", conversation_variable_fields) + + +def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): + """Build the conversation variable infinite scroll pagination model for the API or Namespace.""" + # Build the nested variable model first + conversation_variable_model = build_conversation_variable_model(api_or_ns) + + copied_fields = conversation_variable_infinite_scroll_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(conversation_variable_model)) + + return api_or_ns.model("ConversationVariableInfiniteScrollPagination", copied_fields) diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 071071376f..93f6e447dc 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 32a88cc5db..5a3082516e 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 7fd43e8dbe..9be59f7454 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from fields.dataset_fields import dataset_fields from libs.helper import TimestampField diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 99e529f9d1..ea43e3b5fd 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields simple_end_user_fields = { "id": fields.String, @@ -6,3 +6,7 @@ simple_end_user_fields = { "is_anonymous": fields.Boolean, "session_id": fields.String, } + + +def build_simple_end_user_model(api_or_ns: Api | Namespace): + return api_or_ns.model("SimpleEndUser", simple_end_user_fields) diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 8b4839ef97..dd359e2f5f 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import TimestampField @@ -11,6 +11,19 @@ upload_config_fields = { "workflow_file_upload_limit": fields.Integer, } + +def build_upload_config_model(api_or_ns: Api | Namespace): + """Build the upload config model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("UploadConfig", upload_config_fields) + + file_fields = { "id": fields.String, "name": fields.String, @@ -22,12 +35,37 @@ file_fields = { "preview_url": fields.String, } + +def build_file_model(api_or_ns: Api | Namespace): + """Build the file model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("File", file_fields) + + remote_file_info_fields = { "file_type": fields.String(attribute="file_type"), "file_length": fields.Integer(attribute="file_length"), } +def build_remote_file_info_model(api_or_ns: Api | Namespace): + """Build the remote file info model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("RemoteFileInfo", remote_file_info_fields) + + file_fields_with_signed_url = { "id": fields.String, "name": fields.String, @@ -38,3 +76,15 @@ file_fields_with_signed_url = { "created_by": fields.String, "created_at": TimestampField, } + + +def build_file_with_signed_url_model(api_or_ns: Api | Namespace): + """Build the file with signed URL model for the API or Namespace. + + Args: + api_or_ns: Flask-RestX Api or Namespace instance + + Returns: + The registered model + """ + return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url) diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 9d67999ea4..75bdff1803 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index e0b3e340f6..16dd26a10e 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 8007b7e052..08e38a6931 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,8 +1,17 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from libs.helper import AvatarUrlField, TimestampField -simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} +simple_account_fields = { + "id": fields.String, + "name": fields.String, + "email": fields.String, +} + + +def build_simple_account_model(api_or_ns: Api | Namespace): + return api_or_ns.model("SimpleAccount", simple_account_fields) + account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index e6aebd810f..a419da2e18 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,11 +1,19 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from fields.conversation_fields import message_file_fields from libs.helper import TimestampField from .raws import FilesContainedField -feedback_fields = {"rating": fields.String} +feedback_fields = { + "rating": fields.String, +} + + +def build_feedback_model(api_or_ns: Api | Namespace): + """Build the feedback model for the API or Namespace.""" + return api_or_ns.model("Feedback", feedback_fields) + agent_thought_fields = { "id": fields.String, @@ -21,6 +29,12 @@ agent_thought_fields = { "files": fields.List(fields.String), } + +def build_agent_thought_model(api_or_ns: Api | Namespace): + """Build the agent thought model for the API or Namespace.""" + return api_or_ns.model("AgentThought", agent_thought_fields) + + retriever_resource_fields = { "id": fields.String, "message_id": fields.String, diff --git a/api/fields/raws.py b/api/fields/raws.py index 15ec16ab13..9bc6a12c78 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from core.file import File diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 4126c24598..2ff917d6bc 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from libs.helper import TimestampField diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index 9af4fc57dd..d5b7c86a04 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,3 +1,12 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields -tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} +dataset_tag_fields = { + "id": fields.String, + "name": fields.String, + "type": fields.String, + "binding_count": fields.String, +} + + +def build_dataset_tag_fields(api_or_ns: Api | Namespace): + return api_or_ns.model("DataSetTag", dataset_tag_fields) diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index 823c99ec6b..243efd817c 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,8 +1,8 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields -from fields.workflow_run_fields import workflow_run_for_log_fields +from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields +from fields.member_fields import build_simple_account_model, simple_account_fields +from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields from libs.helper import TimestampField workflow_app_log_partial_fields = { @@ -15,6 +15,24 @@ workflow_app_log_partial_fields = { "created_at": TimestampField, } + +def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace): + """Build the workflow app log partial model for the API or Namespace.""" + workflow_run_model = build_workflow_run_for_log_model(api_or_ns) + simple_account_model = build_simple_account_model(api_or_ns) + simple_end_user_model = build_simple_end_user_model(api_or_ns) + + copied_fields = workflow_app_log_partial_fields.copy() + copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) + copied_fields["created_by_account"] = fields.Nested( + simple_account_model, attribute="created_by_account", allow_null=True + ) + copied_fields["created_by_end_user"] = fields.Nested( + simple_end_user_model, attribute="created_by_end_user", allow_null=True + ) + return api_or_ns.model("WorkflowAppLogPartial", copied_fields) + + workflow_app_log_pagination_fields = { "page": fields.Integer, "limit": fields.Integer, @@ -22,3 +40,13 @@ workflow_app_log_pagination_fields = { "has_more": fields.Boolean, "data": fields.List(fields.Nested(workflow_app_log_partial_fields)), } + + +def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace): + """Build the workflow app log pagination model for the API or Namespace.""" + # Build the nested partial model first + workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns) + + copied_fields = workflow_app_log_pagination_fields.copy() + copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model)) + return api_or_ns.model("WorkflowAppLogPagination", copied_fields) diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 930e59cc1c..f048d0f3b6 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import fields from core.helper import encrypter from core.variables import SecretVariable, SegmentType, Variable diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index a106728e9c..6462d8ce5a 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restx import Api, Namespace, fields from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -17,6 +17,11 @@ workflow_run_for_log_fields = { "exceptions_count": fields.Integer, } + +def build_workflow_run_for_log_model(api_or_ns: Api | Namespace): + return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) + + workflow_run_for_list_fields = { "id": fields.String, "version": fields.String, diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 2070df3e55..95d13cd0e6 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,119 +1,111 @@ import re import sys +from collections.abc import Mapping from typing import Any from flask import current_app, got_request_exception -from flask_restful import Api, http_status_message -from werkzeug.datastructures import Headers +from flask_restx import Api from werkzeug.exceptions import HTTPException +from werkzeug.http import HTTP_STATUS_CODES from core.errors.error import AppInvokeQuotaExceededError -class ExternalApi(Api): - def handle_error(self, e): - """Error handler for the API transforms a raised exception into a Flask - response, with the appropriate HTTP status code and body. +def http_status_message(code): + return HTTP_STATUS_CODES.get(code, "") - :param e: the raised Exception object - :type e: Exception - """ +def register_external_error_handlers(api: Api) -> None: + @api.errorhandler(HTTPException) + def handle_http_exception(e: HTTPException): got_request_exception.send(current_app, exception=e) - headers = Headers() - if isinstance(e, HTTPException): - if e.response is not None: - resp = e.get_response() - return resp + # If Werkzeug already prepared a Response, just use it. + if getattr(e, "response", None) is not None: + return e.response - status_code = e.code - default_data = { - "code": re.sub(r"(?= 500: - exc_info: Any = sys.exc_info() - if exc_info[1] is None: - exc_info = None - current_app.log_exception(exc_info) - - if status_code == 406 and self.default_mediatype is None: - # if we are handling NotAcceptable (406), make sure that - # make_response uses a representation we support as the - # default mediatype (so that make_response doesn't throw - # another NotAcceptable error). - supported_mediatypes = list(self.representations.keys()) # only supported application/json - fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" - data = {"code": "not_acceptable", "message": data.get("message")} - resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) + # Payload per status + if status_code == 406 and api.default_mediatype is None: + data = {"code": "not_acceptable", "message": default_data["message"], "status": status_code} + return data, status_code, headers elif status_code == 400: - if isinstance(data.get("message"), dict): - param_key, param_value = list(data.get("message", {}).items())[0] - data = {"code": "invalid_param", "message": param_value, "params": param_key} + msg = default_data["message"] + if isinstance(msg, Mapping) and msg: + # Convert param errors like {"field": "reason"} into a friendly shape + param_key, param_value = next(iter(msg.items())) + data = { + "code": "invalid_param", + "message": str(param_value), + "params": param_key, + "status": status_code, + } else: - if "code" not in data: - data["code"] = "unknown" - - resp = self.make_response(data, status_code, headers) + data = {**default_data} + data.setdefault("code", "unknown") + return data, status_code, headers else: - if "code" not in data: - data["code"] = "unknown" + data = {**default_data} + data.setdefault("code", "unknown") + # If you need WWW-Authenticate for 401, add it to headers + if status_code == 401: + headers["WWW-Authenticate"] = 'Bearer realm="api"' + return data, status_code, headers - resp = self.make_response(data, status_code, headers) + @api.errorhandler(ValueError) + def handle_value_error(e: ValueError): + got_request_exception.send(current_app, exception=e) + status_code = 400 + data = {"code": "invalid_param", "message": str(e), "status": status_code} + return data, status_code - if status_code == 401: - resp = self.unauthorized(resp) - return resp + @api.errorhandler(AppInvokeQuotaExceededError) + def handle_quota_exceeded(e: AppInvokeQuotaExceededError): + got_request_exception.send(current_app, exception=e) + status_code = 429 + data = {"code": "too_many_requests", "message": str(e), "status": status_code} + return data, status_code + + @api.errorhandler(Exception) + def handle_general_exception(e: Exception): + got_request_exception.send(current_app, exception=e) + + status_code = 500 + data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) + + # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) + if not isinstance(data, Mapping): + data = {"message": str(e)} + + data.setdefault("code", "unknown") + data.setdefault("status", status_code) + + # Log stack + exc_info: Any = sys.exc_info() + if exc_info[1] is None: + exc_info = None + current_app.log_exception(exc_info) + + return data, status_code + + +class ExternalApi(Api): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + register_external_error_handlers(self) diff --git a/api/libs/helper.py b/api/libs/helper.py index 00772d530a..70986fedd3 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields +from flask_restx import fields from pydantic import BaseModel from configs import dify_config @@ -57,7 +57,7 @@ def run(script): class AppIconUrlField(fields.Raw): - def output(self, key, obj): + def output(self, key, obj, **kwargs): if obj is None: return None @@ -72,7 +72,7 @@ class AppIconUrlField(fields.Raw): class AvatarUrlField(fields.Raw): - def output(self, key, obj): + def output(self, key, obj, **kwargs): if obj is None: return None @@ -95,7 +95,7 @@ def email(email): if re.match(pattern, email) is not None: return email - error = "{email} is not a valid email.".format(email=email) + error = f"{email} is not a valid email." raise ValueError(error) @@ -107,7 +107,7 @@ def uuid_value(value): uuid_obj = uuid.UUID(value) return str(uuid_obj) except ValueError: - error = "{value} is not a valid uuid.".format(value=value) + error = f"{value} is not a valid uuid." raise ValueError(error) @@ -126,7 +126,7 @@ def timestamp_value(timestamp): raise ValueError return int_timestamp except ValueError: - error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) + error = f"{timestamp} is not a valid timestamp." raise ValueError(error) @@ -169,14 +169,14 @@ def _get_float(value): try: return float(value) except (TypeError, ValueError): - raise ValueError("{} is not a valid float".format(value)) + raise ValueError(f"{value} is not a valid float") def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string - error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) + error = f"{timezone_string} is not a valid timezone." raise ValueError(error) @@ -321,7 +321,7 @@ class TokenManager: key = cls._get_token_key(token, token_type) token_data_json = redis_client.get(key) if token_data_json is None: - logging.warning(f"{token_type} token {token} not found with key {key}") + logging.warning("%s token %s not found with key %s", token_type, token, key) return None token_data: Optional[dict[str, Any]] = json.loads(token_data_json) return token_data diff --git a/api/libs/module_loading.py b/api/libs/module_loading.py new file mode 100644 index 0000000000..616d072a1b --- /dev/null +++ b/api/libs/module_loading.py @@ -0,0 +1,55 @@ +""" +Module loading utilities similar to Django's module_loading. + +Reference implementation from Django: +https://github.com/django/django/blob/main/django/utils/module_loading.py +""" + +import sys +from importlib import import_module +from typing import Any + + +def cached_import(module_path: str, class_name: str) -> Any: + """ + Import a module and return the named attribute/class from it, with caching. + + Args: + module_path: The module path to import from + class_name: The attribute/class name to retrieve + + Returns: + The imported attribute/class + """ + if not ( + (module := sys.modules.get(module_path)) + and (spec := getattr(module, "__spec__", None)) + and getattr(spec, "_initializing", False) is False + ): + module = import_module(module_path) + return getattr(module, class_name) + + +def import_string(dotted_path: str) -> Any: + """ + Import a dotted module path and return the attribute/class designated by + the last name in the path. Raise ImportError if the import failed. + + Args: + dotted_path: Full module path to the class (e.g., 'module.submodule.ClassName') + + Returns: + The imported class or attribute + + Raises: + ImportError: If the module or attribute cannot be imported + """ + try: + module_path, class_name = dotted_path.rsplit(".", 1) + except ValueError as err: + raise ImportError(f"{dotted_path} doesn't look like a module path") from err + + try: + return cached_import(module_path, class_name) + except AttributeError as err: + raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class') from err diff --git a/api/libs/rsa.py b/api/libs/rsa.py index ed7a0eb116..c72032701f 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,5 +1,4 @@ import hashlib -import os from typing import Union from Crypto.Cipher import AES @@ -18,7 +17,7 @@ def generate_key_pair(tenant_id: str) -> str: pem_private = private_key.export_key() pem_public = public_key.export_key() - filepath = os.path.join("privkeys", tenant_id, "private.pem") + filepath = f"privkeys/{tenant_id}/private.pem" storage.save(filepath, pem_private) @@ -48,15 +47,15 @@ def encrypt(text: str, public_key: Union[str, bytes]) -> bytes: def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: - filepath = os.path.join("privkeys", tenant_id, "private.pem") + filepath = f"privkeys/{tenant_id}/private.pem" - cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) + cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}" private_key = redis_client.get(cache_key) if not private_key: try: private_key = storage.load(filepath) except FileNotFoundError: - raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id)) + raise PrivkeyNotFoundError(f"Private key not found, tenant_id: {tenant_id}") redis_client.setex(cache_key, 120, private_key) diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index 5409e3eeeb..cfc6c7d794 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -41,5 +41,5 @@ class SendGridClient: ) raise except Exception as e: - logging.exception(f"SendGridClient Unexpected error occurred while sending email to {_to}") + logging.exception("SendGridClient Unexpected error occurred while sending email to %s", _to) raise diff --git a/api/libs/smtp.py b/api/libs/smtp.py index b94386660e..a01ad6fab8 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -50,7 +50,7 @@ class SMTPClient: logging.exception("Timeout occurred while sending email") raise except Exception as e: - logging.exception(f"Unexpected error occurred while sending email to {mail['to']}") + logging.exception("Unexpected error occurred while sending email to %s", mail["to"]) raise finally: if smtp: diff --git a/api/migrations/versions/2025_07_24_1450-532b3f888abf_manual_dataset_field_update.py b/api/migrations/versions/2025_07_24_1450-532b3f888abf_manual_dataset_field_update.py new file mode 100644 index 0000000000..1664fb99c4 --- /dev/null +++ b/api/migrations/versions/2025_07_24_1450-532b3f888abf_manual_dataset_field_update.py @@ -0,0 +1,25 @@ +"""manual dataset field update + +Revision ID: 532b3f888abf +Revises: 8bcc02c9bd07 +Create Date: 2025-07-24 14:50:48.779833 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '532b3f888abf' +down_revision = '8bcc02c9bd07' +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'::character varying") + + +def downgrade(): + op.execute("ALTER TABLE tidb_auth_bindings ALTER COLUMN status SET DEFAULT 'CREATING'") diff --git a/api/migrations/versions/2025_08_07_1115-fa8b0fa6f407_add_timeout_for_tool_mcp_providers.py b/api/migrations/versions/2025_08_07_1115-fa8b0fa6f407_add_timeout_for_tool_mcp_providers.py new file mode 100644 index 0000000000..383e21cd28 --- /dev/null +++ b/api/migrations/versions/2025_08_07_1115-fa8b0fa6f407_add_timeout_for_tool_mcp_providers.py @@ -0,0 +1,33 @@ +"""add timeout for tool_mcp_providers + +Revision ID: fa8b0fa6f407 +Revises: 532b3f888abf +Create Date: 2025-08-07 11:15:31.517985 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'fa8b0fa6f407' +down_revision = '532b3f888abf' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('timeout', sa.Float(), server_default=sa.text('30'), nullable=False)) + batch_op.add_column(sa.Column('sse_read_timeout', sa.Float(), server_default=sa.text('300'), nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.drop_column('sse_read_timeout') + batch_op.drop_column('timeout') + + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index d63c5d7fb5..1a0752440d 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -3,8 +3,9 @@ import json from datetime import datetime from typing import Optional, cast +import sqlalchemy as sa from flask_login import UserMixin # type: ignore -from sqlalchemy import func, select +from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, mapped_column, reconstructor from models.base import Base @@ -83,26 +84,24 @@ class AccountStatus(enum.StrEnum): class Account(UserMixin, Base): __tablename__ = "accounts" - __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email")) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) - email: Mapped[str] = mapped_column(db.String(255)) - password: Mapped[Optional[str]] = mapped_column(db.String(255)) - password_salt: Mapped[Optional[str]] = mapped_column(db.String(255)) - avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - interface_language: Mapped[Optional[str]] = mapped_column(db.String(255)) - interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - timezone: Mapped[Optional[str]] = mapped_column(db.String(255)) - last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - last_active_at: Mapped[datetime] = mapped_column( - db.DateTime, server_default=func.current_timestamp(), nullable=False - ) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying")) - initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(String(255)) + email: Mapped[str] = mapped_column(String(255)) + password: Mapped[Optional[str]] = mapped_column(String(255)) + password_salt: Mapped[Optional[str]] = mapped_column(String(255)) + avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + interface_language: Mapped[Optional[str]] = mapped_column(String(255)) + interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + timezone: Mapped[Optional[str]] = mapped_column(String(255)) + last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying")) + initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) @reconstructor def init_on_load(self): @@ -197,16 +196,16 @@ class TenantStatus(enum.StrEnum): class Tenant(Base): __tablename__ = "tenants" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(255)) - encrypt_public_key = db.Column(db.Text) - plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying")) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) - custom_config: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(String(255)) + encrypt_public_key = db.Column(sa.Text) + plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) + custom_config: Mapped[Optional[str]] = mapped_column(sa.Text) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -227,56 +226,56 @@ class Tenant(Base): class TenantAccountJoin(Base): __tablename__ = "tenant_account_joins" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), - db.Index("tenant_account_join_account_id_idx", "account_id"), - db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), - db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), + sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), + sa.Index("tenant_account_join_account_id_idx", "account_id"), + sa.Index("tenant_account_join_tenant_id_idx", "tenant_id"), + sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) - current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - role: Mapped[str] = mapped_column(db.String(16), server_default="normal") + current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + role: Mapped[str] = mapped_column(String(16), server_default="normal") invited_by: Mapped[Optional[str]] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class AccountIntegrate(Base): __tablename__ = "account_integrates" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), - db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), - db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), + sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"), + sa.UniqueConstraint("account_id", "provider", name="unique_account_provider"), + sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) account_id: Mapped[str] = mapped_column(StringUUID) - provider: Mapped[str] = mapped_column(db.String(16)) - open_id: Mapped[str] = mapped_column(db.String(255)) - encrypted_token: Mapped[str] = mapped_column(db.String(255)) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + provider: Mapped[str] = mapped_column(String(16)) + open_id: Mapped[str] = mapped_column(String(255)) + encrypted_token: Mapped[str] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) class InvitationCode(Base): __tablename__ = "invitation_codes" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), - db.Index("invitation_codes_batch_idx", "batch"), - db.Index("invitation_codes_code_idx", "code", "status"), + sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"), + sa.Index("invitation_codes_batch_idx", "batch"), + sa.Index("invitation_codes_code_idx", "code", "status"), ) - id: Mapped[int] = mapped_column(db.Integer) - batch: Mapped[str] = mapped_column(db.String(255)) - code: Mapped[str] = mapped_column(db.String(32)) - status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying")) - used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + id: Mapped[int] = mapped_column(sa.Integer) + batch: Mapped[str] = mapped_column(String(255)) + code: Mapped[str] = mapped_column(String(32)) + status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying")) + used_at: Mapped[Optional[datetime]] = mapped_column(DateTime) used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID) used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)")) + deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)")) class TenantPluginPermission(Base): @@ -292,16 +291,14 @@ class TenantPluginPermission(Base): __tablename__ = "account_plugin_permissions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"), - db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), + sa.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"), + sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - install_permission: Mapped[InstallPermission] = mapped_column( - db.String(16), nullable=False, server_default="everyone" - ) - debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone") + install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone") + debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone") class TenantPluginAutoUpgradeStrategy(Base): @@ -317,20 +314,16 @@ class TenantPluginAutoUpgradeStrategy(Base): __tablename__ = "tenant_plugin_auto_upgrade_strategies" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"), - db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), + sa.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"), + sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only") - upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day - upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude") - exclude_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - include_plugins: Mapped[list[str]] = mapped_column( - db.ARRAY(db.String(255)), nullable=False - ) # plugin_id (author/name) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only") + upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day + upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude") + exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name) + created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 3cef5a0fb2..60167d9069 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,10 +1,11 @@ import enum +from datetime import datetime -from sqlalchemy import func -from sqlalchemy.orm import mapped_column +import sqlalchemy as sa +from sqlalchemy import DateTime, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column from .base import Base -from .engine import db from .types import StringUUID @@ -18,13 +19,13 @@ class APIBasedExtensionPoint(enum.Enum): class APIBasedExtension(Base): __tablename__ = "api_based_extensions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), - db.Index("api_based_extension_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), + sa.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - api_endpoint = mapped_column(db.String(255), nullable=False) - api_key = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + name: Mapped[str] = mapped_column(String(255), nullable=False) + api_endpoint: Mapped[str] = mapped_column(String(255), nullable=False) + api_key = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index d877540213..3b1d289bc4 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -12,7 +12,8 @@ from datetime import datetime from json import JSONDecodeError from typing import Any, Optional, cast -from sqlalchemy import func, select +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func, select from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -38,32 +39,32 @@ class DatasetPermissionEnum(enum.StrEnum): class Dataset(Base): __tablename__ = "datasets" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_pkey"), - db.Index("dataset_tenant_idx", "tenant_id"), - db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), + sa.PrimaryKeyConstraint("id", name="dataset_pkey"), + sa.Index("dataset_tenant_idx", "tenant_id"), + sa.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), ) INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) - description = mapped_column(db.Text, nullable=True) - provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying")) - permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying")) - data_source_type = mapped_column(db.String(255)) - indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255)) - index_struct = mapped_column(db.Text, nullable=True) + name: Mapped[str] = mapped_column(String(255)) + description = mapped_column(sa.Text, nullable=True) + provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'::character varying")) + permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'::character varying")) + data_source_type = mapped_column(String(255)) + indexing_technique: Mapped[Optional[str]] = mapped_column(String(255)) + index_struct = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column - embedding_model_provider = db.Column(db.String(255), nullable=True) # TODO: mapped_column + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + embedding_model = db.Column(String(255), nullable=True) # TODO: mapped_column + embedding_model_provider = db.Column(String(255), nullable=True) # TODO: mapped_column collection_binding_id = mapped_column(StringUUID, nullable=True) retrieval_model = mapped_column(JSONB, nullable=True) - built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + built_in_field_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property def dataset_keyword_table(self): @@ -262,16 +263,16 @@ class Dataset(Base): class DatasetProcessRule(Base): __tablename__ = "dataset_process_rules" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), - db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), + sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) - rules = mapped_column(db.Text, nullable=True) + mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying")) + rules = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -302,72 +303,70 @@ class DatasetProcessRule(Base): class Document(Base): __tablename__ = "documents" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="document_pkey"), - db.Index("document_dataset_id_idx", "dataset_id"), - db.Index("document_is_paused_idx", "is_paused"), - db.Index("document_tenant_idx", "tenant_id"), - db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), + sa.PrimaryKeyConstraint("id", name="document_pkey"), + sa.Index("document_dataset_id_idx", "dataset_id"), + sa.Index("document_is_paused_idx", "is_paused"), + sa.Index("document_tenant_idx", "tenant_id"), + sa.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"), ) # initial fields - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) - data_source_type = mapped_column(db.String(255), nullable=False) - data_source_info = mapped_column(db.Text, nullable=True) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + data_source_type: Mapped[str] = mapped_column(String(255), nullable=False) + data_source_info = mapped_column(sa.Text, nullable=True) dataset_process_rule_id = mapped_column(StringUUID, nullable=True) - batch = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_from = mapped_column(db.String(255), nullable=False) + batch: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_api_request_id = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) # start processing - processing_started_at = mapped_column(db.DateTime, nullable=True) + processing_started_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # parsing - file_id = mapped_column(db.Text, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - parsing_completed_at = mapped_column(db.DateTime, nullable=True) + file_id = mapped_column(sa.Text, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) # TODO: make this not nullable + parsing_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # cleaning - cleaning_completed_at = mapped_column(db.DateTime, nullable=True) + cleaning_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # split - splitting_completed_at = mapped_column(db.DateTime, nullable=True) + splitting_completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # indexing - tokens = mapped_column(db.Integer, nullable=True) - indexing_latency = mapped_column(db.Float, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) + tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + indexing_latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # pause - is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + is_paused: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) paused_by = mapped_column(StringUUID, nullable=True) - paused_at = mapped_column(db.DateTime, nullable=True) + paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # error - error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + error = mapped_column(sa.Text, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # basic fields - indexing_status = mapped_column( - db.String(255), nullable=False, server_default=db.text("'waiting'::character varying") - ) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'::character varying")) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - archived_reason = mapped_column(db.String(255), nullable=True) + archived: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + archived_reason = mapped_column(String(255), nullable=True) archived_by = mapped_column(StringUUID, nullable=True) - archived_at = mapped_column(db.DateTime, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - doc_type = mapped_column(db.String(40), nullable=True) + archived_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + doc_type = mapped_column(String(40), nullable=True) doc_metadata = mapped_column(JSONB, nullable=True) - doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) - doc_language = mapped_column(db.String(255), nullable=True) + doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'::character varying")) + doc_language = mapped_column(String(255), nullable=True) DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @@ -524,7 +523,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.upload_date, "type": "time", - "value": self.created_at.timestamp(), + "value": str(self.created_at.timestamp()), } ) built_in_fields.append( @@ -532,7 +531,7 @@ class Document(Base): "id": "built-in", "name": BuiltInField.last_update_date, "type": "time", - "value": self.updated_at.timestamp(), + "value": str(self.updated_at.timestamp()), } ) built_in_fields.append( @@ -645,45 +644,45 @@ class Document(Base): class DocumentSegment(Base): __tablename__ = "document_segments" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="document_segment_pkey"), - db.Index("document_segment_dataset_id_idx", "dataset_id"), - db.Index("document_segment_document_id_idx", "document_id"), - db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), - db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), - db.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"), - db.Index("document_segment_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="document_segment_pkey"), + sa.Index("document_segment_dataset_id_idx", "dataset_id"), + sa.Index("document_segment_document_id_idx", "document_id"), + sa.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), + sa.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), + sa.Index("document_segment_node_dataset_idx", "index_node_id", "dataset_id"), + sa.Index("document_segment_tenant_idx", "tenant_id"), ) # initial fields - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) position: Mapped[int] - content = mapped_column(db.Text, nullable=False) - answer = mapped_column(db.Text, nullable=True) + content = mapped_column(sa.Text, nullable=False) + answer = mapped_column(sa.Text, nullable=True) word_count: Mapped[int] tokens: Mapped[int] # indexing fields - keywords = mapped_column(db.JSON, nullable=True) - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) + keywords = mapped_column(sa.JSON, nullable=True) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) # basic fields - hit_count = mapped_column(db.Integer, nullable=False, default=0) - enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - disabled_at = mapped_column(db.DateTime, nullable=True) + hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + disabled_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying")) + status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - error = mapped_column(db.Text, nullable=True) - stopped_at = mapped_column(db.DateTime, nullable=True) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + error = mapped_column(sa.Text, nullable=True) + stopped_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) @property def dataset(self): @@ -796,32 +795,36 @@ class DocumentSegment(Base): class ChildChunk(Base): __tablename__ = "child_chunks" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), - db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), - db.Index("child_chunks_node_idx", "index_node_id", "dataset_id"), - db.Index("child_chunks_segment_idx", "segment_id"), + sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"), + sa.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), + sa.Index("child_chunks_node_idx", "index_node_id", "dataset_id"), + sa.Index("child_chunks_segment_idx", "segment_id"), ) # initial fields - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) segment_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) - content = mapped_column(db.Text, nullable=False) - word_count = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + content = mapped_column(sa.Text, nullable=False) + word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) # indexing fields - index_node_id = mapped_column(db.String(255), nullable=True) - index_node_hash = mapped_column(db.String(255), nullable=True) - type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + index_node_id = mapped_column(String(255), nullable=True) + index_node_hash = mapped_column(String(255), nullable=True) + type = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'::character varying")) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - indexing_at = mapped_column(db.DateTime, nullable=True) - completed_at = mapped_column(db.DateTime, nullable=True) - error = mapped_column(db.Text, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) + indexing_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) + error = mapped_column(sa.Text, nullable=True) @property def dataset(self): @@ -839,14 +842,14 @@ class ChildChunk(Base): class AppDatasetJoin(Base): __tablename__ = "app_dataset_joins" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), - db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), + sa.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), + sa.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def app(self): @@ -856,32 +859,32 @@ class AppDatasetJoin(Base): class DatasetQuery(Base): __tablename__ = "dataset_queries" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), - db.Index("dataset_query_dataset_id_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_query_pkey"), + sa.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False) - content = mapped_column(db.Text, nullable=False) - source = mapped_column(db.String(255), nullable=False) + content = mapped_column(sa.Text, nullable=False) + source: Mapped[str] = mapped_column(String(255), nullable=False) source_app_id = mapped_column(StringUUID, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) class DatasetKeywordTable(Base): __tablename__ = "dataset_keyword_tables" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), - db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), + sa.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) dataset_id = mapped_column(StringUUID, nullable=False, unique=True) - keyword_table = mapped_column(db.Text, nullable=False) + keyword_table = mapped_column(sa.Text, nullable=False) data_source_type = mapped_column( - db.String(255), nullable=False, server_default=db.text("'database'::character varying") + String(255), nullable=False, server_default=sa.text("'database'::character varying") ) @property @@ -911,26 +914,26 @@ class DatasetKeywordTable(Base): return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder) return None except Exception as e: - logging.exception(f"Failed to load keyword table from file: {file_key}") + logging.exception("Failed to load keyword table from file: %s", file_key) return None class Embedding(Base): __tablename__ = "embeddings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="embedding_pkey"), - db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), - db.Index("created_at_idx", "created_at"), + sa.PrimaryKeyConstraint("id", name="embedding_pkey"), + sa.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), + sa.Index("created_at_idx", "created_at"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) model_name = mapped_column( - db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + String(255), nullable=False, server_default=sa.text("'text-embedding-ada-002'::character varying") ) - hash = mapped_column(db.String(64), nullable=False) - embedding = mapped_column(db.LargeBinary, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying")) + hash = mapped_column(String(64), nullable=False) + embedding = mapped_column(sa.LargeBinary, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name = mapped_column(String(255), nullable=False, server_default=sa.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -942,84 +945,84 @@ class Embedding(Base): class DatasetCollectionBinding(Base): __tablename__ = "dataset_collection_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), - db.Index("provider_model_name_idx", "provider_name", "model_name"), + sa.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), + sa.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) - provider_name = mapped_column(db.String(255), nullable=False) - model_name = mapped_column(db.String(255), nullable=False) - type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) - collection_name = mapped_column(db.String(64), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + type = mapped_column(String(40), server_default=sa.text("'dataset'::character varying"), nullable=False) + collection_name = mapped_column(String(64), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(Base): __tablename__ = "tidb_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), - db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), - db.Index("tidb_auth_bindings_active_idx", "active"), - db.Index("tidb_auth_bindings_created_at_idx", "created_at"), - db.Index("tidb_auth_bindings_status_idx", "status"), + sa.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), + sa.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), + sa.Index("tidb_auth_bindings_active_idx", "active"), + sa.Index("tidb_auth_bindings_created_at_idx", "created_at"), + sa.Index("tidb_auth_bindings_status_idx", "status"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - cluster_id = mapped_column(db.String(255), nullable=False) - cluster_name = mapped_column(db.String(255), nullable=False) - active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING")) - account = mapped_column(db.String(255), nullable=False) - password = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + cluster_id: Mapped[str] = mapped_column(String(255), nullable=False) + cluster_name: Mapped[str] = mapped_column(String(255), nullable=False) + active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) + status = mapped_column(String(255), nullable=False, server_default=db.text("'CREATING'::character varying")) + account: Mapped[str] = mapped_column(String(255), nullable=False) + password: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(Base): __tablename__ = "whitelists" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="whitelists_pkey"), - db.Index("whitelists_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="whitelists_pkey"), + sa.Index("whitelists_tenant_idx", "tenant_id"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - category = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + category: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(Base): __tablename__ = "dataset_permissions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), - db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), - db.Index("idx_dataset_permissions_account_id", "account_id"), - db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), + sa.Index("idx_dataset_permissions_dataset_id", "dataset_id"), + sa.Index("idx_dataset_permissions_account_id", "account_id"), + sa.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), primary_key=True) dataset_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) - has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + has_permission: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(Base): __tablename__ = "external_knowledge_apis" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), - db.Index("external_knowledge_apis_tenant_idx", "tenant_id"), - db.Index("external_knowledge_apis_name_idx", "name"), + sa.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), + sa.Index("external_knowledge_apis_tenant_idx", "tenant_id"), + sa.Index("external_knowledge_apis_name_idx", "name"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) - settings = mapped_column(db.Text, nullable=True) + settings = mapped_column(sa.Text, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -1059,71 +1062,79 @@ class ExternalKnowledgeApis(Base): class ExternalKnowledgeBindings(Base): __tablename__ = "external_knowledge_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), - db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), - db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), - db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), - db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), + sa.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), + sa.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), + sa.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), + sa.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), + sa.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) external_knowledge_api_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - external_knowledge_id = mapped_column(db.Text, nullable=False) + external_knowledge_id = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetAutoDisableLog(Base): __tablename__ = "dataset_auto_disable_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), - db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), - db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), - db.Index("dataset_auto_disable_log_created_atx", "created_at"), + sa.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), + sa.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), + sa.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), + sa.Index("dataset_auto_disable_log_created_atx", "created_at"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + notified: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) class RateLimitLog(Base): __tablename__ = "rate_limit_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), - db.Index("rate_limit_log_tenant_idx", "tenant_id"), - db.Index("rate_limit_log_operation_idx", "operation"), + sa.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), + sa.Index("rate_limit_log_tenant_idx", "tenant_id"), + sa.Index("rate_limit_log_operation_idx", "operation"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - subscription_plan = mapped_column(db.String(255), nullable=False) - operation = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + subscription_plan: Mapped[str] = mapped_column(String(255), nullable=False) + operation: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) class DatasetMetadata(Base): __tablename__ = "dataset_metadatas" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), - db.Index("dataset_metadata_tenant_idx", "tenant_id"), - db.Index("dataset_metadata_dataset_idx", "dataset_id"), + sa.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), + sa.Index("dataset_metadata_tenant_idx", "tenant_id"), + sa.Index("dataset_metadata_dataset_idx", "dataset_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) - name = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") + ) created_by = mapped_column(StringUUID, nullable=False) updated_by = mapped_column(StringUUID, nullable=True) @@ -1131,17 +1142,17 @@ class DatasetMetadata(Base): class DatasetMetadataBinding(Base): __tablename__ = "dataset_metadata_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), - db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"), - db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"), - db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), - db.Index("dataset_metadata_binding_document_idx", "document_id"), + sa.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), + sa.Index("dataset_metadata_binding_tenant_idx", "tenant_id"), + sa.Index("dataset_metadata_binding_dataset_idx", "dataset_id"), + sa.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), + sa.Index("dataset_metadata_binding_document_idx", "document_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) metadata_id = mapped_column(StringUUID, nullable=False) document_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) created_by = mapped_column(StringUUID, nullable=False) diff --git a/api/models/model.py b/api/models/model.py index a78a91ebd5..c4303f3cc5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: import sqlalchemy as sa from flask import request from flask_login import UserMixin -from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text +from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, func, text from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config @@ -32,16 +32,13 @@ from .engine import db from .enums import CreatorUserRole from .types import StringUUID -if TYPE_CHECKING: - from .workflow import Workflow - class DifySetup(Base): __tablename__ = "dify_setups" - __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) + __table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) - version = mapped_column(db.String(255), nullable=False) - setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + version: Mapped[str] = mapped_column(String(255), nullable=False) + setup_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -72,33 +69,33 @@ class IconType(Enum): class App(Base): __tablename__ = "apps" - __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_pkey"), sa.Index("app_tenant_id_idx", "tenant_id")) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) - name: Mapped[str] = mapped_column(db.String(255)) - description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying")) - mode: Mapped[str] = mapped_column(db.String(255)) - icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji - icon = db.Column(db.String(255)) - icon_background: Mapped[Optional[str]] = mapped_column(db.String(255)) + name: Mapped[str] = mapped_column(String(255)) + description: Mapped[str] = mapped_column(sa.Text, server_default=sa.text("''::character varying")) + mode: Mapped[str] = mapped_column(String(255)) + icon_type: Mapped[Optional[str]] = mapped_column(String(255)) # image, emoji + icon = db.Column(String(255)) + icon_background: Mapped[Optional[str]] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying")) - enable_site: Mapped[bool] = mapped_column(db.Boolean) - enable_api: Mapped[bool] = mapped_column(db.Boolean) - api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) - api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0")) - is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false")) - tracing = mapped_column(db.Text, nullable=True) + status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying")) + enable_site: Mapped[bool] = mapped_column(sa.Boolean) + enable_api: Mapped[bool] = mapped_column(sa.Boolean) + api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0")) + api_rph: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0")) + is_demo: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + is_public: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + is_universal: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false")) + tracing = mapped_column(sa.Text, nullable=True) max_active_requests: Mapped[Optional[int]] created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property def desc_or_prompt(self): @@ -305,36 +302,36 @@ class App(Base): class AppModelConfig(Base): __tablename__ = "app_model_configs" - __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id")) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) - configs = mapped_column(db.JSON, nullable=True) + provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) + configs = mapped_column(sa.JSON, nullable=True) created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - opening_statement = mapped_column(db.Text) - suggested_questions = mapped_column(db.Text) - suggested_questions_after_answer = mapped_column(db.Text) - speech_to_text = mapped_column(db.Text) - text_to_speech = mapped_column(db.Text) - more_like_this = mapped_column(db.Text) - model = mapped_column(db.Text) - user_input_form = mapped_column(db.Text) - dataset_query_variable = mapped_column(db.String(255)) - pre_prompt = mapped_column(db.Text) - agent_mode = mapped_column(db.Text) - sensitive_word_avoidance = mapped_column(db.Text) - retriever_resource = mapped_column(db.Text) - prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying")) - chat_prompt_config = mapped_column(db.Text) - completion_prompt_config = mapped_column(db.Text) - dataset_configs = mapped_column(db.Text) - external_data_tools = mapped_column(db.Text) - file_upload = mapped_column(db.Text) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + opening_statement = mapped_column(sa.Text) + suggested_questions = mapped_column(sa.Text) + suggested_questions_after_answer = mapped_column(sa.Text) + speech_to_text = mapped_column(sa.Text) + text_to_speech = mapped_column(sa.Text) + more_like_this = mapped_column(sa.Text) + model = mapped_column(sa.Text) + user_input_form = mapped_column(sa.Text) + dataset_query_variable = mapped_column(String(255)) + pre_prompt = mapped_column(sa.Text) + agent_mode = mapped_column(sa.Text) + sensitive_word_avoidance = mapped_column(sa.Text) + retriever_resource = mapped_column(sa.Text) + prompt_type = mapped_column(String(255), nullable=False, server_default=sa.text("'simple'::character varying")) + chat_prompt_config = mapped_column(sa.Text) + completion_prompt_config = mapped_column(sa.Text) + dataset_configs = mapped_column(sa.Text) + external_data_tools = mapped_column(sa.Text) + file_upload = mapped_column(sa.Text) @property def app(self): @@ -556,24 +553,24 @@ class AppModelConfig(Base): class RecommendedApp(Base): __tablename__ = "recommended_apps" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), - db.Index("recommended_app_app_id_idx", "app_id"), - db.Index("recommended_app_is_listed_idx", "is_listed", "language"), + sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"), + sa.Index("recommended_app_app_id_idx", "app_id"), + sa.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - description = mapped_column(db.JSON, nullable=False) - copyright = mapped_column(db.String(255), nullable=False) - privacy_policy = mapped_column(db.String(255), nullable=False) + description = mapped_column(sa.JSON, nullable=False) + copyright: Mapped[str] = mapped_column(String(255), nullable=False) + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False) custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - category = mapped_column(db.String(255), nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_listed = mapped_column(db.Boolean, nullable=False, default=True) - install_count = mapped_column(db.Integer, nullable=False, default=0) - language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + category: Mapped[str] = mapped_column(String(255), nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) + install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -584,20 +581,20 @@ class RecommendedApp(Base): class InstalledApp(Base): __tablename__ = "installed_apps" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="installed_app_pkey"), - db.Index("installed_app_tenant_id_idx", "tenant_id"), - db.Index("installed_app_app_id_idx", "app_id"), - db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), + sa.PrimaryKeyConstraint("id", name="installed_app_pkey"), + sa.Index("installed_app_tenant_id_idx", "tenant_id"), + sa.Index("installed_app_app_id_idx", "app_id"), + sa.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) app_owner_tenant_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False, default=0) - is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used_at = mapped_column(db.DateTime, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + is_pinned: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + last_used_at = mapped_column(sa.DateTime, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -613,47 +610,47 @@ class InstalledApp(Base): class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="conversation_pkey"), - db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), + sa.PrimaryKeyConstraint("id", name="conversation_pkey"), + sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) app_model_config_id = mapped_column(StringUUID, nullable=True) - model_provider = mapped_column(db.String(255), nullable=True) - override_model_configs = mapped_column(db.Text) - model_id = mapped_column(db.String(255), nullable=True) - mode: Mapped[str] = mapped_column(db.String(255)) - name = mapped_column(db.String(255), nullable=False) - summary = mapped_column(db.Text) - _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - introduction = mapped_column(db.Text) - system_instruction = mapped_column(db.Text) - system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - status = mapped_column(db.String(255), nullable=False) + model_provider = mapped_column(String(255), nullable=True) + override_model_configs = mapped_column(sa.Text) + model_id = mapped_column(String(255), nullable=True) + mode: Mapped[str] = mapped_column(String(255)) + name: Mapped[str] = mapped_column(String(255), nullable=False) + summary = mapped_column(sa.Text) + _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) + introduction = mapped_column(sa.Text) + system_instruction = mapped_column(sa.Text) + system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + status: Mapped[str] = mapped_column(String(255), nullable=False) # The `invoke_from` records how the conversation is created. # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = mapped_column(db.String(255), nullable=True) + invoke_from = mapped_column(String(255), nullable=True) # ref: ConversationSource. - from_source = mapped_column(db.String(255), nullable=False) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) - read_at = mapped_column(db.DateTime) + read_at = mapped_column(sa.DateTime) read_account_id = mapped_column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" ) - is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @property def inputs(self): @@ -895,36 +892,36 @@ class Message(Base): Index("message_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - model_provider = mapped_column(db.String(255), nullable=True) - model_id = mapped_column(db.String(255), nullable=True) - override_model_configs = mapped_column(db.Text) - conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) - _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) - query: Mapped[str] = mapped_column(db.Text, nullable=False) - message = mapped_column(db.JSON, nullable=False) - message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) - message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column - answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) - answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + model_provider = mapped_column(String(255), nullable=True) + model_id = mapped_column(String(255), nullable=True) + override_model_configs = mapped_column(sa.Text) + conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) + _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) + query: Mapped[str] = mapped_column(sa.Text, nullable=False) + message = mapped_column(sa.JSON, nullable=False) + message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + message_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False) + message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + answer: Mapped[str] = db.Column(sa.Text, nullable=False) # TODO make it mapped_column + answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) parent_message_id = mapped_column(StringUUID, nullable=True) - provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) - total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - error = mapped_column(db.Text) - message_metadata = mapped_column(db.Text) - invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) - from_source = mapped_column(db.String(255), nullable=False) + provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + total_price = mapped_column(sa.Numeric(10, 7)) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) + error = mapped_column(sa.Text) + message_metadata = mapped_column(sa.Text) + invoke_from: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID) from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) @property @@ -1231,23 +1228,23 @@ class Message(Base): class MessageFeedback(Base): __tablename__ = "message_feedbacks" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), - db.Index("message_feedback_app_idx", "app_id"), - db.Index("message_feedback_message_idx", "message_id", "from_source"), - db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), + sa.PrimaryKeyConstraint("id", name="message_feedback_pkey"), + sa.Index("message_feedback_app_idx", "app_id"), + sa.Index("message_feedback_message_idx", "message_id", "from_source"), + sa.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) conversation_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) - rating = mapped_column(db.String(255), nullable=False) - content = mapped_column(db.Text) - from_source = mapped_column(db.String(255), nullable=False) + rating: Mapped[str] = mapped_column(String(255), nullable=False) + content = mapped_column(sa.Text) + from_source: Mapped[str] = mapped_column(String(255), nullable=False) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): @@ -1273,9 +1270,9 @@ class MessageFeedback(Base): class MessageFile(Base): __tablename__ = "message_files" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_file_pkey"), - db.Index("message_file_message_idx", "message_id"), - db.Index("message_file_created_by_idx", "created_by"), + sa.PrimaryKeyConstraint("id", name="message_file_pkey"), + sa.Index("message_file_message_idx", "message_id"), + sa.Index("message_file_created_by_idx", "created_by"), ) def __init__( @@ -1299,37 +1296,37 @@ class MessageFile(Base): self.created_by_role = created_by_role.value self.created_by = created_by - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False) - url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True) + type: Mapped[str] = mapped_column(String(255), nullable=False) + transfer_method: Mapped[str] = mapped_column(String(255), nullable=False) + url: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + belongs_to: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageAnnotation(Base): __tablename__ = "message_annotations" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), - db.Index("message_annotation_app_idx", "app_id"), - db.Index("message_annotation_conversation_idx", "conversation_id"), - db.Index("message_annotation_message_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"), + sa.Index("message_annotation_app_idx", "app_id"), + sa.Index("message_annotation_conversation_idx", "conversation_id"), + sa.Index("message_annotation_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id: Mapped[str] = mapped_column(StringUUID) - conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id")) + conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) message_id: Mapped[Optional[str]] = mapped_column(StringUUID) - question = db.Column(db.Text, nullable=True) - content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) + question = db.Column(sa.Text, nullable=True) + content = mapped_column(sa.Text, nullable=False) + hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): @@ -1345,24 +1342,24 @@ class MessageAnnotation(Base): class AppAnnotationHitHistory(Base): __tablename__ = "app_annotation_hit_histories" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), - db.Index("app_annotation_hit_histories_app_idx", "app_id"), - db.Index("app_annotation_hit_histories_account_idx", "account_id"), - db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), - db.Index("app_annotation_hit_histories_message_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), + sa.Index("app_annotation_hit_histories_app_idx", "app_id"), + sa.Index("app_annotation_hit_histories_account_idx", "account_id"), + sa.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), + sa.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - source = mapped_column(db.Text, nullable=False) - question = mapped_column(db.Text, nullable=False) + source = mapped_column(sa.Text, nullable=False) + question = mapped_column(sa.Text, nullable=False) account_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - score = mapped_column(Float, nullable=False, server_default=db.text("0")) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + score = mapped_column(Float, nullable=False, server_default=sa.text("0")) message_id = mapped_column(StringUUID, nullable=False) - annotation_question = mapped_column(db.Text, nullable=False) - annotation_content = mapped_column(db.Text, nullable=False) + annotation_question = mapped_column(sa.Text, nullable=False) + annotation_content = mapped_column(sa.Text, nullable=False) @property def account(self): @@ -1383,18 +1380,18 @@ class AppAnnotationHitHistory(Base): class AppAnnotationSetting(Base): __tablename__ = "app_annotation_settings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), - db.Index("app_annotation_settings_app_idx", "app_id"), + sa.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), + sa.Index("app_annotation_settings_app_idx", "app_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0")) + score_threshold = mapped_column(Float, nullable=False, server_default=sa.text("0")) collection_binding_id = mapped_column(StringUUID, nullable=False) created_user_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_user_id = mapped_column(StringUUID, nullable=False) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def collection_binding_detail(self): @@ -1411,58 +1408,58 @@ class AppAnnotationSetting(Base): class OperationLog(Base): __tablename__ = "operation_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="operation_log_pkey"), - db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), + sa.PrimaryKeyConstraint("id", name="operation_log_pkey"), + sa.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False) - action = mapped_column(db.String(255), nullable=False) - content = mapped_column(db.JSON) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_ip = mapped_column(db.String(255), nullable=False) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + action: Mapped[str] = mapped_column(String(255), nullable=False) + content = mapped_column(sa.JSON) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_ip: Mapped[str] = mapped_column(String(255), nullable=False) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class EndUser(Base, UserMixin): __tablename__ = "end_users" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="end_user_pkey"), - db.Index("end_user_session_id_idx", "session_id", "type"), - db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), + sa.PrimaryKeyConstraint("id", name="end_user_pkey"), + sa.Index("end_user_session_id_idx", "session_id", "type"), + sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(255), nullable=False) - external_user_id = mapped_column(db.String(255), nullable=True) - name = mapped_column(db.String(255)) - is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + type: Mapped[str] = mapped_column(String(255), nullable=False) + external_user_id = mapped_column(String(255), nullable=True) + name = mapped_column(String(255)) + is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) session_id: Mapped[str] = mapped_column() - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMCPServer(Base): __tablename__ = "app_mcp_servers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"), - db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"), - db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"), + sa.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"), + sa.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"), + sa.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) - name = mapped_column(db.String(255), nullable=False) - description = mapped_column(db.String(255), nullable=False) - server_code = mapped_column(db.String(255), nullable=False) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - parameters = mapped_column(db.Text, nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(String(255), nullable=False) + server_code: Mapped[str] = mapped_column(String(255), nullable=False) + status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) + parameters = mapped_column(sa.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_server_code(n): @@ -1481,35 +1478,35 @@ class AppMCPServer(Base): class Site(Base): __tablename__ = "sites" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="site_pkey"), - db.Index("site_app_id_idx", "app_id"), - db.Index("site_code_idx", "code", "status"), + sa.PrimaryKeyConstraint("id", name="site_pkey"), + sa.Index("site_app_id_idx", "app_id"), + sa.Index("site_code_idx", "code", "status"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - title = mapped_column(db.String(255), nullable=False) - icon_type = mapped_column(db.String(255), nullable=True) - icon = mapped_column(db.String(255)) - icon_background = mapped_column(db.String(255)) - description = mapped_column(db.Text) - default_language = mapped_column(db.String(255), nullable=False) - chat_color_theme = mapped_column(db.String(255)) - chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - copyright = mapped_column(db.String(255)) - privacy_policy = mapped_column(db.String(255)) - show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) - use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + title: Mapped[str] = mapped_column(String(255), nullable=False) + icon_type = mapped_column(String(255), nullable=True) + icon = mapped_column(String(255)) + icon_background = mapped_column(String(255)) + description = mapped_column(sa.Text) + default_language: Mapped[str] = mapped_column(String(255), nullable=False) + chat_color_theme = mapped_column(String(255)) + chat_color_theme_inverted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + copyright = mapped_column(String(255)) + privacy_policy = mapped_column(String(255)) + show_workflow_steps: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") - customize_domain = mapped_column(db.String(255)) - customize_token_strategy = mapped_column(db.String(255), nullable=False) - prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + customize_domain = mapped_column(String(255)) + customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) + prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'::character varying")) created_by = mapped_column(StringUUID, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - code = mapped_column(db.String(255)) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + code = mapped_column(String(255)) @property def custom_disclaimer(self): @@ -1538,19 +1535,19 @@ class Site(Base): class ApiToken(Base): __tablename__ = "api_tokens" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="api_token_pkey"), - db.Index("api_token_app_id_type_idx", "app_id", "type"), - db.Index("api_token_token_idx", "token", "type"), - db.Index("api_token_tenant_idx", "tenant_id", "type"), + sa.PrimaryKeyConstraint("id", name="api_token_pkey"), + sa.Index("api_token_app_id_type_idx", "app_id", "type"), + sa.Index("api_token_token_idx", "token", "type"), + sa.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=True) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - token = mapped_column(db.String(255), nullable=False) - last_used_at = mapped_column(db.DateTime, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + type = mapped_column(String(16), nullable=False) + token: Mapped[str] = mapped_column(String(255), nullable=False) + last_used_at = mapped_column(sa.DateTime, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): @@ -1564,27 +1561,27 @@ class ApiToken(Base): class UploadFile(Base): __tablename__ = "upload_files" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="upload_file_pkey"), - db.Index("upload_file_tenant_idx", "tenant_id"), + sa.PrimaryKeyConstraint("id", name="upload_file_pkey"), + sa.Index("upload_file_tenant_idx", "tenant_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False) - key: Mapped[str] = mapped_column(db.String(255), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) - size: Mapped[int] = mapped_column(db.Integer, nullable=False) - extension: Mapped[str] = mapped_column(db.String(255), nullable=False) - mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True) + storage_type: Mapped[str] = mapped_column(String(255), nullable=False) + key: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + size: Mapped[int] = mapped_column(sa.Integer, nullable=False) + extension: Mapped[str] = mapped_column(String(255), nullable=False) + mime_type: Mapped[str] = mapped_column(String(255), nullable=True) created_by_role: Mapped[str] = mapped_column( - db.String(255), nullable=False, server_default=db.text("'account'::character varying") + String(255), nullable=False, server_default=sa.text("'account'::character varying") ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True) - hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True) + used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) + hash: Mapped[str | None] = mapped_column(String(255), nullable=True) source_url: Mapped[str] = mapped_column(sa.TEXT, default="") def __init__( @@ -1626,71 +1623,71 @@ class UploadFile(Base): class ApiRequest(Base): __tablename__ = "api_requests" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="api_request_pkey"), - db.Index("api_request_token_idx", "tenant_id", "api_token_id"), + sa.PrimaryKeyConstraint("id", name="api_request_pkey"), + sa.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) api_token_id = mapped_column(StringUUID, nullable=False) - path = mapped_column(db.String(255), nullable=False) - request = mapped_column(db.Text, nullable=True) - response = mapped_column(db.Text, nullable=True) - ip = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + path: Mapped[str] = mapped_column(String(255), nullable=False) + request = mapped_column(sa.Text, nullable=True) + response = mapped_column(sa.Text, nullable=True) + ip: Mapped[str] = mapped_column(String(255), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageChain(Base): __tablename__ = "message_chains" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_chain_pkey"), - db.Index("message_chain_message_id_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="message_chain_pkey"), + sa.Index("message_chain_message_id_idx", "message_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - type = mapped_column(db.String(255), nullable=False) - input = mapped_column(db.Text, nullable=True) - output = mapped_column(db.Text, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + type: Mapped[str] = mapped_column(String(255), nullable=False) + input = mapped_column(sa.Text, nullable=True) + output = mapped_column(sa.Text, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) class MessageAgentThought(Base): __tablename__ = "message_agent_thoughts" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), - db.Index("message_agent_thought_message_id_idx", "message_id"), - db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), + sa.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), + sa.Index("message_agent_thought_message_id_idx", "message_id"), + sa.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) message_chain_id = mapped_column(StringUUID, nullable=True) - position = mapped_column(db.Integer, nullable=False) - thought = mapped_column(db.Text, nullable=True) - tool = mapped_column(db.Text, nullable=True) - tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text")) - tool_input = mapped_column(db.Text, nullable=True) - observation = mapped_column(db.Text, nullable=True) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) + thought = mapped_column(sa.Text, nullable=True) + tool = mapped_column(sa.Text, nullable=True) + tool_labels_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) + tool_meta_str = mapped_column(sa.Text, nullable=False, server_default=sa.text("'{}'::text")) + tool_input = mapped_column(sa.Text, nullable=True) + observation = mapped_column(sa.Text, nullable=True) # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design - tool_process_data = mapped_column(db.Text, nullable=True) - message = mapped_column(db.Text, nullable=True) - message_token = mapped_column(db.Integer, nullable=True) - message_unit_price = mapped_column(db.Numeric, nullable=True) - message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - message_files = mapped_column(db.Text, nullable=True) - answer = db.Column(db.Text, nullable=True) - answer_token = mapped_column(db.Integer, nullable=True) - answer_unit_price = mapped_column(db.Numeric, nullable=True) - answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - tokens = mapped_column(db.Integer, nullable=True) - total_price = mapped_column(db.Numeric, nullable=True) - currency = mapped_column(db.String, nullable=True) - latency = mapped_column(db.Float, nullable=True) - created_by_role = mapped_column(db.String, nullable=False) + tool_process_data = mapped_column(sa.Text, nullable=True) + message = mapped_column(sa.Text, nullable=True) + message_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + message_unit_price = mapped_column(sa.Numeric, nullable=True) + message_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + message_files = mapped_column(sa.Text, nullable=True) + answer = db.Column(sa.Text, nullable=True) + answer_token: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + answer_unit_price = mapped_column(sa.Numeric, nullable=True) + answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + tokens: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + total_price = mapped_column(sa.Numeric, nullable=True) + currency = mapped_column(String, nullable=True) + latency: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + created_by_role = mapped_column(String, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) @property def files(self) -> list: @@ -1772,80 +1769,80 @@ class MessageAgentThought(Base): class DatasetRetrieverResource(Base): __tablename__ = "dataset_retriever_resources" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), - db.Index("dataset_retriever_resource_message_id_idx", "message_id"), + sa.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), + sa.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, nullable=False, server_default=sa.text("uuid_generate_v4()")) message_id = mapped_column(StringUUID, nullable=False) - position = mapped_column(db.Integer, nullable=False) + position: Mapped[int] = mapped_column(sa.Integer, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) - dataset_name = mapped_column(db.Text, nullable=False) + dataset_name = mapped_column(sa.Text, nullable=False) document_id = mapped_column(StringUUID, nullable=True) - document_name = mapped_column(db.Text, nullable=False) - data_source_type = mapped_column(db.Text, nullable=True) + document_name = mapped_column(sa.Text, nullable=False) + data_source_type = mapped_column(sa.Text, nullable=True) segment_id = mapped_column(StringUUID, nullable=True) - score = mapped_column(db.Float, nullable=True) - content = mapped_column(db.Text, nullable=False) - hit_count = mapped_column(db.Integer, nullable=True) - word_count = mapped_column(db.Integer, nullable=True) - segment_position = mapped_column(db.Integer, nullable=True) - index_node_hash = mapped_column(db.Text, nullable=True) - retriever_from = mapped_column(db.Text, nullable=False) + score: Mapped[Optional[float]] = mapped_column(sa.Float, nullable=True) + content = mapped_column(sa.Text, nullable=False) + hit_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + word_count: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + segment_position: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + index_node_hash = mapped_column(sa.Text, nullable=True) + retriever_from = mapped_column(sa.Text, nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) class Tag(Base): __tablename__ = "tags" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tag_pkey"), - db.Index("tag_type_idx", "type"), - db.Index("tag_name_idx", "name"), + sa.PrimaryKeyConstraint("id", name="tag_pkey"), + sa.Index("tag_type_idx", "type"), + sa.Index("tag_name_idx", "name"), ) TAG_TYPE_LIST = ["knowledge", "app"] - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) - type = mapped_column(db.String(16), nullable=False) - name = mapped_column(db.String(255), nullable=False) + type = mapped_column(String(16), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class TagBinding(Base): __tablename__ = "tag_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), - db.Index("tag_bind_target_id_idx", "target_id"), - db.Index("tag_bind_tag_id_idx", "tag_id"), + sa.PrimaryKeyConstraint("id", name="tag_binding_pkey"), + sa.Index("tag_bind_target_id_idx", "target_id"), + sa.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=True) tag_id = mapped_column(StringUUID, nullable=True) target_id = mapped_column(StringUUID, nullable=True) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) class TraceAppConfig(Base): __tablename__ = "trace_app_config" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), - db.Index("trace_app_config_app_id_idx", "app_id"), + sa.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), + sa.Index("trace_app_config_app_id_idx", "app_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) - tracing_provider = mapped_column(db.String(255), nullable=True) - tracing_config = mapped_column(db.JSON, nullable=True) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + tracing_provider = mapped_column(String(255), nullable=True) + tracing_config = mapped_column(sa.JSON, nullable=True) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) @property def tracing_config_dict(self): diff --git a/api/models/provider.py b/api/models/provider.py index 1e25f0c90f..4ea2c59fdb 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -2,11 +2,11 @@ from datetime import datetime from enum import Enum from typing import Optional -from sqlalchemy import func, text +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func, text from sqlalchemy.orm import Mapped, mapped_column from .base import Base -from .engine import db from .types import StringUUID @@ -47,31 +47,31 @@ class Provider(Base): __tablename__ = "providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_pkey"), - db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), - db.UniqueConstraint( + sa.PrimaryKeyConstraint("id", name="provider_pkey"), + sa.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), + sa.UniqueConstraint( "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" ), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) provider_type: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'custom'::character varying") + String(40), nullable=False, server_default=text("'custom'::character varying") ) - encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) + encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) + last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) quota_type: Mapped[Optional[str]] = mapped_column( - db.String(40), nullable=True, server_default=text("''::character varying") + String(40), nullable=True, server_default=text("''::character varying") ) - quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) - quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) + quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True) + quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -104,80 +104,80 @@ class ProviderModel(Base): __tablename__ = "provider_models" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_model_pkey"), - db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), - db.UniqueConstraint( + sa.PrimaryKeyConstraint("id", name="provider_model_pkey"), + sa.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), + sa.UniqueConstraint( "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" ), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): __tablename__ = "tenant_default_models" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), - db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), + sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), + sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), - db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), + sa.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), + sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): __tablename__ = "provider_orders" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_order_pkey"), - db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), + sa.PrimaryKeyConstraint("id", name="provider_order_pkey"), + sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) - payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) - transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) - quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) - currency: Mapped[Optional[str]] = mapped_column(db.String(40)) - total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) + payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(String(191)) + quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) + currency: Mapped[Optional[str]] = mapped_column(String(40)) + total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer) payment_status: Mapped[str] = mapped_column( - db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + String(40), nullable=False, server_default=text("'wait_pay'::character varying") ) - paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -187,19 +187,19 @@ class ProviderModelSetting(Base): __tablename__ = "provider_model_settings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), - db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), + sa.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), + sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) - load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) + load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -209,17 +209,17 @@ class LoadBalancingModelConfig(Base): __tablename__ = "load_balancing_model_configs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), - db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), + sa.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), + sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) - model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) - name: Mapped[str] = mapped_column(db.String(255), nullable=False) - encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + provider_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_name: Mapped[str] = mapped_column(String(255), nullable=False) + model_type: Mapped[str] = mapped_column(String(40), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 100e0d96ef..8456d65a87 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,49 +1,51 @@ import json +from datetime import datetime +from typing import Optional -from sqlalchemy import func +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Mapped, mapped_column from models.base import Base -from .engine import db from .types import StringUUID class DataSourceOauthBinding(Base): __tablename__ = "data_source_oauth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="source_binding_pkey"), - db.Index("source_binding_tenant_id_idx", "tenant_id"), - db.Index("source_info_idx", "source_info", postgresql_using="gin"), + sa.PrimaryKeyConstraint("id", name="source_binding_pkey"), + sa.Index("source_binding_tenant_id_idx", "tenant_id"), + sa.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - access_token = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) + access_token: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) source_info = mapped_column(JSONB, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) class DataSourceApiKeyAuthBinding(Base): __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), - db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), - db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), + sa.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), + sa.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), + sa.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id = mapped_column(StringUUID, nullable=False) - category = mapped_column(db.String(255), nullable=False) - provider = mapped_column(db.String(255), nullable=False) - credentials = mapped_column(db.Text, nullable=True) # JSON - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false")) + category: Mapped[str] = mapped_column(String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) + credentials = mapped_column(sa.Text, nullable=True) # JSON + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + disabled: Mapped[Optional[bool]] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false")) def to_dict(self): return { diff --git a/api/models/task.py b/api/models/task.py index 3e5ebd2099..9a52fcfb41 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,7 +1,9 @@ from datetime import datetime from typing import Optional -from celery import states # type: ignore +import sqlalchemy as sa +from celery import states +from sqlalchemy import DateTime, String from sqlalchemy.orm import Mapped, mapped_column from libs.datetime_utils import naive_utc_now @@ -15,23 +17,23 @@ class CeleryTask(Base): __tablename__ = "celery_taskmeta" - id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) - task_id = mapped_column(db.String(155), unique=True) - status = mapped_column(db.String(50), default=states.PENDING) + id = mapped_column(sa.Integer, sa.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) + task_id = mapped_column(String(155), unique=True) + status = mapped_column(String(50), default=states.PENDING) result = mapped_column(db.PickleType, nullable=True) date_done = mapped_column( - db.DateTime, + DateTime, default=lambda: naive_utc_now(), onupdate=lambda: naive_utc_now(), nullable=True, ) - traceback = mapped_column(db.Text, nullable=True) - name = mapped_column(db.String(155), nullable=True) - args = mapped_column(db.LargeBinary, nullable=True) - kwargs = mapped_column(db.LargeBinary, nullable=True) - worker = mapped_column(db.String(155), nullable=True) - retries = mapped_column(db.Integer, nullable=True) - queue = mapped_column(db.String(155), nullable=True) + traceback = mapped_column(sa.Text, nullable=True) + name = mapped_column(String(155), nullable=True) + args = mapped_column(sa.LargeBinary, nullable=True) + kwargs = mapped_column(sa.LargeBinary, nullable=True) + worker = mapped_column(String(155), nullable=True) + retries: Mapped[Optional[int]] = mapped_column(sa.Integer, nullable=True) + queue = mapped_column(String(155), nullable=True) class CeleryTaskSet(Base): @@ -40,8 +42,8 @@ class CeleryTaskSet(Base): __tablename__ = "celery_tasksetmeta" id: Mapped[int] = mapped_column( - db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True + sa.Integer, sa.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True ) - taskset_id = mapped_column(db.String(155), unique=True) + taskset_id = mapped_column(String(155), unique=True) result = mapped_column(db.PickleType, nullable=True) - date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) + date_done: Mapped[Optional[datetime]] = mapped_column(DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 68f4211e59..e0c9fa6ffc 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -5,7 +5,7 @@ from urllib.parse import urlparse import sqlalchemy as sa from deprecated import deprecated -from sqlalchemy import ForeignKey, func +from sqlalchemy import ForeignKey, String, func from sqlalchemy.orm import Mapped, mapped_column from core.file import helpers as file_helpers @@ -25,33 +25,33 @@ from .types import StringUUID class ToolOAuthSystemClient(Base): __tablename__ = "tool_oauth_system_clients" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), - db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), + sa.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), + sa.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + plugin_id = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # oauth params of the tool provider - encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) # tenant level tool oauth client params (client_id, client_secret, etc.) class ToolOAuthTenantClient(Base): __tablename__ = "tool_oauth_tenant_clients" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), - db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), + sa.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), + sa.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) - provider: Mapped[str] = mapped_column(db.String(255), nullable=False) - enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + plugin_id: Mapped[str] = mapped_column(String(512), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) # oauth params of the tool provider - encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) @property def oauth_params(self) -> dict: @@ -65,35 +65,35 @@ class BuiltinToolProvider(Base): __tablename__ = "tool_builtin_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), + sa.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), + sa.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), ) # id of the tool provider - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) name: Mapped[str] = mapped_column( - db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + String(256), nullable=False, server_default=sa.text("'API KEY 1'::character varying") ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # name of the tool provider - provider: Mapped[str] = mapped_column(db.String(256), nullable=False) + provider: Mapped[str] = mapped_column(String(256), nullable=False) # credential of the tool provider - encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) - is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) # credential type, e.g., "api-key", "oauth2" credential_type: Mapped[str] = mapped_column( - db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") + String(32), nullable=False, server_default=sa.text("'api-key'::character varying") ) - expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) + expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) @property def credentials(self) -> dict: @@ -107,35 +107,35 @@ class ApiToolProvider(Base): __tablename__ = "tool_api_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), - db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), + sa.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + sa.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the api provider - name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) + name = mapped_column(String(255), nullable=False, server_default=sa.text("'API KEY 1'::character varying")) # icon - icon = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # original schema - schema = mapped_column(db.Text, nullable=False) - schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False) + schema = mapped_column(sa.Text, nullable=False) + schema_type_str: Mapped[str] = mapped_column(String(40), nullable=False) # who created this tool user_id = mapped_column(StringUUID, nullable=False) # tenant id tenant_id = mapped_column(StringUUID, nullable=False) # description of the provider - description = mapped_column(db.Text, nullable=False) + description = mapped_column(sa.Text, nullable=False) # json format tools - tools_str = mapped_column(db.Text, nullable=False) + tools_str = mapped_column(sa.Text, nullable=False) # json format credentials - credentials_str = mapped_column(db.Text, nullable=False) + credentials_str = mapped_column(sa.Text, nullable=False) # privacy policy - privacy_policy = mapped_column(db.String(255), nullable=True) + privacy_policy = mapped_column(String(255), nullable=True) # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -167,17 +167,17 @@ class ToolLabelBinding(Base): __tablename__ = "tool_label_bindings" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), - db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), + sa.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), + sa.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # tool id - tool_id: Mapped[str] = mapped_column(db.String(64), nullable=False) + tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + tool_type: Mapped[str] = mapped_column(String(40), nullable=False) # label name - label_name: Mapped[str] = mapped_column(db.String(40), nullable=False) + label_name: Mapped[str] = mapped_column(String(40), nullable=False) class WorkflowToolProvider(Base): @@ -187,38 +187,38 @@ class WorkflowToolProvider(Base): __tablename__ = "tool_workflow_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), - db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), - db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), + sa.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), + sa.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), + sa.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the workflow provider - name: Mapped[str] = mapped_column(db.String(255), nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) # label of the workflow provider - label: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + label: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # icon - icon: Mapped[str] = mapped_column(db.String(255), nullable=False) + icon: Mapped[str] = mapped_column(String(255), nullable=False) # app id of the workflow provider app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # version of the workflow provider - version: Mapped[str] = mapped_column(db.String(255), nullable=False, server_default="") + version: Mapped[str] = mapped_column(String(255), nullable=False, server_default="") # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # description of the provider - description: Mapped[str] = mapped_column(db.Text, nullable=False) + description: Mapped[str] = mapped_column(sa.Text, nullable=False) # parameter configuration - parameter_configuration: Mapped[str] = mapped_column(db.Text, nullable=False, server_default="[]") + parameter_configuration: Mapped[str] = mapped_column(sa.Text, nullable=False, server_default="[]") # privacy policy - privacy_policy: Mapped[str] = mapped_column(db.String(255), nullable=True, server_default="") + privacy_policy: Mapped[str] = mapped_column(String(255), nullable=True, server_default="") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) @property @@ -245,39 +245,41 @@ class MCPToolProvider(Base): __tablename__ = "tool_mcp_providers" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"), - db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"), - db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"), - db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"), + sa.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"), + sa.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"), + sa.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"), + sa.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # name of the mcp provider - name: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(String(40), nullable=False) # server identifier of the mcp provider - server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_identifier: Mapped[str] = mapped_column(String(64), nullable=False) # encrypted url of the mcp provider - server_url: Mapped[str] = mapped_column(db.Text, nullable=False) + server_url: Mapped[str] = mapped_column(sa.Text, nullable=False) # hash of server_url for uniqueness check - server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False) + server_url_hash: Mapped[str] = mapped_column(String(64), nullable=False) # icon of the mcp provider - icon: Mapped[str] = mapped_column(db.String(255), nullable=True) + icon: Mapped[str] = mapped_column(String(255), nullable=True) # tenant id tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who created this tool user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # encrypted credentials - encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True) + encrypted_credentials: Mapped[str] = mapped_column(sa.Text, nullable=True) # authed - authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False) + authed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) # tools - tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]") + tools: Mapped[str] = mapped_column(sa.Text, nullable=False, default="[]") created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)") ) + timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("30")) + sse_read_timeout: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("300")) def load_user(self) -> Account | None: return db.session.query(Account).where(Account.id == self.user_id).first() @@ -347,35 +349,35 @@ class ToolModelInvoke(Base): """ __tablename__ = "tool_model_invokes" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) + __table_args__ = (sa.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # who invoke this tool user_id = mapped_column(StringUUID, nullable=False) # tenant id tenant_id = mapped_column(StringUUID, nullable=False) # provider - provider = mapped_column(db.String(255), nullable=False) + provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type = mapped_column(db.String(40), nullable=False) + tool_type = mapped_column(String(40), nullable=False) # tool name - tool_name = mapped_column(db.String(128), nullable=False) + tool_name = mapped_column(String(128), nullable=False) # invoke parameters - model_parameters = mapped_column(db.Text, nullable=False) + model_parameters = mapped_column(sa.Text, nullable=False) # prompt messages - prompt_messages = mapped_column(db.Text, nullable=False) + prompt_messages = mapped_column(sa.Text, nullable=False) # invoke response - model_response = mapped_column(db.Text, nullable=False) + model_response = mapped_column(sa.Text, nullable=False) - prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0")) - answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False) - answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) - provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0")) - total_price = mapped_column(db.Numeric(10, 7)) - currency = mapped_column(db.String(255), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + prompt_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + answer_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + answer_unit_price = mapped_column(sa.Numeric(10, 4), nullable=False) + answer_price_unit = mapped_column(sa.Numeric(10, 7), nullable=False, server_default=sa.text("0.001")) + provider_response_latency = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) + total_price = mapped_column(sa.Numeric(10, 7)) + currency: Mapped[str] = mapped_column(String(255), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @deprecated @@ -386,13 +388,13 @@ class ToolConversationVariables(Base): __tablename__ = "tool_conversation_variables" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), + sa.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), # add index for user_id and conversation_id - db.Index("user_id_idx", "user_id"), - db.Index("conversation_id_idx", "conversation_id"), + sa.Index("user_id_idx", "user_id"), + sa.Index("conversation_id_idx", "conversation_id"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # conversation user id user_id = mapped_column(StringUUID, nullable=False) # tenant id @@ -400,10 +402,10 @@ class ToolConversationVariables(Base): # conversation id conversation_id = mapped_column(StringUUID, nullable=False) # variables pool - variables_str = mapped_column(db.Text, nullable=False) + variables_str = mapped_column(sa.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> Any: @@ -417,11 +419,11 @@ class ToolFile(Base): __tablename__ = "tool_files" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_file_pkey"), - db.Index("tool_file_conversation_id_idx", "conversation_id"), + sa.PrimaryKeyConstraint("id", name="tool_file_pkey"), + sa.Index("tool_file_conversation_id_idx", "conversation_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # conversation user id user_id: Mapped[str] = mapped_column(StringUUID) # tenant id @@ -429,11 +431,11 @@ class ToolFile(Base): # conversation id conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # file key - file_key: Mapped[str] = mapped_column(db.String(255), nullable=False) + file_key: Mapped[str] = mapped_column(String(255), nullable=False) # mime type - mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) + mimetype: Mapped[str] = mapped_column(String(255), nullable=False) # original url - original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) + original_url: Mapped[str] = mapped_column(String(2048), nullable=True) # name name: Mapped[str] = mapped_column(default="") # size @@ -448,30 +450,30 @@ class DeprecatedPublishedAppTool(Base): __tablename__ = "tool_published_apps" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), - db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), + sa.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), + sa.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) # id of the app app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False) user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # who published this tool - description = mapped_column(db.Text, nullable=False) + description = mapped_column(sa.Text, nullable=False) # llm_description of the tool, for LLM - llm_description = mapped_column(db.Text, nullable=False) + llm_description = mapped_column(sa.Text, nullable=False) # query description, query will be seem as a parameter of the tool, # to describe this parameter to llm, we need this field - query_description = mapped_column(db.Text, nullable=False) + query_description = mapped_column(sa.Text, nullable=False) # query name, the name of the query parameter - query_name = mapped_column(db.String(40), nullable=False) + query_name = mapped_column(String(40), nullable=False) # name of the tool provider - tool_name = mapped_column(db.String(40), nullable=False) + tool_name = mapped_column(String(40), nullable=False) # author - author = mapped_column(db.String(40), nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + author = mapped_column(String(40), nullable=False) + created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")) + updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: diff --git a/api/models/web.py b/api/models/web.py index ce00f4010f..74f99e187b 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,7 @@ -from sqlalchemy import func +from datetime import datetime + +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func from sqlalchemy.orm import Mapped, mapped_column from models.base import Base @@ -11,18 +14,18 @@ from .types import StringUUID class SavedMessage(Base): __tablename__ = "saved_messages" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="saved_message_pkey"), - db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), + sa.PrimaryKeyConstraint("id", name="saved_message_pkey"), + sa.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) message_id = mapped_column(StringUUID, nullable=False) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=sa.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -32,15 +35,15 @@ class SavedMessage(Base): class PinnedConversation(Base): __tablename__ = "pinned_conversations" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), - db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), + sa.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), + sa.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) app_id = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = mapped_column( - db.String(255), nullable=False, server_default=db.text("'end_user'::character varying") + String(255), nullable=False, server_default=sa.text("'end_user'::character varying") ) created_by = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index 79d96e42dd..2fea3fcd78 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -6,8 +6,8 @@ from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 -from flask_login import current_user -from sqlalchemy import orm +import sqlalchemy as sa +from sqlalchemy import DateTime, orm from core.file.constants import maybe_file_object from core.file.models import File @@ -17,15 +17,13 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now -from libs.helper import extract_tenant_id from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: from models.model import AppMode -import sqlalchemy as sa -from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func +from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func from sqlalchemy.orm import Mapped, declared_attr, mapped_column from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE @@ -42,9 +40,6 @@ from .types import EnumText, StringUUID _logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from models.model import AppMode - class WorkflowType(Enum): """ @@ -120,33 +115,33 @@ class Workflow(Base): __tablename__ = "workflows" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_pkey"), - db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), + sa.PrimaryKeyConstraint("id", name="workflow_pkey"), + sa.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(db.String(255), nullable=False) - version: Mapped[str] = mapped_column(db.String(255), nullable=False) + type: Mapped[str] = mapped_column(String(255), nullable=False) + version: Mapped[str] = mapped_column(String(255), nullable=False) marked_name: Mapped[str] = mapped_column(default="", server_default="") marked_comment: Mapped[str] = mapped_column(default="", server_default="") graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=naive_utc_now(), server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( - "environment_variables", db.Text, nullable=False, server_default="{}" + "environment_variables", sa.Text, nullable=False, server_default="{}" ) _conversation_variables: Mapped[str] = mapped_column( - "conversation_variables", db.Text, nullable=False, server_default="{}" + "conversation_variables", sa.Text, nullable=False, server_default="{}" ) VERSION_DRAFT = "draft" @@ -354,8 +349,8 @@ class Workflow(Base): if self._environment_variables is None: self._environment_variables = "{}" - # Get tenant_id from current_user (Account or EndUser) - tenant_id = extract_tenant_id(current_user) + # Use workflow.tenant_id to avoid relying on request user in background threads + tenant_id = self.tenant_id if not tenant_id: return [] @@ -385,8 +380,8 @@ class Workflow(Base): self._environment_variables = "{}" return - # Get tenant_id from current_user (Account or EndUser) - tenant_id = extract_tenant_id(current_user) + # Use workflow.tenant_id to avoid relying on request user in background threads + tenant_id = self.tenant_id if not tenant_id: self._environment_variables = "{}" @@ -494,31 +489,31 @@ class WorkflowRun(Base): __tablename__ = "workflow_runs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), - db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + sa.PrimaryKeyConstraint("id", name="workflow_run_pkey"), + sa.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - type: Mapped[str] = mapped_column(db.String(255)) - triggered_from: Mapped[str] = mapped_column(db.String(255)) - version: Mapped[str] = mapped_column(db.String(255)) - graph: Mapped[Optional[str]] = mapped_column(db.Text) - inputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded + type: Mapped[str] = mapped_column(String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) + version: Mapped[str] = mapped_column(String(255)) + graph: Mapped[Optional[str]] = mapped_column(sa.Text) + inputs: Mapped[Optional[str]] = mapped_column(sa.Text) + status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error: Mapped[Optional[str]] = mapped_column(db.Text) - elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) + error: Mapped[Optional[str]] = mapped_column(sa.Text) + elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) - total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) - created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user + total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) + created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) - exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) @property def created_by_account(self): @@ -707,29 +702,29 @@ class WorkflowNodeExecutionModel(Base): ), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - triggered_from: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(String(255)) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) - index: Mapped[int] = mapped_column(db.Integer) - predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) - node_id: Mapped[str] = mapped_column(db.String(255)) - node_type: Mapped[str] = mapped_column(db.String(255)) - title: Mapped[str] = mapped_column(db.String(255)) - inputs: Mapped[Optional[str]] = mapped_column(db.Text) - process_data: Mapped[Optional[str]] = mapped_column(db.Text) - outputs: Mapped[Optional[str]] = mapped_column(db.Text) - status: Mapped[str] = mapped_column(db.String(255)) - error: Mapped[Optional[str]] = mapped_column(db.Text) - elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) - execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) - created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) - created_by_role: Mapped[str] = mapped_column(db.String(255)) + index: Mapped[int] = mapped_column(sa.Integer) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(String(255)) + node_id: Mapped[str] = mapped_column(String(255)) + node_type: Mapped[str] = mapped_column(String(255)) + title: Mapped[str] = mapped_column(String(255)) + inputs: Mapped[Optional[str]] = mapped_column(sa.Text) + process_data: Mapped[Optional[str]] = mapped_column(sa.Text) + outputs: Mapped[Optional[str]] = mapped_column(sa.Text) + status: Mapped[str] = mapped_column(String(255)) + error: Mapped[Optional[str]] = mapped_column(sa.Text) + elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) + execution_metadata: Mapped[Optional[str]] = mapped_column(sa.Text) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(String(255)) created_by: Mapped[str] = mapped_column(StringUUID) - finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime) @property def created_by_account(self): @@ -837,19 +832,19 @@ class WorkflowAppLog(Base): __tablename__ = "workflow_app_logs" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), - db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), + sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), + sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from: Mapped[str] = mapped_column(db.String(255), nullable=False) - created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -867,6 +862,19 @@ class WorkflowAppLog(Base): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "workflow_id": self.workflow_id, + "workflow_run_id": self.workflow_run_id, + "created_from": self.created_from, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at, + } + class ConversationVariable(Base): __tablename__ = "workflow_conversation_variables" @@ -874,12 +882,12 @@ class ConversationVariable(Base): id: Mapped[str] = mapped_column(StringUUID, primary_key=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) - data: Mapped[str] = mapped_column(db.Text, nullable=False) + data: Mapped[str] = mapped_column(sa.Text, nullable=False) created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True + DateTime, nullable=False, server_default=func.current_timestamp(), index=True ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: @@ -936,17 +944,17 @@ class WorkflowDraftVariable(Base): __allow_unmapped__ = True # id is the unique identifier of a draft variable. - id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")) created_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), ) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, + DateTime, nullable=False, default=_naive_utc_datetime, server_default=func.current_timestamp(), @@ -961,7 +969,7 @@ class WorkflowDraftVariable(Base): # # If it's not edited after creation, its value is `None`. last_edited_at: Mapped[datetime | None] = mapped_column( - db.DateTime, + DateTime, nullable=True, default=None, ) @@ -1143,7 +1151,7 @@ class WorkflowDraftVariable(Base): value: The Segment object to store as the variable's value. """ self.__value = value - self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder) + self.value = variable_utils.dumps_with_segments(value) self.value_type = value.value_type def get_node_id(self) -> str | None: diff --git a/api/mypy.ini b/api/mypy.ini index 6836b2602b..44a01068e9 100644 --- a/api/mypy.ini +++ b/api/mypy.ini @@ -5,16 +5,18 @@ check_untyped_defs = True cache_fine_grained = True sqlite_cache = True exclude = (?x)( - core/model_runtime/model_providers/ - | tests/ + tests/ | migrations/ ) [mypy-flask_login] ignore_missing_imports=True -[mypy-flask_restful] +[mypy-flask_restx] ignore_missing_imports=True -[mypy-flask_restful.inputs] +[mypy-flask_restx.api] +ignore_missing_imports=True + +[mypy-flask_restx.inputs] ignore_missing_imports=True diff --git a/api/pyproject.toml b/api/pyproject.toml index 7ec8a91198..6aa4746d2f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.7.0" +version = "1.7.2" requires-python = ">=3.11,<3.13" dependencies = [ @@ -13,12 +13,12 @@ dependencies = [ "cachetools~=5.3.0", "celery~=5.5.2", "chardet~=5.1.0", - "flask~=3.1.0", + "flask~=3.1.2", "flask-compress~=1.17", "flask-cors~=6.0.0", "flask-login~=0.6.3", "flask-migrate~=4.0.7", - "flask-restful~=0.3.10", + "flask-orjson~=2.0.0", "flask-sqlalchemy~=3.1.1", "gevent~=24.11.1", "gmpy2~=2.2.1", @@ -49,6 +49,8 @@ dependencies = [ "opentelemetry-instrumentation==0.48b0", "opentelemetry-instrumentation-celery==0.48b0", "opentelemetry-instrumentation-flask==0.48b0", + "opentelemetry-instrumentation-redis==0.48b0", + "opentelemetry-instrumentation-requests==0.48b0", "opentelemetry-instrumentation-sqlalchemy==0.48b0", "opentelemetry-propagator-b3==1.27.0", # opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0), @@ -85,6 +87,7 @@ dependencies = [ "sseclient-py>=1.8.0", "httpx-sse>=0.4.0", "sendgrid~=6.12.3", + "flask-restx>=1.3.0", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -107,13 +110,14 @@ dev = [ "dotenv-linter~=0.5.0", "faker~=32.1.0", "lxml-stubs~=0.5.1", - "mypy~=1.16.0", + "mypy~=1.17.1", "ruff~=0.12.3", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", "pytest-cov~=4.1.0", "pytest-env~=1.1.3", "pytest-mock~=3.14.0", + "testcontainers~=4.10.0", "types-aiofiles~=24.1.0", "types-beautifulsoup4~=4.12.0", "types-cachetools~=5.5.0", @@ -159,6 +163,8 @@ dev = [ "pandas-stubs~=2.2.3", "scipy-stubs>=1.15.3.0", "types-python-http-client>=3.3.7.20240910", + "types-redis>=4.6.0.20241004", + "celery-types>=0.23.0", ] ############################################################ @@ -191,6 +197,7 @@ vdb = [ "alibabacloud_tea_openapi~=0.3.9", "chromadb==0.5.20", "clickhouse-connect~=0.7.16", + "clickzetta-connector-python>=0.8.102", "couchbase~=4.3.0", "elasticsearch==8.14.0", "opensearch-py==2.4.0", @@ -199,7 +206,7 @@ vdb = [ "pgvector==0.2.5", "pymilvus~=2.5.0", "pymochow==1.3.1", - "pyobvector~=0.1.6", + "pyobvector~=0.2.15", "qdrant-client==1.9.0", "tablestore==6.2.0", "tcvectordb~=1.6.4", @@ -210,3 +217,4 @@ vdb = [ "xinference-client~=1.2.2", "mo-vector~=0.1.13", ] + diff --git a/api/repositories/factory.py b/api/repositories/factory.py index 0a0adbf2c2..0be9c8908c 100644 --- a/api/repositories/factory.py +++ b/api/repositories/factory.py @@ -5,17 +5,14 @@ This factory is specifically designed for DifyAPI repositories that handle service-layer operations with dependency injection patterns. """ -import logging - from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError +from libs.module_loading import import_string from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository from repositories.api_workflow_run_repository import APIWorkflowRunRepository -logger = logging.getLogger(__name__) - class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): """ @@ -48,20 +45,11 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): RepositoryImportError: If the configured repository cannot be imported or instantiated """ class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY - logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}") try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository) - # Service repository requires session_maker parameter - cls._validate_constructor_signature(repository_class, ["session_maker"]) - + repository_class = import_string(class_path) return repository_class(session_maker=session_maker) # type: ignore[no-any-return] - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository") + except (ImportError, Exception) as e: raise RepositoryImportError( f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" ) from e @@ -86,18 +74,9 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): RepositoryImportError: If the configured repository cannot be imported or instantiated """ class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY - logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}") try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, APIWorkflowRunRepository) - # Service repository requires session_maker parameter - cls._validate_constructor_signature(repository_class, ["session_maker"]) - + repository_class = import_string(class_path) return repository_class(session_maker=session_maker) # type: ignore[no-any-return] - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create APIWorkflowRunRepository") + except (ImportError, Exception) as e: raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index ebd1d74b20..7c3b1f4ce0 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -155,7 +155,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): session.commit() deleted_count = cast(int, result.rowcount) - logger.info(f"Deleted {deleted_count} workflow runs by IDs") + logger.info("Deleted %s workflow runs by IDs", deleted_count) return deleted_count def delete_runs_by_app( @@ -193,11 +193,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): batch_deleted = result.rowcount total_deleted += batch_deleted - logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}") + logger.info("Deleted batch of %s workflow runs for app %s", batch_deleted, app_id) # If we deleted fewer records than the batch size, we're done if batch_deleted < batch_size: break - logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}") + logger.info("Total deleted %s workflow runs for app %s", total_deleted, app_id) return total_deleted diff --git a/api/schedule/check_upgradable_plugin_task.py b/api/schedule/check_upgradable_plugin_task.py index c1d6018827..e27391b558 100644 --- a/api/schedule/check_upgradable_plugin_task.py +++ b/api/schedule/check_upgradable_plugin_task.py @@ -16,7 +16,7 @@ def check_upgradable_plugin_task(): start_at = time.perf_counter() now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC - click.echo(click.style("Now seconds of day: {}".format(now_seconds_of_day), fg="green")) + click.echo(click.style(f"Now seconds of day: {now_seconds_of_day}", fg="green")) strategies = ( db.session.query(TenantPluginAutoUpgradeStrategy) @@ -43,7 +43,7 @@ def check_upgradable_plugin_task(): end_at = time.perf_counter() click.echo( click.style( - "Checked upgradable plugin success latency: {}".format(end_at - start_at), + f"Checked upgradable plugin success latency: {end_at - start_at}", fg="green", ) ) diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 024e3d6f50..2b74fb2dd0 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -3,7 +3,7 @@ import time import click from sqlalchemy import text -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -27,8 +27,8 @@ def clean_embedding_cache_task(): .all() ) embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] - except NotFound: - break + except SQLAlchemyError: + raise if embedding_ids: for embedding_id in embedding_ids: db.session.execute( @@ -39,4 +39,4 @@ def clean_embedding_cache_task(): else: break end_at = time.perf_counter() - click.echo(click.style("Cleaned embedding cache from db success latency: {}".format(end_at - start_at), fg="green")) + click.echo(click.style(f"Cleaned embedding cache from db success latency: {end_at - start_at}", fg="green")) diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index a6851e36e5..a896c818a5 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -3,7 +3,7 @@ import logging import time import click -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -42,8 +42,8 @@ def clean_messages(): .all() ) - except NotFound: - break + except SQLAlchemyError: + raise if not messages: break for message in messages: @@ -87,4 +87,4 @@ def clean_messages(): db.session.query(Message).where(Message.id == message.id).delete() db.session.commit() end_at = time.perf_counter() - click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green")) + click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 72e2e73e65..1141451011 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -1,9 +1,10 @@ import datetime import time +from typing import Optional, TypedDict import click from sqlalchemy import func, select -from werkzeug.exceptions import NotFound +from sqlalchemy.exc import SQLAlchemyError import app from configs import dify_config @@ -14,174 +15,140 @@ from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Documen from services.feature_service import FeatureService +class CleanupConfig(TypedDict): + clean_day: datetime.datetime + plan_filter: Optional[str] + add_logs: bool + + @app.celery.task(queue="dataset") def clean_unused_datasets_task(): click.echo(click.style("Start clean unused datasets indexes.", fg="green")) - plan_sandbox_clean_day_setting = dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING - plan_pro_clean_day_setting = dify_config.PLAN_PRO_CLEAN_DAY_SETTING start_at = time.perf_counter() - plan_sandbox_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_sandbox_clean_day_setting) - plan_pro_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_pro_clean_day_setting) - while True: - try: - # Subquery for counting new documents - document_subquery_new = ( - db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .where( - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, - Document.updated_at > plan_sandbox_clean_day, - ) - .group_by(Document.dataset_id) - .subquery() - ) - # Subquery for counting old documents - document_subquery_old = ( - db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .where( - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, - Document.updated_at < plan_sandbox_clean_day, - ) - .group_by(Document.dataset_id) - .subquery() - ) + # Define cleanup configurations + cleanup_configs: list[CleanupConfig] = [ + { + "clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING), + "plan_filter": None, + "add_logs": True, + }, + { + "clean_day": datetime.datetime.now() - datetime.timedelta(days=dify_config.PLAN_PRO_CLEAN_DAY_SETTING), + "plan_filter": "sandbox", + "add_logs": False, + }, + ] - # Main query with join and filter - stmt = ( - select(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) - .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) - .where( - Dataset.created_at < plan_sandbox_clean_day, - func.coalesce(document_subquery_new.c.document_count, 0) == 0, - func.coalesce(document_subquery_old.c.document_count, 0) > 0, - ) - .order_by(Dataset.created_at.desc()) - ) + for config in cleanup_configs: + clean_day = config["clean_day"] + plan_filter = config["plan_filter"] + add_logs = config["add_logs"] - datasets = db.paginate(stmt, page=1, per_page=50) - - except NotFound: - break - if datasets.items is None or len(datasets.items) == 0: - break - for dataset in datasets: - dataset_query = ( - db.session.query(DatasetQuery) - .where(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id) - .all() - ) - if not dataset_query or len(dataset_query) == 0: - try: - # add auto disable log - documents = ( - db.session.query(Document) - .where( - Document.dataset_id == dataset.id, - Document.enabled == True, - Document.archived == False, - ) - .all() + while True: + try: + # Subquery for counting new documents + document_subquery_new = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .where( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at > clean_day, ) - for document in documents: - dataset_auto_disable_log = DatasetAutoDisableLog( - tenant_id=dataset.tenant_id, - dataset_id=dataset.id, - document_id=document.id, - ) - db.session.add(dataset_auto_disable_log) - # remove index - index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() - index_processor.clean(dataset, None) + .group_by(Document.dataset_id) + .subquery() + ) - # update document - db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False}) - db.session.commit() - click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) - except Exception as e: - click.echo( - click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + # Subquery for counting old documents + document_subquery_old = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .where( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at < clean_day, ) - while True: - try: - # Subquery for counting new documents - document_subquery_new = ( - db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .where( - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, - Document.updated_at > plan_pro_clean_day, + .group_by(Document.dataset_id) + .subquery() ) - .group_by(Document.dataset_id) - .subquery() - ) - # Subquery for counting old documents - document_subquery_old = ( - db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) - .where( - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, - Document.updated_at < plan_pro_clean_day, - ) - .group_by(Document.dataset_id) - .subquery() - ) - - # Main query with join and filter - stmt = ( - select(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) - .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) - .where( - Dataset.created_at < plan_pro_clean_day, - func.coalesce(document_subquery_new.c.document_count, 0) == 0, - func.coalesce(document_subquery_old.c.document_count, 0) > 0, - ) - .order_by(Dataset.created_at.desc()) - ) - datasets = db.paginate(stmt, page=1, per_page=50) - - except NotFound: - break - if datasets.items is None or len(datasets.items) == 0: - break - for dataset in datasets: - dataset_query = ( - db.session.query(DatasetQuery) - .where(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id) - .all() - ) - if not dataset_query or len(dataset_query) == 0: - try: - features_cache_key = f"features:{dataset.tenant_id}" - plan_cache = redis_client.get(features_cache_key) - if plan_cache is None: - features = FeatureService.get_features(dataset.tenant_id) - redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) - plan = features.billing.subscription.plan - else: - plan = plan_cache.decode() - if plan == "sandbox": - # remove index - index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() - index_processor.clean(dataset, None) - - # update document - db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False}) - db.session.commit() - click.echo( - click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") - ) - except Exception as e: - click.echo( - click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + # Main query with join and filter + stmt = ( + select(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) + .where( + Dataset.created_at < clean_day, + func.coalesce(document_subquery_new.c.document_count, 0) == 0, + func.coalesce(document_subquery_old.c.document_count, 0) > 0, ) + .order_by(Dataset.created_at.desc()) + ) + + datasets = db.paginate(stmt, page=1, per_page=50) + + except SQLAlchemyError: + raise + + if datasets.items is None or len(datasets.items) == 0: + break + + for dataset in datasets: + dataset_query = ( + db.session.query(DatasetQuery) + .where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id) + .all() + ) + + if not dataset_query or len(dataset_query) == 0: + try: + should_clean = True + + # Check plan filter if specified + if plan_filter: + features_cache_key = f"features:{dataset.tenant_id}" + plan_cache = redis_client.get(features_cache_key) + if plan_cache is None: + features = FeatureService.get_features(dataset.tenant_id) + redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) + plan = features.billing.subscription.plan + else: + plan = plan_cache.decode() + should_clean = plan == plan_filter + + if should_clean: + # Add auto disable log if required + if add_logs: + documents = ( + db.session.query(Document) + .where( + Document.dataset_id == dataset.id, + Document.enabled == True, + Document.archived == False, + ) + .all() + ) + for document in documents: + dataset_auto_disable_log = DatasetAutoDisableLog( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + ) + db.session.add(dataset_auto_disable_log) + + # Remove index + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + index_processor.clean(dataset, None) + + # Update document + db.session.query(Document).filter_by(dataset_id=dataset.id).update( + {Document.enabled: False} + ) + db.session.commit() + click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green")) + except Exception as e: + click.echo(click.style(f"clean dataset index error: {e.__class__.__name__} {str(e)}", fg="red")) + end_at = time.perf_counter() - click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) + click.echo(click.style(f"Cleaned unused dataset from db success latency: {end_at - start_at}", fg="green")) diff --git a/api/schedule/clean_workflow_runlogs_precise.py b/api/schedule/clean_workflow_runlogs_precise.py new file mode 100644 index 0000000000..8c21be01dc --- /dev/null +++ b/api/schedule/clean_workflow_runlogs_precise.py @@ -0,0 +1,155 @@ +import datetime +import logging +import time + +import click + +import app +from configs import dify_config +from extensions.ext_database import db +from models.model import ( + AppAnnotationHitHistory, + Conversation, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.workflow import ConversationVariable, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun + +_logger = logging.getLogger(__name__) + + +MAX_RETRIES = 3 +BATCH_SIZE = dify_config.WORKFLOW_LOG_CLEANUP_BATCH_SIZE + + +@app.celery.task(queue="dataset") +def clean_workflow_runlogs_precise(): + """Clean expired workflow run logs with retry mechanism and complete message cascade""" + + click.echo(click.style("Start clean workflow run logs (precise mode with complete cascade).", fg="green")) + start_at = time.perf_counter() + + retention_days = dify_config.WORKFLOW_LOG_RETENTION_DAYS + cutoff_date = datetime.datetime.now() - datetime.timedelta(days=retention_days) + + try: + total_workflow_runs = db.session.query(WorkflowRun).where(WorkflowRun.created_at < cutoff_date).count() + if total_workflow_runs == 0: + _logger.info("No expired workflow run logs found") + return + _logger.info("Found %s expired workflow run logs to clean", total_workflow_runs) + + total_deleted = 0 + failed_batches = 0 + batch_count = 0 + + while True: + workflow_runs = ( + db.session.query(WorkflowRun.id).where(WorkflowRun.created_at < cutoff_date).limit(BATCH_SIZE).all() + ) + + if not workflow_runs: + break + + workflow_run_ids = [run.id for run in workflow_runs] + batch_count += 1 + + success = _delete_batch_with_retry(workflow_run_ids, failed_batches) + + if success: + total_deleted += len(workflow_run_ids) + failed_batches = 0 + else: + failed_batches += 1 + if failed_batches >= MAX_RETRIES: + _logger.error("Failed to delete batch after %s retries, aborting cleanup for today", MAX_RETRIES) + break + else: + # Calculate incremental delay times: 5, 10, 15 minutes + retry_delay_minutes = failed_batches * 5 + _logger.warning("Batch deletion failed, retrying in %s minutes...", retry_delay_minutes) + time.sleep(retry_delay_minutes * 60) + continue + + _logger.info("Cleanup completed: %s expired workflow run logs deleted", total_deleted) + + except Exception as e: + db.session.rollback() + _logger.exception("Unexpected error in workflow log cleanup") + raise + + end_at = time.perf_counter() + execution_time = end_at - start_at + click.echo(click.style(f"Cleaned workflow run logs from db success latency: {execution_time:.2f}s", fg="green")) + + +def _delete_batch_with_retry(workflow_run_ids: list[str], attempt_count: int) -> bool: + """Delete a single batch with a retry mechanism and complete cascading deletion""" + try: + with db.session.begin_nested(): + message_data = ( + db.session.query(Message.id, Message.conversation_id) + .filter(Message.workflow_run_id.in_(workflow_run_ids)) + .all() + ) + message_id_list = [msg.id for msg in message_data] + conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id}) + if message_id_list: + db.session.query(AppAnnotationHitHistory).where( + AppAnnotationHitHistory.message_id.in_(message_id_list) + ).delete(synchronize_session=False) + + db.session.query(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageChain).where(MessageChain.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageFile).where(MessageFile.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_id_list)).delete( + synchronize_session=False + ) + + db.session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id.in_(workflow_run_ids)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id.in_(workflow_run_ids) + ).delete(synchronize_session=False) + + if conversation_id_list: + db.session.query(ConversationVariable).where( + ConversationVariable.conversation_id.in_(conversation_id_list) + ).delete(synchronize_session=False) + + db.session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete( + synchronize_session=False + ) + + db.session.query(WorkflowRun).where(WorkflowRun.id.in_(workflow_run_ids)).delete(synchronize_session=False) + + db.session.commit() + return True + + except Exception as e: + db.session.rollback() + _logger.exception("Batch deletion failed (attempt %s)", attempt_count + 1) + return False diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 91953354e6..c343063fae 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -33,7 +33,7 @@ def create_tidb_serverless_task(): break end_at = time.perf_counter() - click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green")) + click.echo(click.style(f"Create tidb serverless task success latency: {end_at - start_at}", fg="green")) def create_clusters(batch_size): diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 5911c98b0a..03ef9062bd 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -90,7 +90,7 @@ def mail_clean_document_notify_task(): db.session.commit() end_at = time.perf_counter() logging.info( - click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") + click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green") ) except Exception: logging.exception("Send document clean notify mail failed") diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index a05e1358ed..5868450a14 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -1,8 +1,8 @@ import logging from datetime import datetime -from urllib.parse import urlparse import click +from kombu.utils.url import parse_url # type: ignore from redis import Redis import app @@ -10,16 +10,13 @@ from configs import dify_config from extensions.ext_database import db from libs.email_i18n import EmailType, get_email_i18n_service -# Create a dedicated Redis connection (using the same configuration as Celery) -celery_broker_url = dify_config.CELERY_BROKER_URL - -parsed = urlparse(celery_broker_url) -host = parsed.hostname or "localhost" -port = parsed.port or 6379 -password = parsed.password or None -redis_db = parsed.path.strip("/") or "1" # type: ignore - -celery_redis = Redis(host=host, port=port, password=password, db=redis_db) +redis_config = parse_url(dify_config.CELERY_BROKER_URL) +celery_redis = Redis( + host=redis_config.get("hostname") or "localhost", + port=redis_config.get("port") or 6379, + password=redis_config.get("password") or None, + db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1, +) @app.celery.task(queue="monitor") @@ -27,9 +24,20 @@ def queue_monitor_task(): queue_name = "dataset" threshold = dify_config.QUEUE_MONITOR_THRESHOLD + if threshold is None: + logging.warning(click.style("QUEUE_MONITOR_THRESHOLD is not configured, skipping monitoring", fg="yellow")) + return + try: queue_length = celery_redis.llen(f"{queue_name}") logging.info(click.style(f"Start monitor {queue_name}", fg="green")) + + if queue_length is None: + logging.error( + click.style(f"Failed to get queue length for {queue_name} - Redis may be unavailable", fg="red") + ) + return + logging.info(click.style(f"Queue length: {queue_length}", fg="green")) if queue_length >= threshold: diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 4d6c1f1877..1bfeb869e2 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -29,9 +29,7 @@ def update_tidb_serverless_status_task(): click.echo(click.style(f"Error: {e}", fg="red")) end_at = time.perf_counter() - click.echo( - click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green") - ) + click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green")) def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): diff --git a/api/services/account_service.py b/api/services/account_service.py index e11f1580e5..0bb903fbbc 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -332,9 +332,9 @@ class AccountService: db.session.add(account_integrate) db.session.commit() - logging.info(f"Account {account.id} linked {provider} account {open_id}.") + logging.info("Account %s linked %s account %s.", account.id, provider, open_id) except Exception as e: - logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") + logging.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id) raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod @@ -355,6 +355,17 @@ class AccountService: db.session.commit() return account + @staticmethod + def update_account_email(account: Account, email: str) -> Account: + """Update account email""" + account.email = email + account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first() + if account_integrate: + db.session.delete(account_integrate) + db.session.add(account) + db.session.commit() + return account + @staticmethod def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" @@ -414,7 +425,7 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", ): account_email = account.email if account else email if account_email is None: @@ -441,12 +452,14 @@ class AccountService: account: Optional[Account] = None, email: Optional[str] = None, old_email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", phase: Optional[str] = None, ): account_email = account.email if account else email if account_email is None: raise ValueError("Email must be provided.") + if not phase: + raise ValueError("phase must be provided.") if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError @@ -469,7 +482,7 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", ): account_email = account.email if account else email if account_email is None: @@ -485,7 +498,7 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", workspace_name: Optional[str] = "", ): account_email = account.email if account else email @@ -498,6 +511,7 @@ class AccountService: raise OwnerTransferRateLimitExceededError() code, token = cls.generate_owner_transfer_token(account_email, account) + workspace_name = workspace_name or "" send_owner_transfer_confirm_task.delay( language=language, @@ -513,13 +527,14 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", workspace_name: Optional[str] = "", - new_owner_email: Optional[str] = "", + new_owner_email: str = "", ): account_email = account.email if account else email if account_email is None: raise ValueError("Email must be provided.") + workspace_name = workspace_name or "" send_old_owner_transfer_notify_email_task.delay( language=language, @@ -533,12 +548,13 @@ class AccountService: cls, account: Optional[Account] = None, email: Optional[str] = None, - language: Optional[str] = "en-US", + language: str = "en-US", workspace_name: Optional[str] = "", ): account_email = account.email if account else email if account_email is None: raise ValueError("Email must be provided.") + workspace_name = workspace_name or "" send_new_owner_transfer_notify_email_task.delay( language=language, @@ -622,7 +638,10 @@ class AccountService: @classmethod def send_email_code_login_email( - cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: str = "en-US", ): email = account.email if account else email if email is None: @@ -906,7 +925,7 @@ class TenantService: """Create tenant member""" if role == TenantAccountRole.OWNER.value: if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]): - logging.error(f"Tenant {tenant.id} has already an owner.") + logging.error("Tenant %s has already an owner.", tenant.id) raise Exception("Tenant already has an owner.") ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() @@ -1158,7 +1177,7 @@ class RegisterService: db.session.query(Tenant).delete() db.session.commit() - logging.exception(f"Setup account failed, email: {email}, name: {name}") + logging.exception("Setup account failed, email: %s, name: %s", email, name) raise ValueError(f"Setup failed: {e}") @classmethod @@ -1249,10 +1268,11 @@ class RegisterService: raise AccountAlreadyInTenantError("Account already in tenant.") token = cls.generate_invite_token(tenant, account) + language = account.interface_language or "en-US" # send email send_invite_member_mail_task.delay( - language=account.interface_language, + language=language, to=email, token=token, inviter_name=inviter.name if inviter else "Dify", @@ -1282,7 +1302,7 @@ class RegisterService: def revoke_token(cls, workspace_id: str, email: str, token: str): if workspace_id and email: email_hash = sha256(email.encode()).hexdigest() - cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" redis_client.delete(cache_key) else: redis_client.delete(cls._get_invitation_token_key(token)) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 7cb0b46517..45b246af1e 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,4 +1,3 @@ -import datetime import uuid from typing import cast @@ -10,6 +9,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation from services.feature_service import FeatureService from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task @@ -74,14 +74,14 @@ class AppAnnotationService: @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: - enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) + enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) + enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" # send batch add segments task redis_client.setnx(enable_app_annotation_job_key, "waiting") enable_annotation_reply_task.delay( @@ -97,14 +97,14 @@ class AppAnnotationService: @classmethod def disable_app_annotation(cls, app_id: str) -> dict: - disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) + disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) + disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}" # send batch add segments task redis_client.setnx(disable_app_annotation_job_key, "waiting") disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) @@ -127,8 +127,8 @@ class AppAnnotationService: .where(MessageAnnotation.app_id == app_id) .where( or_( - MessageAnnotation.question.ilike("%{}%".format(keyword)), - MessageAnnotation.content.ilike("%{}%".format(keyword)), + MessageAnnotation.question.ilike(f"%{keyword}%"), + MessageAnnotation.content.ilike(f"%{keyword}%"), ) ) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) @@ -266,6 +266,54 @@ class AppAnnotationService: annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id ) + @classmethod + def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): + # get app info + app = ( + db.session.query(App) + .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + # Fetch annotations and their settings in a single query + annotations_to_delete = ( + db.session.query(MessageAnnotation, AppAnnotationSetting) + .outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id) + .filter(MessageAnnotation.id.in_(annotation_ids)) + .all() + ) + + if not annotations_to_delete: + return {"deleted_count": 0} + + # Step 1: Extract IDs for bulk operations + annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete] + + # Step 2: Bulk delete hit histories in a single query + db.session.query(AppAnnotationHitHistory).where( + AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete) + ).delete(synchronize_session=False) + + # Step 3: Trigger async tasks for search index deletion + for annotation, annotation_setting in annotations_to_delete: + if annotation_setting: + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, annotation_setting.collection_binding_id + ) + + # Step 4: Bulk delete annotations in a single query + deleted_count = ( + db.session.query(MessageAnnotation) + .where(MessageAnnotation.id.in_(annotation_ids_to_delete)) + .delete(synchronize_session=False) + ) + + db.session.commit() + return {"deleted_count": deleted_count} + @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: # get app info @@ -280,7 +328,7 @@ class AppAnnotationService: try: # Skip the first row - df = pd.read_csv(file) + df = pd.read_csv(file, dtype=str) result = [] for index, row in df.iterrows(): content = {"question": row.iloc[0], "answer": row.iloc[1]} @@ -295,7 +343,7 @@ class AppAnnotationService: raise ValueError("The number of annotations exceeds the limit of your subscription.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) + indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" # send batch add segments task redis_client.setnx(indexing_cache_key, "waiting") batch_import_annotations_task.delay( @@ -425,7 +473,7 @@ class AppAnnotationService: raise NotFound("App annotation not found") annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id - annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + annotation_setting.updated_at = naive_utc_now() db.session.add(annotation_setting) db.session.commit() @@ -440,3 +488,38 @@ class AppAnnotationService: "embedding_model_name": collection_binding_detail.model_name, }, } + + @classmethod + def clear_all_annotations(cls, app_id: str) -> dict: + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) + + if not app: + raise NotFound("App not found") + + # if annotation reply is enabled, delete annotation index + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + ) + + annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id) + for annotation in annotations_query.yield_per(100): + annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where( + AppAnnotationHitHistory.annotation_id == annotation.id + ) + for annotation_hit_history in annotation_hit_histories_query.yield_per(100): + db.session.delete(annotation_hit_history) + + # if annotation reply is enabled, delete annotation index + if app_annotation_setting: + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + ) + + db.session.delete(annotation) + + db.session.commit() + return {"result": "success"} diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 457c91e5c0..2f28eff165 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -102,4 +102,4 @@ class APIBasedExtensionService: if resp.get("result") != "pong": raise ValueError(resp) except Exception as e: - raise ValueError("connection error: {}".format(e)) + raise ValueError(f"connection error: {e}") diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index fe0efd061d..2aa9f6cabd 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -12,6 +12,7 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from packaging import version +from packaging.version import parse as parse_version from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session @@ -269,7 +270,7 @@ class AppDslService: check_dependencies_pending_data = None if dependencies: check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies] - elif imported_version <= "0.1.5": + elif parse_version(imported_version) <= parse_version("0.1.5"): if "workflow" in data: graph = data.get("workflow", {}).get("graph", {}) dependencies_list = self._extract_dependencies_from_workflow_graph(graph) diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 6f7e705b52..6792324ec8 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,5 +1,6 @@ +import uuid from collections.abc import Generator, Mapping -from typing import Any, Union +from typing import Any, Optional, Union from openai._exceptions import RateLimitError @@ -15,6 +16,7 @@ from libs.helper import RateLimiter from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow from services.billing_service import BillingService +from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService @@ -86,7 +88,8 @@ class AppGenerateService: request_id=request_id, ) elif app_model.mode == AppMode.ADVANCED_CHAT.value: - workflow = cls._get_workflow(app_model, invoke_from) + workflow_id = args.get("workflow_id") + workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( AdvancedChatAppGenerator.convert_to_event_stream( AdvancedChatAppGenerator().generate( @@ -101,7 +104,8 @@ class AppGenerateService: request_id=request_id, ) elif app_model.mode == AppMode.WORKFLOW.value: - workflow = cls._get_workflow(app_model, invoke_from) + workflow_id = args.get("workflow_id") + workflow = cls._get_workflow(app_model, invoke_from, workflow_id) return rate_limit.generate( WorkflowAppGenerator.convert_to_event_stream( WorkflowAppGenerator().generate( @@ -210,14 +214,27 @@ class AppGenerateService: ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow: """ Get workflow :param app_model: app model :param invoke_from: invoke from + :param workflow_id: optional workflow id to specify a specific version :return: """ workflow_service = WorkflowService() + + # If workflow_id is specified, get the specific workflow version + if workflow_id: + try: + workflow_uuid = uuid.UUID(workflow_id) + except ValueError: + raise WorkflowIdFormatError(f"Invalid workflow_id format: '{workflow_id}'. ") + workflow = workflow_service.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id) + if not workflow: + raise WorkflowNotFoundError(f"Workflow not found with id: {workflow_id}") + return workflow + if invoke_from == InvokeFrom.DEBUGGER: # fetch draft workflow by app_model workflow = workflow_service.get_draft_workflow(app_model=app_model) diff --git a/api/services/app_service.py b/api/services/app_service.py index 0b6b85bcb2..0f22666d5a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -53,9 +53,10 @@ class AppService: if args.get("name"): name = args["name"][:30] filters.append(App.name.ilike(f"%{name}%")) - if args.get("tag_ids"): + # Check if tag_ids is not empty to avoid WHERE false condition + if args.get("tag_ids") and len(args["tag_ids"]) > 0: target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) - if target_ids: + if target_ids and len(target_ids) > 0: filters.append(App.id.in_(target_ids)) else: return None @@ -94,7 +95,7 @@ class AppService: except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None except Exception as e: - logging.exception(f"Get default model instance failed, tenant_id: {tenant_id}") + logging.exception("Get default model instance failed, tenant_id: %s", tenant_id) model_instance = None if model_instance: diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 5a12aa2e54..40d45af376 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -123,7 +123,7 @@ class BillingService: return BillingService._send_request("GET", "/education/verify", params=params) @classmethod - def is_active(cls, account_id: str): + def status(cls, account_id: str): params = {"account_id": account_id} return BillingService._send_request("GET", "/education/status", params=params) @@ -159,9 +159,9 @@ class BillingService: ): limiter_key = f"{account_id}:{tenant_id}" if cls.compliance_download_rate_limiter.is_rate_limited(limiter_key): - from controllers.console.error import CompilanceRateLimitError + from controllers.console.error import ComplianceRateLimitError - raise CompilanceRateLimitError() + raise ComplianceRateLimitError() json = { "doc_name": doc_name, diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index ad9b750d40..b28afcaa41 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -13,7 +13,19 @@ from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant -from models.model import App, Conversation, Message +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from models.workflow import WorkflowAppLog from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService @@ -21,6 +33,85 @@ logger = logging.getLogger(__name__) class ClearFreePlanTenantExpiredLogs: + @classmethod + def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None: + """ + Clean up message-related tables to avoid data redundancy. + This method cleans up tables that have foreign key relationships with Message. + + Args: + session: Database session, the same with the one in process_tenant method + tenant_id: Tenant ID for logging purposes + batch_message_ids: List of message IDs to clean up + """ + if not batch_message_ids: + return + + # Clean up each related table + related_tables = [ + (MessageFeedback, "message_feedbacks"), + (MessageFile, "message_files"), + (MessageAnnotation, "message_annotations"), + (MessageChain, "message_chains"), + (MessageAgentThought, "message_agent_thoughts"), + (AppAnnotationHitHistory, "app_annotation_hit_histories"), + (SavedMessage, "saved_messages"), + ] + + for model, table_name in related_tables: + # Query records related to expired messages + records = ( + session.query(model) + .filter( + model.message_id.in_(batch_message_ids), # type: ignore + ) + .all() + ) + + if len(records) == 0: + continue + + # Save records before deletion + record_ids = [record.id for record in records] + try: + record_data = [] + for record in records: + try: + if hasattr(record, "to_dict"): + record_data.append(record.to_dict()) + else: + # if record doesn't have to_dict method, we need to transform it to dict manually + record_dict = {} + for column in record.__table__.columns: + record_dict[column.name] = getattr(record, column.name) + record_data.append(record_dict) + except Exception: + logger.exception("Failed to transform %s record: %s", table_name, record.id) + continue + + if record_data: + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/{table_name}/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(record_data), + ).encode("utf-8"), + ) + except Exception: + logger.exception("Failed to save %s records", table_name) + + session.query(model).filter( + model.id.in_(record_ids), # type: ignore + ).delete(synchronize_session=False) + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(record_ids)} " + f"{table_name} records for tenant {tenant_id}" + ) + ) + @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): @@ -58,6 +149,7 @@ class ClearFreePlanTenantExpiredLogs: Message.id.in_(message_ids), ).delete(synchronize_session=False) + cls._clear_message_related_tables(session, tenant_id, message_ids) session.commit() click.echo( @@ -199,6 +291,48 @@ class ClearFreePlanTenantExpiredLogs: if len(workflow_runs) < batch: break + while True: + with Session(db.engine).no_autoflush as session: + workflow_app_logs = ( + session.query(WorkflowAppLog) + .filter( + WorkflowAppLog.tenant_id == tenant_id, + WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), + ) + .limit(batch) + .all() + ) + + if len(workflow_app_logs) == 0: + break + + # save workflow app logs + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_app_logs/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder( + [workflow_app_log.to_dict() for workflow_app_log in workflow_app_logs], + ), + ).encode("utf-8"), + ) + + workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] + + # delete workflow app logs + session.query(WorkflowAppLog).filter( + WorkflowAppLog.id.in_(workflow_app_log_ids), + ).delete(synchronize_session=False) + session.commit() + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_app_log_ids)}" + f" workflow app logs for tenant {tenant_id}" + ) + ) + @classmethod def process(cls, days: int, batch: int, tenant_ids: list[str]): """ @@ -228,7 +362,7 @@ class ClearFreePlanTenantExpiredLogs: # only process sandbox tenant cls.process_tenant(flask_app, tenant_id, days, batch) except Exception: - logger.exception(f"Failed to process tenant {tenant_id}") + logger.exception("Failed to process tenant %s", tenant_id) finally: nonlocal handled_tenant_count handled_tenant_count += 1 @@ -311,7 +445,7 @@ class ClearFreePlanTenantExpiredLogs: try: tenants.append(tenant_id) except Exception: - logger.exception(f"Failed to process tenant {tenant_id}") + logger.exception("Failed to process tenant %s", tenant_id) continue futures.append( diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 525c87fe4a..ac603d3cc9 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,12 +1,17 @@ +import contextlib +import logging from collections.abc import Callable, Sequence -from typing import Optional, Union +from typing import Any, Optional, Union from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator +from core.variables.types import SegmentType +from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory from extensions.ext_database import db +from factories import variable_factory from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ConversationVariable @@ -15,9 +20,13 @@ from models.model import App, Conversation, EndUser, Message from services.errors.conversation import ( ConversationNotExistsError, ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, LastConversationNotExistsError, ) from services.errors.message import MessageNotExistsError +from tasks.delete_conversation_task import delete_conversation_related_data + +logger = logging.getLogger(__name__) class ConversationService: @@ -46,10 +55,16 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) + # Check if include_ids is not None to apply filter if include_ids is not None: + if len(include_ids) == 0: + # If include_ids is empty, return empty result + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) stmt = stmt.where(Conversation.id.in_(include_ids)) + # Check if exclude_ids is not None to apply filter if exclude_ids is not None: - stmt = stmt.where(~Conversation.id.in_(exclude_ids)) + if len(exclude_ids) > 0: + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) @@ -93,10 +108,10 @@ class ConversationService: @classmethod def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation): field_value = getattr(reference_conversation, sort_field) - if sort_direction == desc: + if sort_direction is desc: return getattr(Conversation, sort_field) < field_value - else: - return getattr(Conversation, sort_field) > field_value + + return getattr(Conversation, sort_field) > field_value @classmethod def rename( @@ -132,13 +147,11 @@ class ConversationService: raise MessageNotExistsError() # generate conversation name - try: + with contextlib.suppress(Exception): name = LLMGenerator.generate_conversation_name( app_model.tenant_id, message.query, conversation.id, app_model.id ) conversation.name = name - except: - pass db.session.commit() @@ -166,11 +179,21 @@ class ConversationService: @classmethod def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - conversation = cls.get_conversation(app_model, conversation_id, user) + try: + logger.info( + "Initiating conversation deletion for app_name %s, conversation_id: %s", + app_model.name, + conversation_id, + ) - conversation.is_deleted = True - conversation.updated_at = naive_utc_now() - db.session.commit() + db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) + db.session.commit() + + delete_conversation_related_data.delay(conversation_id) + + except Exception as e: + db.session.rollback() + raise e @classmethod def get_conversational_variable( @@ -218,3 +241,87 @@ class ConversationService: ] return InfiniteScrollPagination(variables, limit, has_more) + + @classmethod + def update_conversation_variable( + cls, + app_model: App, + conversation_id: str, + variable_id: str, + user: Optional[Union[Account, EndUser]], + new_value: Any, + ) -> dict: + """ + Update a conversation variable's value. + + Args: + app_model: The app model + conversation_id: The conversation ID + variable_id: The variable ID to update + user: The user (Account or EndUser) + new_value: The new value for the variable + + Returns: + Dictionary containing the updated variable information + + Raises: + ConversationNotExistsError: If the conversation doesn't exist + ConversationVariableNotExistsError: If the variable doesn't exist + ConversationVariableTypeMismatchError: If the new value type doesn't match the variable's expected type + """ + # Verify conversation exists and user has access + conversation = cls.get_conversation(app_model, conversation_id, user) + + # Get the existing conversation variable + stmt = ( + select(ConversationVariable) + .where(ConversationVariable.app_id == app_model.id) + .where(ConversationVariable.conversation_id == conversation.id) + .where(ConversationVariable.id == variable_id) + ) + + with Session(db.engine) as session: + existing_variable = session.scalar(stmt) + if not existing_variable: + raise ConversationVariableNotExistsError() + + # Convert existing variable to Variable object + current_variable = existing_variable.to_variable() + + # Validate that the new value type matches the expected variable type + expected_type = SegmentType(current_variable.value_type) + + # There is showing number in web ui but int in db + if expected_type == SegmentType.INTEGER: + expected_type = SegmentType.NUMBER + + if not expected_type.is_valid(new_value): + inferred_type = SegmentType.infer_segment_type(new_value) + raise ConversationVariableTypeMismatchError( + f"Type mismatch: variable '{current_variable.name}' expects {expected_type.value}, " + f"but got {inferred_type.value if inferred_type else 'unknown'} type" + ) + + # Create updated variable with new value only, preserving everything else + updated_variable_dict = { + "id": current_variable.id, + "name": current_variable.name, + "description": current_variable.description, + "value_type": current_variable.value_type, + "value": new_value, + "selector": current_variable.selector, + } + + updated_variable = variable_factory.build_conversation_variable_from_mapping(updated_variable_dict) + + # Use the conversation variable updater to persist the changes + updater = conversation_variable_updater_factory() + updater.update(conversation_id, updated_variable) + updater.flush() + + # Return the updated variable data + return { + "created_at": existing_variable.created_at, + "updated_at": naive_utc_now(), # Update timestamp + **updated_variable.model_dump(), + } diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4872702a76..fc2cbba78b 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import secrets import time import uuid from collections import Counter -from typing import Any, Optional +from typing import Any, Literal, Optional from flask_login import current_user from sqlalchemy import func, select @@ -51,7 +51,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( RetrievalModel, SegmentUpdateArgs, ) -from services.errors.account import InvalidActionError, NoPermissionError +from services.errors.account import NoPermissionError from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError @@ -91,14 +91,16 @@ class DatasetService: if user.current_role == TenantAccountRole.DATASET_OPERATOR: # only show datasets that the user has permission to access - if permitted_dataset_ids: + # Check if permitted_dataset_ids is not empty to avoid WHERE false condition + if permitted_dataset_ids and len(permitted_dataset_ids) > 0: query = query.where(Dataset.id.in_(permitted_dataset_ids)) else: return [], 0 else: if user.current_role != TenantAccountRole.OWNER or not include_all: # show all datasets that the user has permission to access - if permitted_dataset_ids: + # Check if permitted_dataset_ids is not empty to avoid WHERE false condition + if permitted_dataset_ids and len(permitted_dataset_ids) > 0: query = query.where( db.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, @@ -127,9 +129,10 @@ class DatasetService: if search: query = query.where(Dataset.name.ilike(f"%{search}%")) - if tag_ids: + # Check if tag_ids is not empty to avoid WHERE false condition + if tag_ids and len(tag_ids) > 0: target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) - if target_ids: + if target_ids and len(target_ids) > 0: query = query.where(Dataset.id.in_(target_ids)) else: return [], 0 @@ -158,6 +161,9 @@ class DatasetService: @staticmethod def get_datasets_by_ids(ids, tenant_id): + # Check if ids is not empty to avoid WHERE false condition + if not ids or len(ids) == 0: + return [], 0 stmt = select(Dataset).where(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) @@ -244,6 +250,11 @@ class DatasetService: dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() return dataset + @staticmethod + def check_doc_form(dataset: Dataset, doc_form: str): + if dataset.doc_form and doc_form != dataset.doc_form: + raise ValueError("doc_form is different from the dataset doc_form.") + @staticmethod def check_dataset_model_setting(dataset): if dataset.indexing_technique == "high_quality": @@ -260,7 +271,7 @@ class DatasetService: "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError(f"The dataset in unavailable, due to: {ex.description}") + raise ValueError(f"The dataset is unavailable, due to: {ex.description}") @staticmethod def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): @@ -364,7 +375,7 @@ class DatasetService: raise ValueError("External knowledge api id is required.") # Update metadata fields dataset.updated_by = user.id if user else None - dataset.updated_at = datetime.datetime.utcnow() + dataset.updated_at = naive_utc_now() db.session.add(dataset) # Update external knowledge binding @@ -605,8 +616,9 @@ class DatasetService: except ProviderTokenNotInitError: # If we can't get the embedding model, preserve existing settings logging.warning( - f"Failed to initialize embedding model {data['embedding_model_provider']}/{data['embedding_model']}, " - f"preserving existing settings" + "Failed to initialize embedding model %s/%s, preserving existing settings", + data["embedding_model_provider"], + data["embedding_model"], ) if dataset.embedding_model_provider and dataset.embedding_model: filtered_data["embedding_model_provider"] = dataset.embedding_model_provider @@ -649,11 +661,11 @@ class DatasetService: @staticmethod def check_dataset_permission(dataset, user): if dataset.tenant_id != user.current_tenant_id: - logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + logging.debug("User %s does not have permission to access dataset %s", user.id, dataset.id) raise NoPermissionError("You do not have permission to access this dataset.") if user.current_role != TenantAccountRole.OWNER: if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: - logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + logging.debug("User %s does not have permission to access dataset %s", user.id, dataset.id) raise NoPermissionError("You do not have permission to access this dataset.") if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: # For partial team permission, user needs explicit permission or be the creator @@ -662,7 +674,7 @@ class DatasetService: db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first() ) if not user_permission: - logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + logging.debug("User %s does not have permission to access dataset %s", user.id, dataset.id) raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod @@ -950,6 +962,9 @@ class DocumentService: @staticmethod def delete_documents(dataset: Dataset, document_ids: list[str]): + # Check if document_ids is not empty to avoid WHERE false condition + if not document_ids or len(document_ids) == 0: + return documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() file_ids = [ document.data_source_info_dict["upload_file_id"] @@ -1000,7 +1015,7 @@ class DocumentService: db.session.add(document) db.session.commit() # set document paused flag - indexing_cache_key = "document_{}_is_paused".format(document.id) + indexing_cache_key = f"document_{document.id}_is_paused" redis_client.setnx(indexing_cache_key, "True") @staticmethod @@ -1015,7 +1030,7 @@ class DocumentService: db.session.add(document) db.session.commit() # delete paused flag - indexing_cache_key = "document_{}_is_paused".format(document.id) + indexing_cache_key = f"document_{document.id}_is_paused" redis_client.delete(indexing_cache_key) # trigger async task recover_document_indexing_task.delay(document.dataset_id, document.id) @@ -1024,7 +1039,7 @@ class DocumentService: def retry_document(dataset_id: str, documents: list[Document]): for document in documents: # add retry flag - retry_indexing_cache_key = "document_{}_is_retried".format(document.id) + retry_indexing_cache_key = f"document_{document.id}_is_retried" cache_result = redis_client.get(retry_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being retried, please try again later") @@ -1041,7 +1056,7 @@ class DocumentService: @staticmethod def sync_website_document(dataset_id: str, document: Document): # add sync flag - sync_indexing_cache_key = "document_{}_is_sync".format(document.id) + sync_indexing_cache_key = f"document_{document.id}_is_sync" cache_result = redis_client.get(sync_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being synced, please try again later") @@ -1075,6 +1090,8 @@ class DocumentService: dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", ): + # check doc_form + DatasetService.check_doc_form(dataset, knowledge_config.doc_form) # check document limit features = FeatureService.get_features(current_user.current_tenant_id) @@ -1174,12 +1191,13 @@ class DocumentService: ) else: logging.warning( - f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + "Invalid process rule mode: %s, can not find dataset process rule", + process_rule.mode, ) return db.session.add(dataset_process_rule) db.session.commit() - lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + lock_name = f"add_document_lock_dataset_id_{dataset.id}" with redis_client.lock(lock_name, timeout=600): position = DocumentService.get_documents_position(dataset.id) document_ids = [] @@ -1216,7 +1234,7 @@ class DocumentService: ) if document: document.dataset_process_rule_id = dataset_process_rule.id # type: ignore - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.updated_at = naive_utc_now() document.created_from = created_from document.doc_form = knowledge_config.doc_form document.doc_language = knowledge_config.doc_language @@ -1534,7 +1552,7 @@ class DocumentService: document.parsing_completed_at = None document.cleaning_completed_at = None document.splitting_completed_at = None - document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.updated_at = naive_utc_now() document.created_from = created_from document.doc_form = document_data.doc_form db.session.add(document) @@ -1789,14 +1807,16 @@ class DocumentService: raise ValueError("Process rule segmentation max_tokens is invalid") @staticmethod - def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user): + def batch_update_document_status( + dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user + ): """ Batch update document status. Args: dataset (Dataset): The dataset object document_ids (list[str]): List of document IDs to update - action (str): Action to perform (enable, disable, archive, un_archive) + action (Literal["enable", "disable", "archive", "un_archive"]): Action to perform user: Current user performing the action Raises: @@ -1862,7 +1882,7 @@ class DocumentService: task_func.delay(*task_args) except Exception as e: # Log the error but do not rollback the transaction - logging.exception(f"Error executing async task for document {update_info['document'].id}") + logging.exception("Error executing async task for document %s", update_info["document"].id) # don't raise the error immediately, but capture it for later propagation_error = e try: @@ -1873,15 +1893,16 @@ class DocumentService: redis_client.setex(indexing_cache_key, 600, 1) except Exception as e: # Log the error but do not rollback the transaction - logging.exception(f"Error setting cache for document {update_info['document'].id}") + logging.exception("Error setting cache for document %s", update_info["document"].id) # Raise any propagation error after all updates if propagation_error: raise propagation_error @staticmethod - def _prepare_document_status_update(document, action: str, user): - """ - Prepare document status update information. + def _prepare_document_status_update( + document: Document, action: Literal["enable", "disable", "archive", "un_archive"], user + ): + """Prepare document status update information. Args: document: Document object to update @@ -1891,7 +1912,7 @@ class DocumentService: Returns: dict: Update information or None if no update needed """ - now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + now = naive_utc_now() if action == "enable": return DocumentService._prepare_enable_update(document, now) @@ -2001,7 +2022,7 @@ class SegmentService: ) # calc embedding use tokens tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0] - lock_name = "add_segment_lock_document_id_{}".format(document.id) + lock_name = f"add_segment_lock_document_id_{document.id}" with redis_client.lock(lock_name, timeout=600): max_position = ( db.session.query(func.max(DocumentSegment.position)) @@ -2019,8 +2040,8 @@ class SegmentService: word_count=len(content), tokens=tokens, status="completed", - indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + indexing_at=naive_utc_now(), + completed_at=naive_utc_now(), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -2029,6 +2050,7 @@ class SegmentService: db.session.add(segment_document) # update document word count + assert document.word_count is not None document.word_count += segment_document.word_count db.session.add(document) db.session.commit() @@ -2039,7 +2061,7 @@ class SegmentService: except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment_document.disabled_at = naive_utc_now() segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -2048,7 +2070,7 @@ class SegmentService: @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): - lock_name = "multi_add_segment_lock_document_id_{}".format(document.id) + lock_name = f"multi_add_segment_lock_document_id_{document.id}" increment_word_count = 0 with redis_client.lock(lock_name, timeout=600): embedding_model = None @@ -2095,8 +2117,8 @@ class SegmentService: tokens=tokens, keywords=segment_item.get("keywords", []), status="completed", - indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + indexing_at=naive_utc_now(), + completed_at=naive_utc_now(), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -2113,6 +2135,7 @@ class SegmentService: else: keywords_list.append(None) # update document word count + assert document.word_count is not None document.word_count += increment_word_count db.session.add(document) try: @@ -2122,7 +2145,7 @@ class SegmentService: logging.exception("create segment index failed") for segment_document in segment_data_list: segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment_document.disabled_at = naive_utc_now() segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -2130,7 +2153,7 @@ class SegmentService: @classmethod def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = "segment_{}_indexing".format(segment.id) + indexing_cache_key = f"segment_{segment.id}_indexing" cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") @@ -2139,7 +2162,7 @@ class SegmentService: if segment.enabled != action: if not action: segment.enabled = action - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.disabled_by = current_user.id db.session.add(segment) db.session.commit() @@ -2174,6 +2197,7 @@ class SegmentService: db.session.commit() # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task @@ -2236,10 +2260,10 @@ class SegmentService: segment.word_count = len(content) segment.tokens = tokens segment.status = "completed" - segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.indexing_at = naive_utc_now() + segment.completed_at = naive_utc_now() segment.updated_by = current_user.id - segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.updated_at = naive_utc_now() segment.enabled = True segment.disabled_at = None segment.disabled_by = None @@ -2249,6 +2273,7 @@ class SegmentService: word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: + assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) db.session.add(segment) @@ -2291,7 +2316,7 @@ class SegmentService: except Exception as e: logging.exception("update segment index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.status = "error" segment.error = str(e) db.session.commit() @@ -2300,7 +2325,7 @@ class SegmentService: @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = "segment_{}_delete_indexing".format(segment.id) + indexing_cache_key = f"segment_{segment.id}_delete_indexing" cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is deleting.") @@ -2312,16 +2337,16 @@ class SegmentService: delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) db.session.delete(segment) # update document word count + assert document.word_count is not None document.word_count -= segment.word_count db.session.add(document) db.session.commit() @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): - index_node_ids = ( - db.session.query(DocumentSegment) - .with_entities(DocumentSegment.index_node_id) - .where( + segments = ( + db.session.query(DocumentSegment.index_node_id, DocumentSegment.word_count) + .filter( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, @@ -2329,14 +2354,27 @@ class SegmentService: ) .all() ) - index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] + + if not segments: + return + + index_node_ids = [seg.index_node_id for seg in segments] + total_words = sum(seg.word_count for seg in segments) + + document.word_count -= total_words + db.session.add(document) delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() db.session.commit() @classmethod - def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): + def update_segments_status( + cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document + ): + # Check if segment_ids is not empty to avoid WHERE false condition + if not segment_ids or len(segment_ids) == 0: + return if action == "enable": segments = ( db.session.query(DocumentSegment) @@ -2350,9 +2388,9 @@ class SegmentService: ) if not segments: return - real_deal_segmment_ids = [] + real_deal_segment_ids = [] for segment in segments: - indexing_cache_key = "segment_{}_indexing".format(segment.id) + indexing_cache_key = f"segment_{segment.id}_indexing" cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: continue @@ -2360,10 +2398,10 @@ class SegmentService: segment.disabled_at = None segment.disabled_by = None db.session.add(segment) - real_deal_segmment_ids.append(segment.id) + real_deal_segment_ids.append(segment.id) db.session.commit() - enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) elif action == "disable": segments = ( db.session.query(DocumentSegment) @@ -2377,28 +2415,26 @@ class SegmentService: ) if not segments: return - real_deal_segmment_ids = [] + real_deal_segment_ids = [] for segment in segments: - indexing_cache_key = "segment_{}_indexing".format(segment.id) + indexing_cache_key = f"segment_{segment.id}_indexing" cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: continue segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.disabled_by = current_user.id db.session.add(segment) - real_deal_segmment_ids.append(segment.id) + real_deal_segment_ids.append(segment.id) db.session.commit() - disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) - else: - raise InvalidActionError() + disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) @classmethod def create_child_chunk( cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset ) -> ChildChunk: - lock_name = "add_child_lock_{}".format(segment.id) + lock_name = f"add_child_lock_{segment.id}" with redis_client.lock(lock_name, timeout=20): index_node_id = str(uuid.uuid4()) index_node_hash = helper.generate_text_hash(content) @@ -2476,7 +2512,7 @@ class SegmentService: child_chunk.content = child_chunk_update_args.content child_chunk.word_count = len(child_chunk.content) child_chunk.updated_by = current_user.id - child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + child_chunk.updated_at = naive_utc_now() child_chunk.type = "customized" update_child_chunks.append(child_chunk) else: @@ -2533,7 +2569,7 @@ class SegmentService: child_chunk.content = content child_chunk.word_count = len(content) child_chunk.updated_by = current_user.id - child_chunk.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + child_chunk.updated_at = naive_utc_now() child_chunk.type = "customized" db.session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) @@ -2598,7 +2634,8 @@ class SegmentService: DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id ) - if status_list: + # Check if status_list is not empty to avoid WHERE false condition + if status_list and len(status_list) > 0: query = query.where(DocumentSegment.status.in_(status_list)) if keyword: @@ -2647,7 +2684,7 @@ class SegmentService: # check segment segment = ( db.session.query(DocumentSegment) - .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) + .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) .first() ) if not segment: diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 54d45f45ea..f8612456d6 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -52,6 +52,16 @@ class EnterpriseService: return data.get("result", False) + @classmethod + def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]): + if not app_codes: + return {} + body = {"userId": user_id, "appCodes": app_codes} + data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body) + if not data: + raise ValueError("No data found.") + return data.get("permissions", {}) + @classmethod def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: if not app_id: diff --git a/api/services/enterprise/mail_service.py b/api/services/enterprise/mail_service.py deleted file mode 100644 index 630e7679ac..0000000000 --- a/api/services/enterprise/mail_service.py +++ /dev/null @@ -1,18 +0,0 @@ -from pydantic import BaseModel - -from tasks.mail_enterprise_task import send_enterprise_email_task - - -class DifyMail(BaseModel): - to: list[str] - subject: str - body: str - substitutions: dict[str, str] = {} - - -class EnterpriseMailService: - @classmethod - def send_mail(cls, mail: DifyMail): - send_enterprise_email_task.delay( - to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions - ) diff --git a/api/services/errors/app.py b/api/services/errors/app.py index 5d348c61be..390716a47f 100644 --- a/api/services/errors/app.py +++ b/api/services/errors/app.py @@ -8,3 +8,11 @@ class WorkflowHashNotEqualError(Exception): class IsDraftWorkflowError(Exception): pass + + +class WorkflowNotFoundError(Exception): + pass + + +class WorkflowIdFormatError(Exception): + pass diff --git a/api/services/errors/conversation.py b/api/services/errors/conversation.py index f8051e3417..a123f99b59 100644 --- a/api/services/errors/conversation.py +++ b/api/services/errors/conversation.py @@ -15,3 +15,7 @@ class ConversationCompletedError(Exception): class ConversationVariableNotExistsError(BaseServiceError): pass + + +class ConversationVariableTypeMismatchError(BaseServiceError): + pass diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index b7af03e91f..2f1babba6f 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -46,9 +46,9 @@ class ExternalDatasetService: def validate_api_list(cls, api_settings: dict): if not api_settings: raise ValueError("api list is empty") - if "endpoint" not in api_settings and not api_settings["endpoint"]: + if not api_settings.get("endpoint"): raise ValueError("endpoint is required") - if "api_key" not in api_settings and not api_settings["api_key"]: + if not api_settings.get("api_key"): raise ValueError("api_key is required") @staticmethod diff --git a/api/services/file_service.py b/api/services/file_service.py index e234c2f325..4c0a0f451c 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,4 +1,3 @@ -import datetime import hashlib import os import uuid @@ -18,6 +17,7 @@ from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models.account import Account from models.enums import CreatorUserRole @@ -80,7 +80,7 @@ class FileService: mime_type=mimetype, created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER), created_by=user.id, - created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + created_at=naive_utc_now(), used=False, hash=hashlib.sha3_256(content).hexdigest(), source_url=source_url, @@ -131,10 +131,10 @@ class FileService: mime_type="text/plain", created_by=current_user.id, created_by_role=CreatorUserRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + created_at=naive_utc_now(), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + used_at=naive_utc_now(), ) db.session.add(upload_file) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 519d5abca5..5a3f504035 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -77,7 +77,7 @@ class HitTestingService: ) end = time.perf_counter() - logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") + logging.debug("Hit testing retrieve in %s seconds", end - start) dataset_query = DatasetQuery( dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id @@ -113,7 +113,7 @@ class HitTestingService: ) end = time.perf_counter() - logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds") + logging.debug("External knowledge hit testing retrieve in %s seconds", end - start) dataset_query = DatasetQuery( dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id diff --git a/api/services/message_service.py b/api/services/message_service.py index 283b7b9b4b..a19d6ee157 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -111,7 +111,8 @@ class MessageService: base_query = base_query.where(Message.conversation_id == conversation.id) - if include_ids is not None: + # Check if include_ids is not None and not empty to avoid WHERE false condition + if include_ids is not None and len(include_ids) > 0: base_query = base_query.where(Message.id.in_(include_ids)) if last_id: diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index cfcb121153..fd222f59d3 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -1,5 +1,4 @@ import copy -import datetime import logging from typing import Optional @@ -8,6 +7,7 @@ from flask_login import current_user from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ( @@ -69,7 +69,7 @@ class MetadataService: old_name = metadata.name metadata.name = name metadata.updated_by = current_user.id - metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + metadata.updated_at = naive_utc_now() # update related documents dataset_metadata_bindings = ( @@ -79,7 +79,10 @@ class MetadataService: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) for document in documents: - doc_metadata = copy.deepcopy(document.doc_metadata) + if not document.doc_metadata: + doc_metadata = {} + else: + doc_metadata = copy.deepcopy(document.doc_metadata) value = doc_metadata.pop(old_name, None) doc_metadata[name] = value document.doc_metadata = doc_metadata @@ -109,7 +112,10 @@ class MetadataService: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) for document in documents: - doc_metadata = copy.deepcopy(document.doc_metadata) + if not document.doc_metadata: + doc_metadata = {} + else: + doc_metadata = copy.deepcopy(document.doc_metadata) doc_metadata.pop(metadata.name, None) document.doc_metadata = doc_metadata db.session.add(document) @@ -137,7 +143,6 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset.id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) - dataset.built_in_field_enabled = True db.session.add(dataset) documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) if documents: @@ -153,6 +158,7 @@ class MetadataService: doc_metadata[BuiltInField.source.value] = MetadataDataSource[document.data_source_type].value document.doc_metadata = doc_metadata db.session.add(document) + dataset.built_in_field_enabled = True db.session.commit() except Exception: logging.exception("Enable built-in field failed") @@ -166,13 +172,15 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset.id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) - dataset.built_in_field_enabled = False db.session.add(dataset) documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) document_ids = [] if documents: for document in documents: - doc_metadata = copy.deepcopy(document.doc_metadata) + if not document.doc_metadata: + doc_metadata = {} + else: + doc_metadata = copy.deepcopy(document.doc_metadata) doc_metadata.pop(BuiltInField.document_name.value, None) doc_metadata.pop(BuiltInField.uploader.value, None) doc_metadata.pop(BuiltInField.upload_date.value, None) @@ -181,6 +189,7 @@ class MetadataService: document.doc_metadata = doc_metadata db.session.add(document) document_ids.append(document.id) + dataset.built_in_field_enabled = False db.session.commit() except Exception: logging.exception("Disable built-in field failed") diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index a200cfa146..f8dd70c790 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,4 +1,3 @@ -import datetime import json import logging from json import JSONDecodeError @@ -17,6 +16,7 @@ from core.model_runtime.entities.provider_entities import ( from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.provider import LoadBalancingModelConfig logger = logging.getLogger(__name__) @@ -340,7 +340,7 @@ class ModelLoadBalancingService: config_id = str(config_id) if config_id not in current_load_balancing_configs_dict: - raise ValueError("Invalid load balancing config id: {}".format(config_id)) + raise ValueError(f"Invalid load balancing config id: {config_id}") updated_config_ids.add(config_id) @@ -349,7 +349,7 @@ class ModelLoadBalancingService: # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: - raise ValueError("Load balancing config name {} already exists".format(name)) + raise ValueError(f"Load balancing config name {name} already exists") if credentials: if not isinstance(credentials, dict): @@ -371,7 +371,7 @@ class ModelLoadBalancingService: load_balancing_config.name = name load_balancing_config.enabled = enabled - load_balancing_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + load_balancing_config.updated_at = naive_utc_now() db.session.commit() self._clear_credentials_cache(tenant_id, config_id) @@ -383,7 +383,7 @@ class ModelLoadBalancingService: # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.name == name: - raise ValueError("Load balancing config name {} already exists".format(name)) + raise ValueError(f"Load balancing config name {name} already exists") if not credentials: raise ValueError("Invalid load balancing config credentials") diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 0a0a5619e1..54197bf949 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -380,7 +380,7 @@ class ModelProviderService: else None ) except Exception as e: - logger.debug(f"get_default_model_of_model_type error: {e}") + logger.debug("get_default_model_of_model_type error: %s", e) return None def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 62f37c1588..7a9db7273e 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -65,9 +65,7 @@ class OpsService: } ) except Exception: - new_decrypt_tracing_config.update( - {"project_url": "{host}/".format(host=decrypt_tracing_config.get("host"))} - ) + new_decrypt_tracing_config.update({"project_url": f"{decrypt_tracing_config.get('host')}/"}) if tracing_provider == "langsmith" and ( "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") @@ -139,7 +137,7 @@ class OpsService: project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) elif tracing_provider == "langfuse": project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) - project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key) + project_url = f"{tracing_config.get('host')}/project/{project_key}" elif tracing_provider in ("langsmith", "opik"): project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) else: diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py index 5324036414..c5ad65ec87 100644 --- a/api/services/plugin/data_migration.py +++ b/api/services/plugin/data_migration.py @@ -2,6 +2,7 @@ import json import logging import click +import sqlalchemy as sa from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID from models.engine import db @@ -38,7 +39,7 @@ class PluginDataMigration: where {provider_column_name} not like '%/%' and {provider_column_name} is not null and {provider_column_name} != '' limit 1000""" with db.engine.begin() as conn: - rs = conn.execute(db.text(sql)) + rs = conn.execute(sa.text(sql)) current_iter_count = 0 for i in rs: @@ -94,7 +95,7 @@ limit 1000""" :provider_name {update_retrieval_model_sql} where id = :record_id""" - conn.execute(db.text(sql), params) + conn.execute(sa.text(sql), params) click.echo( click.style( f"[{processed_count}] Migrated [{table_name}] {record_id} ({provider_name})", @@ -110,7 +111,7 @@ limit 1000""" ) ) logger.exception( - f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})" + "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name ) continue @@ -148,7 +149,7 @@ limit 1000""" params = {"last_id": last_id or ""} with db.engine.begin() as conn: - rs = conn.execute(db.text(sql), params) + rs = conn.execute(sa.text(sql), params) current_iter_count = 0 batch_updates = [] @@ -183,7 +184,7 @@ limit 1000""" ) ) logger.exception( - f"[{processed_count}] Failed to migrate [{table_name}] {record_id} ({provider_name})" + "[%s] Failed to migrate [%s] %s (%s)", processed_count, table_name, record_id, provider_name ) continue @@ -193,7 +194,7 @@ limit 1000""" SET {provider_column_name} = :updated_value WHERE id = :record_id """ - conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates]) + conn.execute(sa.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates]) click.echo( click.style( f"[{processed_count}] Batch migrated [{len(batch_updates)}] records from [{table_name}]", diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index b84dd0afc5..055fbb8138 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -47,7 +47,9 @@ class OAuthProxyService(BasePluginClient): if not context_id: raise ValueError("context_id is required") # get data from redis - data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}") + key = f"{OAuthProxyService.__KEY_PREFIX__}{context_id}" + data = redis_client.get(key) if not data: raise ValueError("context_id is invalid") + redis_client.delete(key) return json.loads(data) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 1806fbcfd6..221069b2b3 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -9,6 +9,7 @@ from typing import Any, Optional from uuid import uuid4 import click +import sqlalchemy as sa import tqdm from flask import Flask, current_app from sqlalchemy.orm import Session @@ -78,7 +79,7 @@ class PluginMigration: ) ) except Exception: - logger.exception(f"Failed to process tenant {tenant_id}") + logger.exception("Failed to process tenant %s", tenant_id) futures = [] @@ -136,7 +137,7 @@ class PluginMigration: try: tenants.append(tenant_id) except Exception: - logger.exception(f"Failed to process tenant {tenant_id}") + logger.exception("Failed to process tenant %s", tenant_id) continue futures.append( @@ -197,7 +198,7 @@ class PluginMigration: """ with Session(db.engine) as session: rs = session.execute( - db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id} + sa.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id} ) result = [] for row in rs: @@ -273,7 +274,7 @@ class PluginMigration: result.append(ToolProviderID(tool_entity.provider_id).plugin_id) except Exception: - logger.exception(f"Failed to process tool {tool}") + logger.exception("Failed to process tool %s", tool) continue return result @@ -301,7 +302,7 @@ class PluginMigration: plugins: dict[str, str] = {} plugin_ids = [] plugin_not_exist = [] - logger.info(f"Extracting unique plugins from {extracted_plugins}") + logger.info("Extracting unique plugins from %s", extracted_plugins) with open(extracted_plugins) as f: for line in f: data = json.loads(line) @@ -318,7 +319,7 @@ class PluginMigration: else: plugin_not_exist.append(plugin_id) except Exception: - logger.exception(f"Failed to fetch plugin unique identifier for {plugin_id}") + logger.exception("Failed to fetch plugin unique identifier for %s", plugin_id) plugin_not_exist.append(plugin_id) with ThreadPoolExecutor(max_workers=10) as executor: @@ -339,7 +340,7 @@ class PluginMigration: # use a fake tenant id to install all the plugins fake_tenant_id = uuid4().hex - logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}") + logger.info("Installing %s plugin instances for fake tenant %s", len(plugins["plugins"]), fake_tenant_id) thread_pool = ThreadPoolExecutor(max_workers=workers) @@ -348,7 +349,7 @@ class PluginMigration: plugin_install_failed.extend(response.get("failed", [])) def install(tenant_id: str, plugin_ids: list[str]) -> None: - logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}") + logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id) # fetch plugin already installed installed_plugins = manager.list_plugins(tenant_id) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] @@ -408,7 +409,7 @@ class PluginMigration: installation = manager.list_plugins(fake_tenant_id) except Exception: - logger.exception(f"Failed to get installation for tenant {fake_tenant_id}") + logger.exception("Failed to get installation for tenant %s", fake_tenant_id) Path(output_file).write_text( json.dumps( @@ -491,7 +492,9 @@ class PluginMigration: else: failed.append(reverse_map[plugin.plugin_unique_identifier]) logger.error( - f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}" + "Failed to install plugin %s, error: %s", + plugin.plugin_unique_identifier, + plugin.message, ) done = True diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 80e1aefc01..85f3a02825 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -20,7 +20,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): try: result = self.fetch_recommended_app_detail_from_dify_official(app_id) except Exception as e: - logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") + logger.warning("fetch recommended app detail from dify official failed: %s, switch to built-in.", e) result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id) return result @@ -28,7 +28,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): try: result = self.fetch_recommended_apps_from_dify_official(language) except Exception as e: - logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") + logger.warning("fetch recommended apps from dify official failed: %s, switch to built-in.", e) result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language) return result diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 75fa52a75c..2e5e96214b 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -26,6 +26,9 @@ class TagService: @staticmethod def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: + # Check if tag_ids is not empty to avoid WHERE false condition + if not tag_ids or len(tag_ids) == 0: + return [] tags = ( db.session.query(Tag) .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) @@ -34,6 +37,9 @@ class TagService: if not tags: return [] tag_ids = [tag.id for tag in tags] + # Check if tag_ids is not empty to avoid WHERE false condition + if not tag_ids or len(tag_ids) == 0: + return [] tag_bindings = ( db.session.query(TagBinding.target_id) .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 65f05d2986..da0fc58566 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -337,7 +337,7 @@ class BuiltinToolManageService: max_number = max(numbers) return f"{default_pattern} {max_number + 1}" except Exception as e: - logger.warning(f"Error generating next provider name for {provider}: {str(e)}") + logger.warning("Error generating next provider name for %s: %s", provider, str(e)) # fallback return f"{credential_type.get_name()} 1" @@ -508,10 +508,10 @@ class BuiltinToolManageService: oauth_params = encrypter.decrypt(user_client.oauth_params) return oauth_params - # only verified provider can use custom oauth client - is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( - tenant_id, provider.plugin_unique_identifier - ) + # only verified provider can use official oauth client + is_verified = not isinstance( + provider_controller, PluginToolProviderController + ) or PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier) if not is_verified: return oauth_params diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 23be449a5a..f45c931768 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -59,6 +59,8 @@ class MCPToolManageService: icon_type: str, icon_background: str, server_identifier: str, + timeout: float, + sse_read_timeout: float, ) -> ToolProviderApiEntity: server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() existing_provider = ( @@ -91,6 +93,8 @@ class MCPToolManageService: tools="[]", icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon, server_identifier=server_identifier, + timeout=timeout, + sse_read_timeout=sse_read_timeout, ) db.session.add(mcp_tool) db.session.commit() @@ -166,6 +170,8 @@ class MCPToolManageService: icon_type: str, icon_background: str, server_identifier: str, + timeout: float | None = None, + sse_read_timeout: float | None = None, ): mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) @@ -197,6 +203,10 @@ class MCPToolManageService: mcp_provider.tools = reconnect_result["tools"] mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] + if timeout is not None: + mcp_provider.timeout = timeout + if sse_read_timeout is not None: + mcp_provider.sse_read_timeout = sse_read_timeout db.session.commit() except IntegrityError as e: db.session.rollback() diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 59d5b50e23..f245dd7527 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.tools.entities.api_entities import ToolProviderTypeApiLiteral from core.tools.tool_manager import ToolManager @@ -9,7 +10,7 @@ logger = logging.getLogger(__name__) class ToolCommonService: @staticmethod - def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None): + def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None): """ list tool providers diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 2d192e6f7f..52fbc0979c 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -275,7 +275,7 @@ class ToolTransformService: username = user.name except Exception: - logger.exception(f"failed to get user name for api provider {db_provider.id}") + logger.exception("failed to get user name for api provider %s", db_provider.id) # add provider into providers credentials = db_provider.credentials result = ToolProviderApiEntity( diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index a9df8d0d73..8d21335c86 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -63,7 +63,7 @@ class WebAppAuthService: @classmethod def send_email_code_login_email( - cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US" ): email = account.email if account else email if email is None: diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index abf6824d73..00b02f8091 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -185,7 +185,7 @@ class WorkflowConverter: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(new_app_mode).value, - version="draft", + version=Workflow.VERSION_DRAFT, graph=json.dumps(graph), features=json.dumps(features), created_by=account_id, @@ -402,7 +402,7 @@ class WorkflowConverter: ) role_prefix = None - prompts: Any = None + prompts: Optional[Any] = None # Chat Model if model_config.mode == LLMMode.CHAT.value: diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 3164e010b4..9f01bcb668 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -1,5 +1,4 @@ import dataclasses -import datetime import logging from collections.abc import Mapping, Sequence from enum import StrEnum @@ -13,7 +12,7 @@ from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File from core.variables import Segment, StringSegment, Variable -from core.variables.consts import MIN_SELECTORS_LENGTH +from core.variables.consts import SELECTORS_LENGTH from core.variables.segments import ArrayFileSegment, FileSegment from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID @@ -23,6 +22,7 @@ from core.workflow.nodes.variable_assigner.common.helpers import get_updated_var from core.workflow.variable_loader import VariableLoader from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable +from libs.datetime_utils import naive_utc_now from models import App, Conversation from models.enums import DraftVariableType from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable @@ -147,7 +147,7 @@ class WorkflowDraftVariableService: ) -> list[WorkflowDraftVariable]: ors = [] for selector in selectors: - assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}" + assert len(selector) >= SELECTORS_LENGTH, f"Invalid selector to get: {selector}" node_id, name = selector[:2] ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) @@ -231,7 +231,7 @@ class WorkflowDraftVariableService: variable.set_name(name) if value is not None: variable.set_value(value) - variable.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + variable.last_edited_at = naive_utc_now() self._session.flush() return variable @@ -256,7 +256,7 @@ class WorkflowDraftVariableService: def _reset_node_var_or_sys_var( self, workflow: Workflow, variable: WorkflowDraftVariable ) -> WorkflowDraftVariable | None: - # If a variable does not allow updating, it makes no sence to resetting it. + # If a variable does not allow updating, it makes no sense to reset it. if not variable.editable: return variable # No execution record for this variable, delete the variable instead. @@ -422,7 +422,7 @@ class WorkflowDraftVariableService: description=conv_var.description, ) draft_conv_vars.append(draft_var) - _batch_upsert_draft_varaible( + _batch_upsert_draft_variable( self._session, draft_conv_vars, policy=_UpsertPolicy.IGNORE, @@ -434,7 +434,7 @@ class _UpsertPolicy(StrEnum): OVERWRITE = "overwrite" -def _batch_upsert_draft_varaible( +def _batch_upsert_draft_variable( session: Session, draft_vars: Sequence[WorkflowDraftVariable], policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE, @@ -478,7 +478,7 @@ def _batch_upsert_draft_varaible( "node_execution_id": stmt.excluded.node_execution_id, }, ) - elif _UpsertPolicy.IGNORE: + elif policy == _UpsertPolicy.IGNORE: stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name()) else: raise Exception("Invalid value for update policy.") @@ -608,7 +608,7 @@ class DraftVariableSaver: for item in updated_variables: selector = item.selector - if len(selector) < MIN_SELECTORS_LENGTH: + if len(selector) < SELECTORS_LENGTH: raise Exception("selector too short") # NOTE(QuantumGhost): only the following two kinds of variable could be updated by # VariableAssigner: ConversationVariable and iteration variable. @@ -721,7 +721,7 @@ class DraftVariableSaver: draft_vars = self._build_variables_from_start_mapping(outputs) else: draft_vars = self._build_variables_from_mapping(outputs) - _batch_upsert_draft_varaible(self._session, draft_vars) + _batch_upsert_draft_variable(self._session, draft_vars) @staticmethod def _should_variable_be_editable(node_id: str, name: str) -> bool: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index e9f21fc5f1..d2715a61fe 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -105,7 +105,9 @@ class WorkflowService: workflow = ( db.session.query(Workflow) .where( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == Workflow.VERSION_DRAFT, ) .first() ) @@ -127,7 +129,10 @@ class WorkflowService: if not workflow: return None if workflow.version == Workflow.VERSION_DRAFT: - raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}") + raise IsDraftWorkflowError( + f"Cannot use draft workflow version. Workflow ID: {workflow_id}. " + f"Please use a published workflow version or leave workflow_id empty." + ) return workflow def get_published_workflow(self, app_model: App) -> Optional[Workflow]: @@ -219,7 +224,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(app_model.mode).value, - version="draft", + version=Workflow.VERSION_DRAFT, graph=json.dumps(graph), features=json.dumps(features), created_by=account.id, @@ -257,7 +262,7 @@ class WorkflowService: draft_workflow_stmt = select(Workflow).where( Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, - Workflow.version == "draft", + Workflow.version == Workflow.VERSION_DRAFT, ) draft_workflow = session.scalar(draft_workflow_stmt) if not draft_workflow: @@ -382,9 +387,9 @@ class WorkflowService: tenant_id=app_model.tenant_id, ) - eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) - if eclosing_node_type_and_id: - _, enclosing_node_id = eclosing_node_type_and_id + enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) + if enclosing_node_type_and_id: + _, enclosing_node_id = enclosing_node_type_and_id else: enclosing_node_id = None @@ -439,9 +444,9 @@ class WorkflowService: self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: """ - Run draft workflow node + Run free workflow node """ - # run draft workflow node + # run free workflow node start_at = time.perf_counter() node_execution = self._handle_node_run_result( @@ -644,7 +649,7 @@ class WorkflowService: raise ValueError(f"Workflow with ID {workflow_id} not found") # Check if workflow is a draft version - if workflow.version == "draft": + if workflow.version == Workflow.VERSION_DRAFT: raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") # Check if this workflow is currently referenced by an app diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 204c1a4f5b..8834229e16 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -1,15 +1,15 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import DatasetAutoDisableLog, DocumentSegment from models.dataset import Document as DatasetDocument @@ -22,19 +22,20 @@ def add_document_to_index_task(dataset_document_id: str): Usage: add_document_to_index_task.delay(dataset_document_id) """ - logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) + logging.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green")) start_at = time.perf_counter() dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() if not dataset_document: - logging.info(click.style("Document not found: {}".format(dataset_document_id), fg="red")) + logging.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) db.session.close() return if dataset_document.indexing_status != "completed": + db.session.close() return - indexing_cache_key = "document_{}_indexing".format(dataset_document.id) + indexing_cache_key = f"document_{dataset_document.id}_indexing" try: dataset = dataset_document.dataset @@ -94,23 +95,22 @@ def add_document_to_index_task(dataset_document_id: str): DocumentSegment.enabled: True, DocumentSegment.disabled_at: None, DocumentSegment.disabled_by: None, - DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.updated_at: naive_utc_now(), } ) db.session.commit() end_at = time.perf_counter() logging.info( - click.style( - "Document added to index: {} latency: {}".format(dataset_document.id, end_at - start_at), fg="green" - ) + click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green") ) except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False - dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + dataset_document.disabled_at = naive_utc_now() dataset_document.indexing_status = "error" dataset_document.error = str(e) db.session.commit() finally: redis_client.delete(indexing_cache_key) + db.session.close() diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 2a93c21abd..5bf8e7c33e 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document @@ -25,7 +25,7 @@ def add_annotation_to_index_task( Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style("Start build index for annotation: {}".format(annotation_id), fg="green")) + logging.info(click.style(f"Start build index for annotation: {annotation_id}", fg="green")) start_at = time.perf_counter() try: @@ -50,7 +50,7 @@ def add_annotation_to_index_task( end_at = time.perf_counter() logging.info( click.style( - "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + f"Build index successful for annotation: {annotation_id} latency: {end_at - start_at}", fg="green", ) ) diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 6d48f5df89..fd33feea16 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector @@ -25,9 +25,9 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: :param user_id: user_id """ - logging.info(click.style("Start batch import annotation: {}".format(job_id), fg="green")) + logging.info(click.style(f"Start batch import annotation: {job_id}", fg="green")) start_at = time.perf_counter() - indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) + indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" # get app info app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() @@ -85,7 +85,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: except Exception as e: db.session.rollback() redis_client.setex(indexing_cache_key, 600, "error") - indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) + indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}" redis_client.setex(indexing_error_msg_key, 600, str(e)) logging.exception("Build index for batch import annotations failed") finally: diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index a6657e813a..1894031a80 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db @@ -15,7 +15,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str """ Async delete annotation index task """ - logging.info(click.style("Start delete app annotation index: {}".format(app_id), fg="green")) + logging.info(click.style(f"Start delete app annotation index: {app_id}", fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( @@ -35,9 +35,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str except Exception: logging.exception("Delete annotation index failed when annotation deleted.") end_at = time.perf_counter() - logging.info( - click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") - ) + logging.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) except Exception as e: logging.exception("Annotation deleted index failed") finally: diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 5d5d1d3ad8..a8375dfa26 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db @@ -16,25 +16,25 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): """ Async enable annotation reply task """ - logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green")) + logging.info(click.style(f"Start delete app annotations index: {app_id}", fg="green")) start_at = time.perf_counter() # get app info app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count() if not app: - logging.info(click.style("App not found: {}".format(app_id), fg="red")) + logging.info(click.style(f"App not found: {app_id}", fg="red")) db.session.close() return app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if not app_annotation_setting: - logging.info(click.style("App annotation setting not found: {}".format(app_id), fg="red")) + logging.info(click.style(f"App annotation setting not found: {app_id}", fg="red")) db.session.close() return - disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) - disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) + disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" + disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}" try: dataset = Dataset( @@ -57,13 +57,11 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): db.session.commit() end_at = time.perf_counter() - logging.info( - click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") - ) + logging.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) except Exception as e: logging.exception("Annotation batch deleted index failed") redis_client.setex(disable_app_annotation_job_key, 600, "error") - disable_app_annotation_error_key = "disable_app_annotation_error_{}".format(str(job_id)) + disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}" redis_client.setex(disable_app_annotation_error_key, 600, str(e)) finally: redis_client.delete(disable_app_annotation_key) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 12d10df442..9ffaf81af6 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService @@ -27,19 +27,19 @@ def enable_annotation_reply_task( """ Async enable annotation reply task """ - logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green")) + logging.info(click.style(f"Start add app annotation to index: {app_id}", fg="green")) start_at = time.perf_counter() # get app info app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if not app: - logging.info(click.style("App not found: {}".format(app_id), fg="red")) + logging.info(click.style(f"App not found: {app_id}", fg="red")) db.session.close() return annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all() - enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) - enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) + enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" + enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" try: documents = [] @@ -68,11 +68,11 @@ def enable_annotation_reply_task( try: old_vector.delete() except Exception as e: - logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red")) + logging.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id annotation_setting.updated_user_id = user_id - annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + annotation_setting.updated_at = naive_utc_now() db.session.add(annotation_setting) else: new_app_annotation_setting = AppAnnotationSetting( @@ -104,18 +104,16 @@ def enable_annotation_reply_task( try: vector.delete_by_metadata_field("app_id", app_id) except Exception as e: - logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red")) + logging.info(click.style(f"Delete annotation index error: {str(e)}", fg="red")) vector.create(documents) db.session.commit() redis_client.setex(enable_app_annotation_job_key, 600, "completed") end_at = time.perf_counter() - logging.info( - click.style("App annotations added to index: {} latency: {}".format(app_id, end_at - start_at), fg="green") - ) + logging.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green")) except Exception as e: logging.exception("Annotation batch created index failed") redis_client.setex(enable_app_annotation_job_key, 600, "error") - enable_app_annotation_error_key = "enable_app_annotation_error_{}".format(str(job_id)) + enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}" redis_client.setex(enable_app_annotation_error_key, 600, str(e)) db.session.rollback() finally: diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 596ba829ad..337434b768 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document @@ -25,7 +25,7 @@ def update_annotation_to_index_task( Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style("Start update index for annotation: {}".format(annotation_id), fg="green")) + logging.info(click.style(f"Start update index for annotation: {annotation_id}", fg="green")) start_at = time.perf_counter() try: @@ -51,7 +51,7 @@ def update_annotation_to_index_task( end_at = time.perf_counter() logging.info( click.style( - "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + f"Build index successful for annotation: {annotation_id} latency: {end_at - start_at}", fg="green", ) ) diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 49bff72a96..ed47b62e1b 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -49,7 +49,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form except Exception: logging.exception( "Delete image_files failed when storage deleted, \ - image_upload_file_is: {}".format(upload_file_id) + image_upload_file_is: %s", + upload_file_id, ) db.session.delete(image_file) db.session.delete(segment) @@ -61,14 +62,14 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form try: storage.delete(file.key) except Exception: - logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id)) + logging.exception("Delete file failed when document deleted, file_id: %s", file.id) db.session.delete(file) db.session.commit() end_at = time.perf_counter() logging.info( click.style( - "Cleaned documents when documents deleted latency: {}".format(end_at - start_at), + f"Cleaned documents when documents deleted latency: {end_at - start_at}", fg="green", ) ) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 64df3175e1..50293f38a7 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -1,10 +1,12 @@ -import datetime import logging +import tempfile import time import uuid +from pathlib import Path import click -from celery import shared_task # type: ignore +import pandas as pd +from celery import shared_task from sqlalchemy import func from sqlalchemy.orm import Session @@ -12,15 +14,18 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from libs import helper +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.model import UploadFile from services.vector_service import VectorService @shared_task(queue="dataset") def batch_create_segment_to_index_task( job_id: str, - content: list, + upload_file_id: str, dataset_id: str, document_id: str, tenant_id: str, @@ -29,18 +34,18 @@ def batch_create_segment_to_index_task( """ Async batch create segment to index :param job_id: - :param content: + :param upload_file_id: :param dataset_id: :param document_id: :param tenant_id: :param user_id: - Usage: batch_create_segment_to_index_task.delay(job_id, content, dataset_id, document_id, tenant_id, user_id) + Usage: batch_create_segment_to_index_task.delay(job_id, upload_file_id, dataset_id, document_id, tenant_id, user_id) """ - logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green")) + logging.info(click.style(f"Start batch create segment jobId: {job_id}", fg="green")) start_at = time.perf_counter() - indexing_cache_key = "segment_batch_import_{}".format(job_id) + indexing_cache_key = f"segment_batch_import_{job_id}" try: with Session(db.engine) as session: @@ -58,6 +63,29 @@ def batch_create_segment_to_index_task( or dataset_document.indexing_status != "completed" ): raise ValueError("Document is not available.") + + upload_file = session.get(UploadFile, upload_file_id) + if not upload_file: + raise ValueError("UploadFile not found.") + + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file.key).suffix + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore + storage.download(upload_file.key, file_path) + + # Skip the first row + df = pd.read_csv(file_path) + content = [] + for index, row in df.iterrows(): + if dataset_document.doc_form == "qa_model": + data = {"content": row.iloc[0], "answer": row.iloc[1]} + else: + data = {"content": row.iloc[0]} + content.append(data) + if len(content) == 0: + raise ValueError("The CSV file is empty.") + document_segments = [] embedding_model = None if dataset.indexing_technique == "high_quality": @@ -95,9 +123,9 @@ def batch_create_segment_to_index_task( word_count=len(content), tokens=tokens, created_by=user_id, - indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + indexing_at=naive_utc_now(), status="completed", - completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=naive_utc_now(), ) if dataset_document.doc_form == "qa_model": segment_document.answer = segment["answer"] @@ -106,6 +134,7 @@ def batch_create_segment_to_index_task( db.session.add(segment_document) document_segments.append(segment_document) # update document word count + assert dataset_document.word_count is not None dataset_document.word_count += word_count_change db.session.add(dataset_document) # add index to db @@ -115,7 +144,7 @@ def batch_create_segment_to_index_task( end_at = time.perf_counter() logging.info( click.style( - "Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), + f"Segment batch created job: {job_id} latency: {end_at - start_at}", fg="green", ) ) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index fad090141a..3d3fadbd0a 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -2,10 +2,10 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.tools.utils.rag_web_reader import get_image_upload_file_ids +from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage from models.dataset import ( @@ -42,7 +42,7 @@ def clean_dataset_task( Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style("Start clean dataset when dataset deleted: {}".format(dataset_id), fg="green")) + logging.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green")) start_at = time.perf_counter() try: @@ -56,15 +56,34 @@ def clean_dataset_task( documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() - if documents is None or len(documents) == 0: - logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green")) - else: - logging.info(click.style("Cleaning documents for dataset: {}".format(dataset_id), fg="green")) - # Specify the index type before initializing the index processor - if doc_form is None: - raise ValueError("Index type must be specified.") + # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace + # This ensures all invalid doc_form values are properly handled + if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): + # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup + from core.rag.index_processor.constant.index_type import IndexType + + doc_form = IndexType.PARAGRAPH_INDEX + logging.info( + click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow") + ) + + # Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure + # This ensures Document/Segment deletion can continue even if vector database cleanup fails + try: index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) + logging.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green")) + except Exception as index_cleanup_error: + logging.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red")) + # Continue with document and segment deletion even if vector cleanup fails + logging.info( + click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow") + ) + + if documents is None or len(documents) == 0: + logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green")) + else: + logging.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green")) for document in documents: db.session.delete(document) @@ -80,7 +99,8 @@ def clean_dataset_task( except Exception: logging.exception( "Delete image_files failed when storage deleted, \ - image_upload_file_is: {}".format(upload_file_id) + image_upload_file_is: %s", + upload_file_id, ) db.session.delete(image_file) db.session.delete(segment) @@ -115,11 +135,17 @@ def clean_dataset_task( db.session.commit() end_at = time.perf_counter() logging.info( - click.style( - "Cleaned dataset when dataset deleted: {} latency: {}".format(dataset_id, end_at - start_at), fg="green" - ) + click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green") ) except Exception: + # Add rollback to prevent dirty session state in case of exceptions + # This ensures the database session is properly cleaned up + try: + db.session.rollback() + logging.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow")) + except Exception as rollback_error: + logging.exception("Failed to rollback database session") + logging.exception("Cleaned dataset when dataset deleted failed") finally: db.session.close() diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index dd7a544ff5..c18329a9c2 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,10 +3,10 @@ import time from typing import Optional import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.tools.utils.rag_web_reader import get_image_upload_file_ids +from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment @@ -24,7 +24,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i Usage: clean_document_task.delay(document_id, dataset_id) """ - logging.info(click.style("Start clean document when document deleted: {}".format(document_id), fg="green")) + logging.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green")) start_at = time.perf_counter() try: @@ -51,7 +51,8 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i except Exception: logging.exception( "Delete image_files failed when storage deleted, \ - image_upload_file_is: {}".format(upload_file_id) + image_upload_file_is: %s", + upload_file_id, ) db.session.delete(image_file) db.session.delete(segment) @@ -63,7 +64,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i try: storage.delete(file.key) except Exception: - logging.exception("Delete file failed when document deleted, file_id: {}".format(file_id)) + logging.exception("Delete file failed when document deleted, file_id: %s", file_id) db.session.delete(file) db.session.commit() @@ -77,7 +78,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i end_at = time.perf_counter() logging.info( click.style( - "Cleaned document when document deleted: {} latency: {}".format(document_id, end_at - start_at), + f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", fg="green", ) ) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 0f72f87f15..3ad6257cda 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -19,7 +19,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): Usage: clean_notion_document_task.delay(document_ids, dataset_id) """ logging.info( - click.style("Start clean document when import form notion document deleted: {}".format(dataset_id), fg="green") + click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green") ) start_at = time.perf_counter() diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 5eda24674a..db2f69596d 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -1,15 +1,15 @@ -import datetime import logging import time from typing import Optional import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment @@ -21,26 +21,27 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] :param keywords: Usage: create_segment_to_index_task.delay(segment_id) """ - logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green")) + logging.info(click.style(f"Start create segment to index: {segment_id}", fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() if not segment: - logging.info(click.style("Segment not found: {}".format(segment_id), fg="red")) + logging.info(click.style(f"Segment not found: {segment_id}", fg="red")) db.session.close() return if segment.status != "waiting": + db.session.close() return - indexing_cache_key = "segment_{}_indexing".format(segment.id) + indexing_cache_key = f"segment_{segment.id}_indexing" try: # update segment status to indexing db.session.query(DocumentSegment).filter_by(id=segment.id).update( { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.indexing_at: naive_utc_now(), } ) db.session.commit() @@ -57,17 +58,17 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] dataset = segment.dataset if not dataset: - logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) + logging.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) + logging.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) + logging.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) return index_type = dataset.doc_form @@ -78,19 +79,17 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] db.session.query(DocumentSegment).filter_by(id=segment.id).update( { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.completed_at: naive_utc_now(), } ) db.session.commit() end_at = time.perf_counter() - logging.info( - click.style("Segment created to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") - ) + logging.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green")) except Exception as e: logging.exception("create segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 7478bf5a90..512ea1048a 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -1,8 +1,9 @@ import logging import time +from typing import Literal import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -13,14 +14,14 @@ from models.dataset import Document as DatasetDocument @shared_task(queue="dataset") -def deal_dataset_vector_index_task(dataset_id: str, action: str): +def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]): """ Async deal dataset from index :param dataset_id: dataset_id :param action: action Usage: deal_dataset_vector_index_task.delay(dataset_id, action) """ - logging.info(click.style("Start deal dataset vector index: {}".format(dataset_id), fg="green")) + logging.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green")) start_at = time.perf_counter() try: @@ -162,9 +163,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) end_at = time.perf_counter() - logging.info( - click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") - ) + logging.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green")) except Exception: logging.exception("Deal dataset vector index failed") finally: diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index d3b33e3052..29f5a2450d 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -1,6 +1,6 @@ import logging -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_database import db from models.account import Account @@ -16,11 +16,11 @@ def delete_account_task(account_id): try: BillingService.delete_account(account_id) except Exception as e: - logger.exception(f"Failed to delete account {account_id} from billing service.") + logger.exception("Failed to delete account %s from billing service.", account_id) raise if not account: - logger.error(f"Account {account_id} not found.") + logger.error("Account %s not found.", account_id) return # send success email send_deletion_success_task.delay(account.email) diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py new file mode 100644 index 0000000000..4279dd2c17 --- /dev/null +++ b/api/tasks/delete_conversation_task.py @@ -0,0 +1,68 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from extensions.ext_database import db +from models import ConversationVariable +from models.model import Message, MessageAnnotation, MessageFeedback +from models.tools import ToolConversationVariables, ToolFile +from models.web import PinnedConversation + + +@shared_task(queue="conversation") +def delete_conversation_related_data(conversation_id: str) -> None: + """ + Delete related data conversation in correct order from datatbase to respect foreign key constraints + + Args: + conversation_id: conversation Id + """ + + logging.info( + click.style(f"Starting to delete conversation data from db for conversation_id {conversation_id}", fg="green") + ) + start_at = time.perf_counter() + + try: + db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + db.session.query(ToolConversationVariables).where( + ToolConversationVariables.conversation_id == conversation_id + ).delete(synchronize_session=False) + + db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False) + + db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False) + + db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( + synchronize_session=False + ) + + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}", + fg="green", + ) + ) + + except Exception as e: + logging.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id) + db.session.rollback() + raise e + finally: + db.session.close() diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 66ff0f9a0a..f091085fb8 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -38,7 +38,7 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) end_at = time.perf_counter() - logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green")) + logging.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) except Exception: logging.exception("delete segment from index failed") finally: diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index e67ba5c76e..c813a9dca6 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -18,37 +18,37 @@ def disable_segment_from_index_task(segment_id: str): Usage: disable_segment_from_index_task.delay(segment_id) """ - logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green")) + logging.info(click.style(f"Start disable segment from index: {segment_id}", fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() if not segment: - logging.info(click.style("Segment not found: {}".format(segment_id), fg="red")) + logging.info(click.style(f"Segment not found: {segment_id}", fg="red")) db.session.close() return if segment.status != "completed": - logging.info(click.style("Segment is not completed, disable is not allowed: {}".format(segment_id), fg="red")) + logging.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red")) db.session.close() return - indexing_cache_key = "segment_{}_indexing".format(segment.id) + indexing_cache_key = f"segment_{segment.id}_indexing" try: dataset = segment.dataset if not dataset: - logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) + logging.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) + logging.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) + logging.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) return index_type = dataset_document.doc_form @@ -56,9 +56,7 @@ def disable_segment_from_index_task(segment_id: str): index_processor.clean(dataset, [segment.index_node_id]) end_at = time.perf_counter() - logging.info( - click.style("Segment removed from index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") - ) + logging.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green")) except Exception: logging.exception("remove segment from index failed") segment.enabled = True diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 0c8b1aabc7..252321ba83 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -25,18 +25,18 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + logging.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) db.session.close() return dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() if not dataset_document: - logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + logging.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) db.session.close() return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + logging.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) db.session.close() return # sync index processor @@ -61,7 +61,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) end_at = time.perf_counter() - logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) + logging.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) except Exception: # update segment error msg db.session.query(DocumentSegment).where( @@ -78,6 +78,6 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen db.session.commit() finally: for segment in segments: - indexing_cache_key = "segment_{}_indexing".format(segment.id) + indexing_cache_key = f"segment_{segment.id}_indexing" redis_client.delete(indexing_cache_key) db.session.close() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index dcc748ef18..4afd13eb13 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from models.source import DataSourceOauthBinding @@ -22,13 +22,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): Usage: document_indexing_sync_task.delay(dataset_id, document_id) """ - logging.info(click.style("Start sync document: {}".format(document_id), fg="green")) + logging.info(click.style(f"Start sync document: {document_id}", fg="green")) start_at = time.perf_counter() document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - logging.info(click.style("Document not found: {}".format(document_id), fg="red")) + logging.info(click.style(f"Document not found: {document_id}", fg="red")) db.session.close() return @@ -72,7 +72,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): # check the page is updated if last_edited_time != page_edited_time: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() db.session.commit() # delete all document segment and index @@ -108,10 +108,10 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info( - click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") - ) + logging.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - logging.exception("document_indexing_sync_task failed, document_id: {}".format(document_id)) + logging.exception("document_indexing_sync_task failed, document_id: %s", document_id) + finally: + db.session.close() diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index ec6d10d93b..c414b01d0e 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -26,7 +26,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style("Dataset is not found: {}".format(dataset_id), fg="yellow")) + logging.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow")) db.session.close() return # check document limit @@ -60,7 +60,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): return for document_id in document_ids: - logging.info(click.style("Start process document: {}".format(document_id), fg="green")) + logging.info(click.style(f"Start process document: {document_id}", fg="green")) document = ( db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() @@ -77,10 +77,10 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) + logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - logging.exception("Document indexing task failed, dataset_id: {}".format(dataset_id)) + logging.exception("Document indexing task failed, dataset_id: %s", dataset_id) finally: db.session.close() diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index e53c38ddc3..31bbc8b570 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -1,13 +1,13 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -20,18 +20,18 @@ def document_indexing_update_task(dataset_id: str, document_id: str): Usage: document_indexing_update_task.delay(dataset_id, document_id) """ - logging.info(click.style("Start update document: {}".format(document_id), fg="green")) + logging.info(click.style(f"Start update document: {document_id}", fg="green")) start_at = time.perf_counter() document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - logging.info(click.style("Document not found: {}".format(document_id), fg="red")) + logging.info(click.style(f"Document not found: {document_id}", fg="red")) db.session.close() return document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() db.session.commit() # delete all document segment and index @@ -69,10 +69,10 @@ def document_indexing_update_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) + logging.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - logging.exception("document_indexing_update_task failed, document_id: {}".format(document_id)) + logging.exception("document_indexing_update_task failed, document_id: %s", document_id) finally: db.session.close() diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index b3ddface59..f3850b7e3b 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService @@ -27,7 +27,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if dataset is None: - logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red")) + logging.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) db.session.close() return @@ -55,7 +55,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() return @@ -63,7 +63,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): db.session.close() for document_id in document_ids: - logging.info(click.style("Start process document: {}".format(document_id), fg="green")) + logging.info(click.style(f"Start process document: {document_id}", fg="green")) document = ( db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() @@ -86,7 +86,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) db.session.commit() @@ -95,10 +95,10 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) + logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - logging.exception("duplicate_document_indexing_task failed, dataset_id: {}".format(dataset_id)) + logging.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id) finally: db.session.close() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 13822f078e..a4bcc043e3 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -1,15 +1,15 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment @@ -21,21 +21,21 @@ def enable_segment_to_index_task(segment_id: str): Usage: enable_segment_to_index_task.delay(segment_id) """ - logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green")) + logging.info(click.style(f"Start enable segment to index: {segment_id}", fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() if not segment: - logging.info(click.style("Segment not found: {}".format(segment_id), fg="red")) + logging.info(click.style(f"Segment not found: {segment_id}", fg="red")) db.session.close() return if segment.status != "completed": - logging.info(click.style("Segment is not completed, enable is not allowed: {}".format(segment_id), fg="red")) + logging.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red")) db.session.close() return - indexing_cache_key = "segment_{}_indexing".format(segment.id) + indexing_cache_key = f"segment_{segment.id}_indexing" try: document = Document( @@ -51,17 +51,17 @@ def enable_segment_to_index_task(segment_id: str): dataset = segment.dataset if not dataset: - logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) + logging.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) + logging.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan")) return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) + logging.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() @@ -85,13 +85,11 @@ def enable_segment_to_index_task(segment_id: str): index_processor.load(dataset, [document]) end_at = time.perf_counter() - logging.info( - click.style("Segment enabled to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") - ) + logging.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.disabled_at = naive_utc_now() segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index e3fdf04d8c..1db984f0d3 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -1,15 +1,15 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -27,17 +27,17 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i start_at = time.perf_counter() dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + logging.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) return dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() if not dataset_document: - logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + logging.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) db.session.close() return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + logging.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan")) db.session.close() return # sync index processor @@ -53,7 +53,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i .all() ) if not segments: - logging.info(click.style("Segments not found: {}".format(segment_ids), fg="cyan")) + logging.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) db.session.close() return @@ -91,7 +91,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i index_processor.load(dataset, documents) end_at = time.perf_counter() - logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green")) + logging.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) except Exception as e: logging.exception("enable segments to index failed") # update segment error msg @@ -103,13 +103,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i { "error": str(e), "status": "error", - "disabled_at": datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + "disabled_at": naive_utc_now(), "enabled": False, } ) db.session.commit() finally: for segment in segments: - indexing_cache_key = "segment_{}_indexing".format(segment.id) + indexing_cache_key = f"segment_{segment.id}_indexing" redis_client.delete(indexing_cache_key) db.session.close() diff --git a/api/tasks/mail_account_deletion_task.py b/api/tasks/mail_account_deletion_task.py index a6f8ce2f0b..43ddbfc03b 100644 --- a/api/tasks/mail_account_deletion_task.py +++ b/api/tasks/mail_account_deletion_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service @@ -37,12 +37,10 @@ def send_deletion_success_task(to: str, language: str = "en-US") -> None: end_at = time.perf_counter() logging.info( - click.style( - "Send account deletion success email to {}: latency: {}".format(to, end_at - start_at), fg="green" - ) + click.style(f"Send account deletion success email to {to}: latency: {end_at - start_at}", fg="green") ) except Exception: - logging.exception("Send account deletion success email to {} failed".format(to)) + logging.exception("Send account deletion success email to %s failed", to) @shared_task(queue="mail") @@ -83,4 +81,4 @@ def send_account_deletion_verification_code(to: str, code: str, language: str = ) ) except Exception: - logging.exception("Send account deletion verification code email to {} failed".format(to)) + logging.exception("Send account deletion verification code email to %s failed", to) diff --git a/api/tasks/mail_change_mail_task.py b/api/tasks/mail_change_mail_task.py index 6334fb22de..a56109705a 100644 --- a/api/tasks/mail_change_mail_task.py +++ b/api/tasks/mail_change_mail_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service @@ -22,7 +22,7 @@ def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None if not mail.is_inited(): return - logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + logging.info(click.style(f"Start change email mail to {to}", fg="green")) start_at = time.perf_counter() try: @@ -35,11 +35,9 @@ def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None ) end_at = time.perf_counter() - logging.info( - click.style("Send change email mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green") - ) + logging.info(click.style(f"Send change email mail to {to} succeeded: latency: {end_at - start_at}", fg="green")) except Exception: - logging.exception("Send change email mail to {} failed".format(to)) + logging.exception("Send change email mail to %s failed", to) @shared_task(queue="mail") @@ -54,7 +52,7 @@ def send_change_mail_completed_notification_task(language: str, to: str) -> None if not mail.is_inited(): return - logging.info(click.style("Start change email completed notify mail to {}".format(to), fg="green")) + logging.info(click.style(f"Start change email completed notify mail to {to}", fg="green")) start_at = time.perf_counter() try: @@ -72,9 +70,9 @@ def send_change_mail_completed_notification_task(language: str, to: str) -> None end_at = time.perf_counter() logging.info( click.style( - "Send change email completed mail to {} succeeded: latency: {}".format(to, end_at - start_at), + f"Send change email completed mail to {to} succeeded: latency: {end_at - start_at}", fg="green", ) ) except Exception: - logging.exception("Send change email completed mail to {} failed".format(to)) + logging.exception("Send change email completed mail to %s failed", to) diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index 34220784e9..53ea3709cd 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service @@ -21,7 +21,7 @@ def send_email_code_login_mail_task(language: str, to: str, code: str) -> None: if not mail.is_inited(): return - logging.info(click.style("Start email code login mail to {}".format(to), fg="green")) + logging.info(click.style(f"Start email code login mail to {to}", fg="green")) start_at = time.perf_counter() try: @@ -38,9 +38,7 @@ def send_email_code_login_mail_task(language: str, to: str, code: str) -> None: end_at = time.perf_counter() logging.info( - click.style( - "Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" - ) + click.style(f"Send email code login mail to {to} succeeded: latency: {end_at - start_at}", fg="green") ) except Exception: - logging.exception("Send email code login mail to {} failed".format(to)) + logging.exception("Send email code login mail to %s failed", to) diff --git a/api/tasks/mail_enterprise_task.py b/api/tasks/mail_inner_task.py similarity index 54% rename from api/tasks/mail_enterprise_task.py rename to api/tasks/mail_inner_task.py index a1c2908624..cad4657bc8 100644 --- a/api/tasks/mail_enterprise_task.py +++ b/api/tasks/mail_inner_task.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping import click -from celery import shared_task # type: ignore +from celery import shared_task from flask import render_template_string from extensions.ext_mail import mail @@ -11,11 +11,11 @@ from libs.email_i18n import get_email_i18n_service @shared_task(queue="mail") -def send_enterprise_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]): +def send_inner_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]): if not mail.is_inited(): return - logging.info(click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green")) + logging.info(click.style(f"Start enterprise mail to {to} with subject {subject}", fg="green")) start_at = time.perf_counter() try: @@ -25,8 +25,6 @@ def send_enterprise_email_task(to: list[str], subject: str, body: str, substitut email_service.send_raw_email(to=to, subject=subject, html_content=html_content) end_at = time.perf_counter() - logging.info( - click.style("Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green") - ) + logging.info(click.style(f"Send enterprise mail to {to} succeeded: latency: {end_at - start_at}", fg="green")) except Exception: - logging.exception("Send enterprise mail to {} failed".format(to)) + logging.exception("Send enterprise mail to %s failed", to) diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index 8c73de0111..f4f7f58416 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from configs import dify_config from extensions.ext_mail import mail @@ -24,9 +24,7 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam if not mail.is_inited(): return - logging.info( - click.style("Start send invite member mail to {} in workspace {}".format(to, workspace_name), fg="green") - ) + logging.info(click.style(f"Start send invite member mail to {to} in workspace {workspace_name}", fg="green")) start_at = time.perf_counter() try: @@ -46,9 +44,7 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam end_at = time.perf_counter() logging.info( - click.style( - "Send invite member mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" - ) + click.style(f"Send invite member mail to {to} succeeded: latency: {end_at - start_at}", fg="green") ) except Exception: - logging.exception("Send invite member mail to {} failed".format(to)) + logging.exception("Send invite member mail to %s failed", to) diff --git a/api/tasks/mail_owner_transfer_task.py b/api/tasks/mail_owner_transfer_task.py index e566a6bc56..db7158e786 100644 --- a/api/tasks/mail_owner_transfer_task.py +++ b/api/tasks/mail_owner_transfer_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service @@ -22,7 +22,7 @@ def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspac if not mail.is_inited(): return - logging.info(click.style("Start owner transfer confirm mail to {}".format(to), fg="green")) + logging.info(click.style(f"Start owner transfer confirm mail to {to}", fg="green")) start_at = time.perf_counter() try: @@ -41,12 +41,12 @@ def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspac end_at = time.perf_counter() logging.info( click.style( - "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at), + f"Send owner transfer confirm mail to {to} succeeded: latency: {end_at - start_at}", fg="green", ) ) except Exception: - logging.exception("owner transfer confirm email mail to {} failed".format(to)) + logging.exception("owner transfer confirm email mail to %s failed", to) @shared_task(queue="mail") @@ -63,7 +63,7 @@ def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: if not mail.is_inited(): return - logging.info(click.style("Start old owner transfer notify mail to {}".format(to), fg="green")) + logging.info(click.style(f"Start old owner transfer notify mail to {to}", fg="green")) start_at = time.perf_counter() try: @@ -82,12 +82,12 @@ def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: end_at = time.perf_counter() logging.info( click.style( - "Send old owner transfer notify mail to {} succeeded: latency: {}".format(to, end_at - start_at), + f"Send old owner transfer notify mail to {to} succeeded: latency: {end_at - start_at}", fg="green", ) ) except Exception: - logging.exception("old owner transfer notify email mail to {} failed".format(to)) + logging.exception("old owner transfer notify email mail to %s failed", to) @shared_task(queue="mail") @@ -103,7 +103,7 @@ def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: if not mail.is_inited(): return - logging.info(click.style("Start new owner transfer notify mail to {}".format(to), fg="green")) + logging.info(click.style(f"Start new owner transfer notify mail to {to}", fg="green")) start_at = time.perf_counter() try: @@ -121,9 +121,9 @@ def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: end_at = time.perf_counter() logging.info( click.style( - "Send new owner transfer notify mail to {} succeeded: latency: {}".format(to, end_at - start_at), + f"Send new owner transfer notify mail to {to} succeeded: latency: {end_at - start_at}", fg="green", ) ) except Exception: - logging.exception("new owner transfer notify email mail to {} failed".format(to)) + logging.exception("new owner transfer notify email mail to %s failed", to) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index e2482f2101..066d648530 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service @@ -21,7 +21,7 @@ def send_reset_password_mail_task(language: str, to: str, code: str) -> None: if not mail.is_inited(): return - logging.info(click.style("Start password reset mail to {}".format(to), fg="green")) + logging.info(click.style(f"Start password reset mail to {to}", fg="green")) start_at = time.perf_counter() try: @@ -38,9 +38,7 @@ def send_reset_password_mail_task(language: str, to: str, code: str) -> None: end_at = time.perf_counter() logging.info( - click.style( - "Send password reset mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" - ) + click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green") ) except Exception: - logging.exception("Send password reset mail to {} failed".format(to)) + logging.exception("Send password reset mail to %s failed", to) diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 2e77332ffe..a4ef60b13c 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,7 +1,7 @@ import json import logging -from celery import shared_task # type: ignore +from celery import shared_task from flask import current_app from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY @@ -43,13 +43,11 @@ def process_trace_tasks(file_info): if trace_type: trace_info = trace_type(**trace_info) trace_instance.trace(trace_info) - logging.info(f"Processing trace tasks success, app_id: {app_id}") + logging.info("Processing trace tasks success, app_id: %s", app_id) except Exception as e: - logging.info( - f"error:\n\n\n{e}\n\n\n\n", - ) + logging.info("error:\n\n\n%s\n\n\n\n", e) failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" redis_client.incr(failed_key) - logging.info(f"Processing trace tasks failed, app_id: {app_id}") + logging.info("Processing trace tasks failed, app_id: %s", app_id) finally: storage.delete(file_path) diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 6fcdad0525..ec0b534546 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -2,7 +2,7 @@ import traceback import typing import click -from celery import shared_task # type: ignore +from celery import shared_task from core.helper import marketplace from core.helper.marketplace import MarketplacePluginDeclaration @@ -58,7 +58,7 @@ def process_tenant_plugin_autoupgrade_check_task( click.echo( click.style( - "Checking upgradable plugin for tenant: {}".format(tenant_id), + f"Checking upgradable plugin for tenant: {tenant_id}", fg="green", ) ) @@ -68,7 +68,7 @@ def process_tenant_plugin_autoupgrade_check_task( # get plugin_ids to check plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier - click.echo(click.style("Upgrade mode: {}".format(upgrade_mode), fg="green")) + click.echo(click.style(f"Upgrade mode: {upgrade_mode}", fg="green")) if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins: all_plugins = manager.list_plugins(tenant_id) @@ -142,7 +142,7 @@ def process_tenant_plugin_autoupgrade_check_task( marketplace.record_install_plugin_event(new_unique_identifier) click.echo( click.style( - "Upgrade plugin: {} -> {}".format(original_unique_identifier, new_unique_identifier), + f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}", fg="green", ) ) @@ -156,11 +156,11 @@ def process_tenant_plugin_autoupgrade_check_task( }, ) except Exception as e: - click.echo(click.style("Error when upgrading plugin: {}".format(e), fg="red")) + click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red")) traceback.print_exc() break except Exception as e: - click.echo(click.style("Error when checking upgradable plugin: {}".format(e), fg="red")) + click.echo(click.style(f"Error when checking upgradable plugin: {e}", fg="red")) traceback.print_exc() return diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index dfb2389579..998fc6b32d 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -2,7 +2,7 @@ import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db @@ -18,13 +18,13 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): Usage: recover_document_indexing_task.delay(dataset_id, document_id) """ - logging.info(click.style("Recover document: {}".format(document_id), fg="green")) + logging.info(click.style(f"Recover document: {document_id}", fg="green")) start_at = time.perf_counter() document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - logging.info(click.style("Document not found: {}".format(document_id), fg="red")) + logging.info(click.style(f"Document not found: {document_id}", fg="red")) db.session.close() return @@ -37,12 +37,10 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): elif document.indexing_status == "indexing": indexing_runner.run_in_indexing_status(document) end_at = time.perf_counter() - logging.info( - click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") - ) + logging.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green")) except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - logging.exception("recover_document_indexing_task failed, document_id: {}".format(document_id)) + logging.exception("recover_document_indexing_task failed, document_id: %s", document_id) finally: db.session.close() diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 1619f8c546..3d623c09d1 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -3,7 +3,8 @@ import time from collections.abc import Callable import click -from celery import shared_task # type: ignore +import sqlalchemy as sa +from celery import shared_task from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker @@ -32,7 +33,11 @@ from models import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog +from models.workflow import ( + ConversationVariable, + Workflow, + WorkflowAppLog, +) from repositories.factory import DifyAPIRepositoryFactory @@ -61,6 +66,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_end_users(tenant_id, app_id) _delete_trace_app_configs(tenant_id, app_id) _delete_conversation_variables(app_id=app_id) + _delete_draft_variables(app_id) end_at = time.perf_counter() logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green")) @@ -90,7 +96,12 @@ def _delete_app_site(tenant_id: str, app_id: str): def del_site(site_id: str): db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) - _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") + _delete_records( + """select id from sites where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_site, + "site", + ) def _delete_app_mcp_servers(tenant_id: str, app_id: str): @@ -110,7 +121,10 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( - """select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token" + """select id from api_tokens where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_api_token, + "api token", ) @@ -201,7 +215,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): batch_size=1000, ) - logging.info(f"Deleted {deleted_count} workflow runs for app {app_id}") + logging.info("Deleted %s workflow runs for app %s", deleted_count, app_id) def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): @@ -215,7 +229,7 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): batch_size=1000, ) - logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}") + logging.info("Deleted %s workflow node executions for app %s", deleted_count, app_id) def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): @@ -272,7 +286,10 @@ def _delete_app_messages(tenant_id: str, app_id: str): db.session.query(Message).where(Message.id == message_id).delete() _delete_records( - """select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message" + """select id from messages where app_id=:app_id limit 1000""", + {"app_id": app_id}, + del_message, + "message", ) @@ -328,10 +345,60 @@ def _delete_trace_app_configs(tenant_id: str, app_id: str): ) +def _delete_draft_variables(app_id: str): + """Delete all workflow draft variables for an app in batches.""" + return delete_draft_variables_batch(app_id, batch_size=1000) + + +def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: + """ + Delete draft variables for an app in batches. + + Args: + app_id: The ID of the app whose draft variables should be deleted + batch_size: Number of records to delete per batch + + Returns: + Total number of records deleted + """ + if batch_size <= 0: + raise ValueError("batch_size must be positive") + + total_deleted = 0 + + while True: + with db.engine.begin() as conn: + # Get a batch of draft variable IDs + query_sql = """ + SELECT id FROM workflow_draft_variables + WHERE app_id = :app_id + LIMIT :batch_size + """ + result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size}) + + draft_var_ids = [row[0] for row in result] + if not draft_var_ids: + break + + # Delete the batch + delete_sql = """ + DELETE FROM workflow_draft_variables + WHERE id IN :ids + """ + deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}) + batch_deleted = deleted_result.rowcount + total_deleted += batch_deleted + + logging.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green")) + + logging.info(click.style(f"Deleted {total_deleted} total draft variables for app {app_id}", fg="green")) + return total_deleted + + def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: while True: with db.engine.begin() as conn: - rs = conn.execute(db.text(query_sql), params) + rs = conn.execute(sa.text(query_sql), params) if rs.rowcount == 0: break @@ -342,6 +409,6 @@ def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: s db.session.commit() logging.info(click.style(f"Deleted {name} {record_id}", fg="green")) except Exception: - logging.exception(f"Error occurred while deleting {name} {record_id}") + logging.exception("Error occurred while deleting %s %s", name, record_id) continue rs.close() diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 3f73cc7b40..6356b1c46c 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -1,13 +1,13 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Document, DocumentSegment @@ -19,21 +19,21 @@ def remove_document_from_index_task(document_id: str): Usage: remove_document_from_index.delay(document_id) """ - logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green")) + logging.info(click.style(f"Start remove document segments from index: {document_id}", fg="green")) start_at = time.perf_counter() document = db.session.query(Document).where(Document.id == document_id).first() if not document: - logging.info(click.style("Document not found: {}".format(document_id), fg="red")) + logging.info(click.style(f"Document not found: {document_id}", fg="red")) db.session.close() return if document.indexing_status != "completed": - logging.info(click.style("Document is not completed, remove is not allowed: {}".format(document_id), fg="red")) + logging.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red")) db.session.close() return - indexing_cache_key = "document_{}_indexing".format(document.id) + indexing_cache_key = f"document_{document.id}_indexing" try: dataset = document.dataset @@ -49,23 +49,21 @@ def remove_document_from_index_task(document_id: str): try: index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) except Exception: - logging.exception(f"clean dataset {dataset.id} from index failed") + logging.exception("clean dataset %s from index failed", dataset.id) # update segment to disable db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update( { DocumentSegment.enabled: False, - DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.disabled_at: naive_utc_now(), DocumentSegment.disabled_by: document.disabled_by, - DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.updated_at: naive_utc_now(), } ) db.session.commit() end_at = time.perf_counter() logging.info( - click.style( - "Document removed from index: {} latency: {}".format(document.id, end_at - start_at), fg="green" - ) + click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green") ) except Exception: logging.exception("remove document from index failed") diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 58f0156afb..67af857f40 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService @@ -24,79 +24,83 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): """ documents: list[Document] = [] start_at = time.perf_counter() + try: + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) + return + tenant_id = dataset.tenant_id + for document_id in document_ids: + retry_indexing_cache_key = f"document_{document_id}_is_retried" + # check document limit + features = FeatureService.get_features(tenant_id) + try: + if features.billing.enabled: + vector_space = features.vector_space + if 0 < vector_space.limit <= vector_space.size: + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) + except Exception as e: + document = ( + db.session.query(Document) + .where(Document.id == document_id, Document.dataset_id == dataset_id) + .first() + ) + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() + db.session.add(document) + db.session.commit() + redis_client.delete(retry_indexing_cache_key) + return - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red")) - db.session.close() - return - tenant_id = dataset.tenant_id - for document_id in document_ids: - retry_indexing_cache_key = "document_{}_is_retried".format(document_id) - # check document limit - features = FeatureService.get_features(tenant_id) - try: - if features.billing.enabled: - vector_space = features.vector_space - if 0 < vector_space.limit <= vector_space.size: - raise ValueError( - "Your total number of documents plus the number of uploads have over the limit of " - "your subscription." - ) - except Exception as e: + logging.info(click.style(f"Start retry document: {document_id}", fg="green")) document = ( db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) - if document: - document.indexing_status = "error" - document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + if not document: + logging.info(click.style(f"Document not found: {document_id}", fg="yellow")) + return + try: + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + + segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + db.session.delete(segment) + db.session.commit() + + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() db.session.add(document) db.session.commit() - redis_client.delete(retry_indexing_cache_key) - db.session.close() - return - logging.info(click.style("Start retry document: {}".format(document_id), fg="green")) - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(retry_indexing_cache_key) + except Exception as ex: + document.indexing_status = "error" + document.error = str(ex) + document.stopped_at = naive_utc_now() + db.session.add(document) + db.session.commit() + logging.info(click.style(str(ex), fg="yellow")) + redis_client.delete(retry_indexing_cache_key) + logging.exception("retry_document_indexing_task failed, document_id: %s", document_id) + end_at = time.perf_counter() + logging.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except Exception as e: + logging.exception( + "retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids ) - if not document: - logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) - db.session.close() - return - try: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - db.session.commit() - - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.add(document) - db.session.commit() - - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(retry_indexing_cache_key) - except Exception as ex: - document.indexing_status = "error" - document.error = str(ex) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - db.session.add(document) - db.session.commit() - logging.info(click.style(str(ex), fg="yellow")) - redis_client.delete(retry_indexing_cache_key) - logging.exception("retry_document_indexing_task failed, document_id: {}".format(document_id)) - finally: - db.session.close() - end_at = time.perf_counter() - logging.info(click.style("Retry dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) + raise e + finally: + db.session.close() diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 539c2db80f..ad782f9b88 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -1,14 +1,14 @@ -import datetime import logging import time import click -from celery import shared_task # type: ignore +from celery import shared_task from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService @@ -28,7 +28,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if dataset is None: raise ValueError("Dataset not found") - sync_indexing_cache_key = "document_{}_is_sync".format(document_id) + sync_indexing_cache_key = f"document_{document_id}_is_sync" # check document limit features = FeatureService.get_features(dataset.tenant_id) try: @@ -46,16 +46,16 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() redis_client.delete(sync_indexing_cache_key) return - logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) + logging.info(click.style(f"Start sync website document: {document_id}", fg="green")) document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + logging.info(click.style(f"Document not found: {document_id}", fg="yellow")) return try: # clean old data @@ -72,7 +72,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() db.session.add(document) db.session.commit() @@ -82,11 +82,11 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(sync_indexing_cache_key) - logging.exception("sync_website_document_indexing_task failed, document_id: {}".format(document_id)) + logging.exception("sync_website_document_indexing_task failed, document_id: %s", document_id) end_at = time.perf_counter() - logging.info(click.style("Sync document: {} latency: {}".format(document_id, end_at - start_at), fg="green")) + logging.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py new file mode 100644 index 0000000000..77ddf83023 --- /dev/null +++ b/api/tasks/workflow_execution_tasks.py @@ -0,0 +1,136 @@ +""" +Celery tasks for asynchronous workflow execution storage operations. + +These tasks provide asynchronous storage capabilities for workflow execution data, +improving performance by offloading storage operations to background workers. +""" + +import json +import logging + +from celery import shared_task +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_execution import WorkflowExecution +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from extensions.ext_database import db +from models import CreatorUserRole, WorkflowRun +from models.enums import WorkflowRunTriggeredFrom + +logger = logging.getLogger(__name__) + + +@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) +def save_workflow_execution_task( + self, + execution_data: dict, + tenant_id: str, + app_id: str, + triggered_from: str, + creator_user_id: str, + creator_user_role: str, +) -> bool: + """ + Asynchronously save or update a workflow execution to the database. + + Args: + execution_data: Serialized WorkflowExecution data + tenant_id: Tenant ID for multi-tenancy + app_id: Application ID + triggered_from: Source of the execution trigger + creator_user_id: ID of the user who created the execution + creator_user_role: Role of the user who created the execution + + Returns: + True if successful, False otherwise + """ + try: + # Create a new session for this task + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + with session_factory() as session: + # Deserialize execution data + execution = WorkflowExecution.model_validate(execution_data) + + # Check if workflow run already exists + existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_)) + + if existing_run: + # Update existing workflow run + _update_workflow_run_from_execution(existing_run, execution) + logger.debug("Updated existing workflow run: %s", execution.id_) + else: + # Create new workflow run + workflow_run = _create_workflow_run_from_execution( + execution=execution, + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom(triggered_from), + creator_user_id=creator_user_id, + creator_user_role=CreatorUserRole(creator_user_role), + ) + session.add(workflow_run) + logger.debug("Created new workflow run: %s", execution.id_) + + session.commit() + return True + + except Exception as e: + logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown")) + # Retry the task with exponential backoff + raise self.retry(exc=e, countdown=60 * (2**self.request.retries)) + + +def _create_workflow_run_from_execution( + execution: WorkflowExecution, + tenant_id: str, + app_id: str, + triggered_from: WorkflowRunTriggeredFrom, + creator_user_id: str, + creator_user_role: CreatorUserRole, +) -> WorkflowRun: + """ + Create a WorkflowRun database model from a WorkflowExecution domain entity. + """ + workflow_run = WorkflowRun() + workflow_run.id = execution.id_ + workflow_run.tenant_id = tenant_id + workflow_run.app_id = app_id + workflow_run.workflow_id = execution.workflow_id + workflow_run.type = execution.workflow_type.value + workflow_run.triggered_from = triggered_from.value + workflow_run.version = execution.workflow_version + json_converter = WorkflowRuntimeTypeConverter() + workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph)) + workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) + workflow_run.status = execution.status.value + workflow_run.outputs = ( + json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" + ) + workflow_run.error = execution.error_message + workflow_run.elapsed_time = execution.elapsed_time + workflow_run.total_tokens = execution.total_tokens + workflow_run.total_steps = execution.total_steps + workflow_run.created_by_role = creator_user_role.value + workflow_run.created_by = creator_user_id + workflow_run.created_at = execution.started_at + workflow_run.finished_at = execution.finished_at + + return workflow_run + + +def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None: + """ + Update a WorkflowRun database model from a WorkflowExecution domain entity. + """ + json_converter = WorkflowRuntimeTypeConverter() + workflow_run.status = execution.status.value + workflow_run.outputs = ( + json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" + ) + workflow_run.error = execution.error_message + workflow_run.elapsed_time = execution.elapsed_time + workflow_run.total_tokens = execution.total_tokens + workflow_run.total_steps = execution.total_steps + workflow_run.finished_at = execution.finished_at diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py new file mode 100644 index 0000000000..16356086cf --- /dev/null +++ b/api/tasks/workflow_node_execution_tasks.py @@ -0,0 +1,171 @@ +""" +Celery tasks for asynchronous workflow node execution storage operations. + +These tasks provide asynchronous storage capabilities for workflow node execution data, +improving performance by offloading storage operations to background workers. +""" + +import json +import logging + +from celery import shared_task +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from extensions.ext_database import db +from models import CreatorUserRole, WorkflowNodeExecutionModel +from models.workflow import WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) +def save_workflow_node_execution_task( + self, + execution_data: dict, + tenant_id: str, + app_id: str, + triggered_from: str, + creator_user_id: str, + creator_user_role: str, +) -> bool: + """ + Asynchronously save or update a workflow node execution to the database. + + Args: + execution_data: Serialized WorkflowNodeExecution data + tenant_id: Tenant ID for multi-tenancy + app_id: Application ID + triggered_from: Source of the execution trigger + creator_user_id: ID of the user who created the execution + creator_user_role: Role of the user who created the execution + + Returns: + True if successful, False otherwise + """ + try: + # Create a new session for this task + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + with session_factory() as session: + # Deserialize execution data + execution = WorkflowNodeExecution.model_validate(execution_data) + + # Check if node execution already exists + existing_execution = session.scalar( + select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id) + ) + + if existing_execution: + # Update existing node execution + _update_node_execution_from_domain(existing_execution, execution) + logger.debug("Updated existing workflow node execution: %s", execution.id) + else: + # Create new node execution + node_execution = _create_node_execution_from_domain( + execution=execution, + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from), + creator_user_id=creator_user_id, + creator_user_role=CreatorUserRole(creator_user_role), + ) + session.add(node_execution) + logger.debug("Created new workflow node execution: %s", execution.id) + + session.commit() + return True + + except Exception as e: + logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown")) + # Retry the task with exponential backoff + raise self.retry(exc=e, countdown=60 * (2**self.request.retries)) + + +def _create_node_execution_from_domain( + execution: WorkflowNodeExecution, + tenant_id: str, + app_id: str, + triggered_from: WorkflowNodeExecutionTriggeredFrom, + creator_user_id: str, + creator_user_role: CreatorUserRole, +) -> WorkflowNodeExecutionModel: + """ + Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity. + """ + node_execution = WorkflowNodeExecutionModel() + node_execution.id = execution.id + node_execution.tenant_id = tenant_id + node_execution.app_id = app_id + node_execution.workflow_id = execution.workflow_id + node_execution.triggered_from = triggered_from.value + node_execution.workflow_run_id = execution.workflow_execution_id + node_execution.index = execution.index + node_execution.predecessor_node_id = execution.predecessor_node_id + node_execution.node_id = execution.node_id + node_execution.node_type = execution.node_type.value + node_execution.title = execution.title + node_execution.node_execution_id = execution.node_execution_id + + # Serialize complex data as JSON + json_converter = WorkflowRuntimeTypeConverter() + node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}" + node_execution.process_data = ( + json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}" + ) + node_execution.outputs = ( + json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" + ) + # Convert metadata enum keys to strings for JSON serialization + if execution.metadata: + metadata_for_json = { + key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items() + } + node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json)) + else: + node_execution.execution_metadata = "{}" + + node_execution.status = execution.status.value + node_execution.error = execution.error + node_execution.elapsed_time = execution.elapsed_time + node_execution.created_by_role = creator_user_role.value + node_execution.created_by = creator_user_id + node_execution.created_at = execution.created_at + node_execution.finished_at = execution.finished_at + + return node_execution + + +def _update_node_execution_from_domain( + node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution +) -> None: + """ + Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity. + """ + # Update serialized data + json_converter = WorkflowRuntimeTypeConverter() + node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}" + node_execution.process_data = ( + json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}" + ) + node_execution.outputs = ( + json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" + ) + # Convert metadata enum keys to strings for JSON serialization + if execution.metadata: + metadata_for_json = { + key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items() + } + node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json)) + else: + node_execution.execution_metadata = "{}" + + # Update other fields + node_execution.status = execution.status.value + node_execution.error = execution.error + node_execution.elapsed_time = execution.elapsed_time + node_execution.finished_at = execution.finished_at diff --git a/api/tests/integration_tests/controllers/console/app/test_description_validation.py b/api/tests/integration_tests/controllers/console/app/test_description_validation.py new file mode 100644 index 0000000000..2d0ceac760 --- /dev/null +++ b/api/tests/integration_tests/controllers/console/app/test_description_validation.py @@ -0,0 +1,168 @@ +""" +Unit tests for App description validation functions. + +This test module validates the 400-character limit enforcement +for App descriptions across all creation and editing endpoints. +""" + +import os +import sys + +import pytest + +# Add the API root to Python path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + + +class TestAppDescriptionValidationUnit: + """Unit tests for description validation function""" + + def test_validate_description_length_function(self): + """Test the _validate_description_length function directly""" + from controllers.console.app.app import _validate_description_length + + # Test valid descriptions + assert _validate_description_length("") == "" + assert _validate_description_length("x" * 400) == "x" * 400 + assert _validate_description_length(None) is None + + # Test invalid descriptions + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 401) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 500) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + _validate_description_length("x" * 1000) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_validation_consistency_with_dataset(self): + """Test that App and Dataset validation functions are consistent""" + from controllers.console.app.app import _validate_description_length as app_validate + from controllers.console.datasets.datasets import _validate_description_length as dataset_validate + from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate + + # Test same valid inputs + valid_desc = "x" * 400 + assert app_validate(valid_desc) == dataset_validate(valid_desc) == service_dataset_validate(valid_desc) + assert app_validate("") == dataset_validate("") == service_dataset_validate("") + assert app_validate(None) == dataset_validate(None) == service_dataset_validate(None) + + # Test same invalid inputs produce same error + invalid_desc = "x" * 401 + + app_error = None + dataset_error = None + service_dataset_error = None + + try: + app_validate(invalid_desc) + except ValueError as e: + app_error = str(e) + + try: + dataset_validate(invalid_desc) + except ValueError as e: + dataset_error = str(e) + + try: + service_dataset_validate(invalid_desc) + except ValueError as e: + service_dataset_error = str(e) + + assert app_error == dataset_error == service_dataset_error + assert app_error == "Description cannot exceed 400 characters." + + def test_boundary_values(self): + """Test boundary values for description validation""" + from controllers.console.app.app import _validate_description_length + + # Test exact boundary + exactly_400 = "x" * 400 + assert _validate_description_length(exactly_400) == exactly_400 + + # Test just over boundary + just_over_400 = "x" * 401 + with pytest.raises(ValueError): + _validate_description_length(just_over_400) + + # Test just under boundary + just_under_400 = "x" * 399 + assert _validate_description_length(just_under_400) == just_under_400 + + def test_edge_cases(self): + """Test edge cases for description validation""" + from controllers.console.app.app import _validate_description_length + + # Test None input + assert _validate_description_length(None) is None + + # Test empty string + assert _validate_description_length("") == "" + + # Test single character + assert _validate_description_length("a") == "a" + + # Test unicode characters + unicode_desc = "测试" * 200 # 400 characters in Chinese + assert _validate_description_length(unicode_desc) == unicode_desc + + # Test unicode over limit + unicode_over = "测试" * 201 # 402 characters + with pytest.raises(ValueError): + _validate_description_length(unicode_over) + + def test_whitespace_handling(self): + """Test how validation handles whitespace""" + from controllers.console.app.app import _validate_description_length + + # Test description with spaces + spaces_400 = " " * 400 + assert _validate_description_length(spaces_400) == spaces_400 + + # Test description with spaces over limit + spaces_401 = " " * 401 + with pytest.raises(ValueError): + _validate_description_length(spaces_401) + + # Test mixed content + mixed_400 = "a" * 200 + " " * 200 + assert _validate_description_length(mixed_400) == mixed_400 + + # Test mixed over limit + mixed_401 = "a" * 200 + " " * 201 + with pytest.raises(ValueError): + _validate_description_length(mixed_401) + + +if __name__ == "__main__": + # Run tests directly + import traceback + + test_instance = TestAppDescriptionValidationUnit() + test_methods = [method for method in dir(test_instance) if method.startswith("test_")] + + passed = 0 + failed = 0 + + for test_method in test_methods: + try: + print(f"Running {test_method}...") + getattr(test_instance, test_method)() + print(f"✅ {test_method} PASSED") + passed += 1 + except Exception as e: + print(f"❌ {test_method} FAILED: {str(e)}") + traceback.print_exc() + failed += 1 + + print(f"\n📊 Test Results: {passed} passed, {failed} failed") + + if failed == 0: + print("🎉 All tests passed!") + else: + print("💥 Some tests failed!") + sys.exit(1) diff --git a/api/tests/integration_tests/storage/test_clickzetta_volume.py b/api/tests/integration_tests/storage/test_clickzetta_volume.py new file mode 100644 index 0000000000..293b469ef3 --- /dev/null +++ b/api/tests/integration_tests/storage/test_clickzetta_volume.py @@ -0,0 +1,168 @@ +"""Integration tests for ClickZetta Volume Storage.""" + +import os +import tempfile +import unittest + +import pytest + +from extensions.storage.clickzetta_volume.clickzetta_volume_storage import ( + ClickZettaVolumeConfig, + ClickZettaVolumeStorage, +) + + +class TestClickZettaVolumeStorage(unittest.TestCase): + """Test cases for ClickZetta Volume Storage.""" + + def setUp(self): + """Set up test environment.""" + self.config = ClickZettaVolumeConfig( + username=os.getenv("CLICKZETTA_USERNAME", "test_user"), + password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"), + instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"), + service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"), + workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"), + vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), + schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"), + volume_type="table", + table_prefix="test_dataset_", + ) + + @pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided") + def test_user_volume_operations(self): + """Test basic operations with User Volume.""" + config = self.config + config.volume_type = "user" + + storage = ClickZettaVolumeStorage(config) + + # Test file operations + test_filename = "test_file.txt" + test_content = b"Hello, ClickZetta Volume!" + + # Save file + storage.save(test_filename, test_content) + + # Check if file exists + assert storage.exists(test_filename) + + # Load file + loaded_content = storage.load_once(test_filename) + assert loaded_content == test_content + + # Test streaming + stream_content = b"" + for chunk in storage.load_stream(test_filename): + stream_content += chunk + assert stream_content == test_content + + # Test download + with tempfile.NamedTemporaryFile() as temp_file: + storage.download(test_filename, temp_file.name) + with open(temp_file.name, "rb") as f: + downloaded_content = f.read() + assert downloaded_content == test_content + + # Test scan + files = storage.scan("", files=True, directories=False) + assert test_filename in files + + # Delete file + storage.delete(test_filename) + assert not storage.exists(test_filename) + + @pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided") + def test_table_volume_operations(self): + """Test basic operations with Table Volume.""" + config = self.config + config.volume_type = "table" + + storage = ClickZettaVolumeStorage(config) + + # Test file operations with dataset_id + dataset_id = "12345" + test_filename = f"{dataset_id}/test_file.txt" + test_content = b"Hello, Table Volume!" + + # Save file + storage.save(test_filename, test_content) + + # Check if file exists + assert storage.exists(test_filename) + + # Load file + loaded_content = storage.load_once(test_filename) + assert loaded_content == test_content + + # Test scan for dataset + files = storage.scan(dataset_id, files=True, directories=False) + assert "test_file.txt" in files + + # Delete file + storage.delete(test_filename) + assert not storage.exists(test_filename) + + def test_config_validation(self): + """Test configuration validation.""" + # Test missing required fields + with pytest.raises(ValueError): + ClickZettaVolumeConfig( + username="", # Empty username should fail + password="pass", + instance="instance", + ) + + # Test invalid volume type + with pytest.raises(ValueError): + ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type") + + # Test external volume without volume_name + with pytest.raises(ValueError): + ClickZettaVolumeConfig( + username="user", + password="pass", + instance="instance", + volume_type="external", + # Missing volume_name + ) + + def test_volume_path_generation(self): + """Test volume path generation for different types.""" + storage = ClickZettaVolumeStorage(self.config) + + # Test table volume path + path = storage._get_volume_path("test.txt", "12345") + assert path == "test_dataset_12345/test.txt" + + # Test path with existing dataset_id prefix + path = storage._get_volume_path("12345/test.txt") + assert path == "12345/test.txt" + + # Test user volume + storage._config.volume_type = "user" + path = storage._get_volume_path("test.txt") + assert path == "test.txt" + + def test_sql_prefix_generation(self): + """Test SQL prefix generation for different volume types.""" + storage = ClickZettaVolumeStorage(self.config) + + # Test table volume SQL prefix + prefix = storage._get_volume_sql_prefix("12345") + assert prefix == "TABLE VOLUME test_dataset_12345" + + # Test user volume SQL prefix + storage._config.volume_type = "user" + prefix = storage._get_volume_sql_prefix() + assert prefix == "USER VOLUME" + + # Test external volume SQL prefix + storage._config.volume_type = "external" + storage._config.volume_name = "my_external_volume" + prefix = storage._get_volume_sql_prefix() + assert prefix == "VOLUME my_external_volume" + + +if __name__ == "__main__": + unittest.main() diff --git a/api/core/tools/entities/agent_entities.py b/api/tests/integration_tests/tasks/__init__.py similarity index 100% rename from api/core/tools/entities/agent_entities.py rename to api/tests/integration_tests/tasks/__init__.py diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py new file mode 100644 index 0000000000..2f7fc60ada --- /dev/null +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -0,0 +1,214 @@ +import uuid + +import pytest +from sqlalchemy import delete + +from core.variables.segments import StringSegment +from models import Tenant, db +from models.model import App +from models.workflow import WorkflowDraftVariable +from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch + + +@pytest.fixture +def app_and_tenant(flask_req_ctx): + tenant_id = uuid.uuid4() + tenant = Tenant( + id=tenant_id, + name="test_tenant", + ) + db.session.add(tenant) + + app = App( + tenant_id=tenant_id, # Now tenant.id will have a value + name=f"Test App for tenant {tenant.id}", + mode="workflow", + enable_site=True, + enable_api=True, + ) + db.session.add(app) + db.session.flush() + yield (tenant, app) + + # Cleanup with proper error handling + db.session.delete(app) + db.session.delete(tenant) + + +class TestDeleteDraftVariablesIntegration: + @pytest.fixture + def setup_test_data(self, app_and_tenant): + """Create test data with apps and draft variables.""" + tenant, app = app_and_tenant + + # Create a second app for testing + app2 = App( + tenant_id=tenant.id, + name="Test App 2", + mode="workflow", + enable_site=True, + enable_api=True, + ) + db.session.add(app2) + db.session.commit() + + # Create draft variables for both apps + variables_app1 = [] + variables_app2 = [] + + for i in range(5): + var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + db.session.add(var1) + variables_app1.append(var1) + + var2 = WorkflowDraftVariable.new_node_variable( + app_id=app2.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + db.session.add(var2) + variables_app2.append(var2) + + # Commit all the variables to the database + db.session.commit() + + yield { + "app1": app, + "app2": app2, + "tenant": tenant, + "variables_app1": variables_app1, + "variables_app2": variables_app2, + } + + # Cleanup - refresh session and check if objects still exist + db.session.rollback() # Clear any pending changes + + # Clean up remaining variables + cleanup_query = ( + delete(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id.in_([app.id, app2.id]), + ) + .execution_options(synchronize_session=False) + ) + db.session.execute(cleanup_query) + + # Clean up app2 + app2_obj = db.session.get(App, app2.id) + if app2_obj: + db.session.delete(app2_obj) + + db.session.commit() + + def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data): + """Test that batch deletion only removes variables for the specified app.""" + data = setup_test_data + app1_id = data["app1"].id + app2_id = data["app2"].id + + # Verify initial state + app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + assert app1_vars_before == 5 + assert app2_vars_before == 5 + + # Delete app1 variables + deleted_count = delete_draft_variables_batch(app1_id, batch_size=10) + + # Verify results + assert deleted_count == 5 + + app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + + assert app1_vars_after == 0 # All app1 variables deleted + assert app2_vars_after == 5 # App2 variables unchanged + + def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data): + """Test batch deletion with small batch size processes all records.""" + data = setup_test_data + app1_id = data["app1"].id + + # Use small batch size to force multiple batches + deleted_count = delete_draft_variables_batch(app1_id, batch_size=2) + + assert deleted_count == 5 + + # Verify all variables are deleted + remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + assert remaining_vars == 0 + + def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data): + """Test that deleting variables for nonexistent app returns 0.""" + nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format + + deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100) + + assert deleted_count == 0 + + def test_delete_draft_variables_wrapper_function(self, setup_test_data): + """Test that _delete_draft_variables wrapper function works correctly.""" + data = setup_test_data + app1_id = data["app1"].id + + # Verify initial state + vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + assert vars_before == 5 + + # Call wrapper function + deleted_count = _delete_draft_variables(app1_id) + + # Verify results + assert deleted_count == 5 + + vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + assert vars_after == 0 + + def test_batch_deletion_handles_large_dataset(self, app_and_tenant): + """Test batch deletion with larger dataset to verify batching logic.""" + tenant, app = app_and_tenant + + # Create many draft variables + variables = [] + for i in range(25): + var = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + db.session.add(var) + variables.append(var) + variable_ids = [i.id for i in variables] + + # Commit the variables to the database + db.session.commit() + + try: + # Use small batch size to force multiple batches + deleted_count = delete_draft_variables_batch(app.id, batch_size=8) + + assert deleted_count == 25 + + # Verify all variables are deleted + remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count() + assert remaining_vars == 0 + + finally: + query = ( + delete(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.id.in_(variable_ids), + ) + .execution_options(synchronize_session=False) + ) + db.session.execute(query) diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index 83f4d70ce9..2f0f38e0b8 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -1,5 +1,5 @@ from flask import Flask, request -from flask_restful import Api, Resource +from flask_restx import Api, Resource app = Flask(__name__) api = Api(app) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 4af35a8bef..be5b4de5a2 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -1,5 +1,6 @@ import os from collections import UserDict +from typing import Optional from unittest.mock import MagicMock import pytest @@ -21,7 +22,7 @@ class MockBaiduVectorDBClass: def mock_vector_db_client( self, config=None, - adapter: HTTPAdapter = None, + adapter: Optional[HTTPAdapter] = None, ): self.conn = MagicMock() self._config = MagicMock() diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index ae5f9761b4..02f658aad6 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -23,7 +23,7 @@ class MockTcvectordbClass: key="", read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, timeout=10, - adapter: HTTPAdapter = None, + adapter: Optional[HTTPAdapter] = None, pool_size: int = 2, proxies: Optional[dict] = None, password: Optional[str] = None, @@ -72,11 +72,11 @@ class MockTcvectordbClass: shard: int, replicas: int, description: Optional[str] = None, - index: Index = None, - embedding: Embedding = None, + index: Optional[Index] = None, + embedding: Optional[Embedding] = None, timeout: Optional[float] = None, ttl_config: Optional[dict] = None, - filter_index_config: FilterIndexConfig = None, + filter_index_config: Optional[FilterIndexConfig] = None, indexes: Optional[list[IndexField]] = None, ) -> RPCCollection: return RPCCollection( @@ -113,7 +113,7 @@ class MockTcvectordbClass: database_name: str, collection_name: str, vectors: list[list[float]], - filter: Filter = None, + filter: Optional[Filter] = None, params=None, retrieve_vector: bool = False, limit: int = 10, @@ -128,7 +128,7 @@ class MockTcvectordbClass: collection_name: str, ann: Optional[Union[list[AnnSearch], AnnSearch]] = None, match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None, - filter: Union[Filter, str] = None, + filter: Optional[Union[Filter, str]] = None, rerank: Optional[Rerank] = None, retrieve_vector: Optional[bool] = None, output_fields: Optional[list[str]] = None, @@ -158,7 +158,7 @@ class MockTcvectordbClass: database_name: str, collection_name: str, document_ids: Optional[list[str]] = None, - filter: Filter = None, + filter: Optional[Filter] = None, timeout: Optional[float] = None, ): return {"code": 0, "msg": "operation success"} diff --git a/api/tests/integration_tests/vdb/clickzetta/README.md b/api/tests/integration_tests/vdb/clickzetta/README.md new file mode 100644 index 0000000000..c16dca8018 --- /dev/null +++ b/api/tests/integration_tests/vdb/clickzetta/README.md @@ -0,0 +1,25 @@ +# Clickzetta Integration Tests + +## Running Tests + +To run the Clickzetta integration tests, you need to set the following environment variables: + +```bash +export CLICKZETTA_USERNAME=your_username +export CLICKZETTA_PASSWORD=your_password +export CLICKZETTA_INSTANCE=your_instance +export CLICKZETTA_SERVICE=api.clickzetta.com +export CLICKZETTA_WORKSPACE=your_workspace +export CLICKZETTA_VCLUSTER=your_vcluster +export CLICKZETTA_SCHEMA=dify +``` + +Then run the tests: + +```bash +pytest api/tests/integration_tests/vdb/clickzetta/ +``` + +## Security Note + +Never commit credentials to the repository. Always use environment variables or secure credential management systems. diff --git a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py new file mode 100644 index 0000000000..21de8be6e3 --- /dev/null +++ b/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py @@ -0,0 +1,223 @@ +import contextlib +import os + +import pytest + +from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector +from core.rag.models.document import Document +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +class TestClickzettaVector(AbstractVectorTest): + """ + Test cases for Clickzetta vector database integration. + """ + + @pytest.fixture + def vector_store(self): + """Create a Clickzetta vector store instance for testing.""" + # Skip test if Clickzetta credentials are not configured + if not os.getenv("CLICKZETTA_USERNAME"): + pytest.skip("CLICKZETTA_USERNAME is not configured") + if not os.getenv("CLICKZETTA_PASSWORD"): + pytest.skip("CLICKZETTA_PASSWORD is not configured") + if not os.getenv("CLICKZETTA_INSTANCE"): + pytest.skip("CLICKZETTA_INSTANCE is not configured") + + config = ClickzettaConfig( + username=os.getenv("CLICKZETTA_USERNAME", ""), + password=os.getenv("CLICKZETTA_PASSWORD", ""), + instance=os.getenv("CLICKZETTA_INSTANCE", ""), + service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"), + workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"), + vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), + schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"), + batch_size=10, # Small batch size for testing + enable_inverted_index=True, + analyzer_type="chinese", + analyzer_mode="smart", + vector_distance_function="cosine_distance", + ) + + with setup_mock_redis(): + vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config) + + yield vector + + # Cleanup: delete the test collection + with contextlib.suppress(Exception): + vector.delete() + + def test_clickzetta_vector_basic_operations(self, vector_store): + """Test basic CRUD operations on Clickzetta vector store.""" + # Prepare test data + texts = [ + "这是第一个测试文档,包含一些中文内容。", + "This is the second test document with English content.", + "第三个文档混合了English和中文内容。", + ] + embeddings = [ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2], + ] + documents = [ + Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"}) + for i, text in enumerate(texts) + ] + + # Test create (initial insert) + vector_store.create(texts=documents, embeddings=embeddings) + + # Test text_exists + assert vector_store.text_exists("doc_0") + assert not vector_store.text_exists("doc_999") + + # Test search_by_vector + query_vector = [0.1, 0.2, 0.3, 0.4] + results = vector_store.search_by_vector(query_vector, top_k=2) + assert len(results) > 0 + assert results[0].page_content == texts[0] # Should match the first document + + # Test search_by_full_text (Chinese) + results = vector_store.search_by_full_text("中文", top_k=3) + assert len(results) >= 2 # Should find documents with Chinese content + + # Test search_by_full_text (English) + results = vector_store.search_by_full_text("English", top_k=3) + assert len(results) >= 2 # Should find documents with English content + + # Test delete_by_ids + vector_store.delete_by_ids(["doc_0"]) + assert not vector_store.text_exists("doc_0") + assert vector_store.text_exists("doc_1") + + # Test delete_by_metadata_field + vector_store.delete_by_metadata_field("source", "test") + assert not vector_store.text_exists("doc_1") + assert not vector_store.text_exists("doc_2") + + def test_clickzetta_vector_advanced_search(self, vector_store): + """Test advanced search features of Clickzetta vector store.""" + # Prepare test data with more complex metadata + documents = [] + embeddings = [] + for i in range(10): + doc = Document( + page_content=f"Document {i}: " + get_example_text(), + metadata={ + "doc_id": f"adv_doc_{i}", + "category": "technical" if i % 2 == 0 else "general", + "document_id": f"doc_{i // 3}", # Group documents + "importance": i, + }, + ) + documents.append(doc) + # Create varied embeddings + embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i]) + + vector_store.create(texts=documents, embeddings=embeddings) + + # Test vector search with document filter + query_vector = [0.5, 1.0, 1.5, 2.0] + results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"]) + assert len(results) > 0 + # All results should belong to doc_0 or doc_1 groups + for result in results: + assert result.metadata["document_id"] in ["doc_0", "doc_1"] + + # Test score threshold + results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5) + # Check that all results have a score above threshold + for result in results: + assert result.metadata.get("score", 0) >= 0.5 + + def test_clickzetta_batch_operations(self, vector_store): + """Test batch insertion operations.""" + # Prepare large batch of documents + batch_size = 25 + documents = [] + embeddings = [] + + for i in range(batch_size): + doc = Document( + page_content=f"Batch document {i}: This is a test document for batch processing.", + metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"}, + ) + documents.append(doc) + embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)]) + + # Test batch insert + vector_store.add_texts(documents=documents, embeddings=embeddings) + + # Verify all documents were inserted + for i in range(batch_size): + assert vector_store.text_exists(f"batch_doc_{i}") + + # Clean up + vector_store.delete_by_metadata_field("batch", "test_batch") + + def test_clickzetta_edge_cases(self, vector_store): + """Test edge cases and error handling.""" + # Test empty operations + vector_store.create(texts=[], embeddings=[]) + vector_store.add_texts(documents=[], embeddings=[]) + vector_store.delete_by_ids([]) + + # Test special characters in content + special_doc = Document( + page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline", + metadata={"doc_id": "special_doc", "test": "edge_case"}, + ) + embeddings = [[0.1, 0.2, 0.3, 0.4]] + + vector_store.add_texts(documents=[special_doc], embeddings=embeddings) + assert vector_store.text_exists("special_doc") + + # Test search with special characters + results = vector_store.search_by_full_text("quotes", top_k=1) + if results: # Full-text search might not be available + assert len(results) > 0 + + # Clean up + vector_store.delete_by_ids(["special_doc"]) + + def test_clickzetta_full_text_search_modes(self, vector_store): + """Test different full-text search capabilities.""" + # Prepare documents with various language content + documents = [ + Document( + page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"} + ), + Document( + page_content="Clickzetta provides powerful Lakehouse solutions", + metadata={"doc_id": "en_doc_1", "lang": "english"}, + ), + Document( + page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"} + ), + Document( + page_content="Modern data architecture includes Lakehouse technology", + metadata={"doc_id": "en_doc_2", "lang": "english"}, + ), + ] + + embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents] + + vector_store.create(texts=documents, embeddings=embeddings) + + # Test Chinese full-text search + results = vector_store.search_by_full_text("Lakehouse", top_k=4) + assert len(results) >= 2 # Should find at least documents with "Lakehouse" + + # Test English full-text search + results = vector_store.search_by_full_text("solutions", top_k=2) + assert len(results) >= 1 # Should find English documents with "solutions" + + # Test mixed search + results = vector_store.search_by_full_text("数据架构", top_k=2) + assert len(results) >= 1 # Should find Chinese documents with this phrase + + # Clean up + vector_store.delete_by_metadata_field("lang", "chinese") + vector_store.delete_by_metadata_field("lang", "english") diff --git a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py new file mode 100644 index 0000000000..ef54eaa174 --- /dev/null +++ b/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Test Clickzetta integration in Docker environment +""" + +import os +import time + +import requests +from clickzetta import connect + + +def test_clickzetta_connection(): + """Test direct connection to Clickzetta""" + print("=== Testing direct Clickzetta connection ===") + try: + conn = connect( + username=os.getenv("CLICKZETTA_USERNAME", "test_user"), + password=os.getenv("CLICKZETTA_PASSWORD", "test_password"), + instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"), + service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"), + workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"), + vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"), + database=os.getenv("CLICKZETTA_SCHEMA", "dify"), + ) + + with conn.cursor() as cursor: + # Test basic connectivity + cursor.execute("SELECT 1 as test") + result = cursor.fetchone() + print(f"✓ Connection test: {result}") + + # Check if our test table exists + cursor.execute("SHOW TABLES IN dify") + tables = cursor.fetchall() + print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}") + + # Check if test collection exists + test_collection = "collection_test_dataset" + if test_collection in [t[1] for t in tables if t[0] == "dify"]: + cursor.execute(f"DESCRIBE dify.{test_collection}") + columns = cursor.fetchall() + print(f"✓ Table structure for {test_collection}:") + for col in columns: + print(f" - {col[0]}: {col[1]}") + + # Check for indexes + cursor.execute(f"SHOW INDEXES IN dify.{test_collection}") + indexes = cursor.fetchall() + print(f"✓ Indexes on {test_collection}:") + for idx in indexes: + print(f" - {idx}") + + return True + except Exception as e: + print(f"✗ Connection test failed: {e}") + return False + + +def test_dify_api(): + """Test Dify API with Clickzetta backend""" + print("\n=== Testing Dify API ===") + base_url = "http://localhost:5001" + + # Wait for API to be ready + max_retries = 30 + for i in range(max_retries): + try: + response = requests.get(f"{base_url}/console/api/health") + if response.status_code == 200: + print("✓ Dify API is ready") + break + except: + if i == max_retries - 1: + print("✗ Dify API is not responding") + return False + time.sleep(2) + + # Check vector store configuration + try: + # This is a simplified check - in production, you'd use proper auth + print("✓ Dify is configured to use Clickzetta as vector store") + return True + except Exception as e: + print(f"✗ API test failed: {e}") + return False + + +def verify_table_structure(): + """Verify the table structure meets Dify requirements""" + print("\n=== Verifying Table Structure ===") + + expected_columns = { + "id": "VARCHAR", + "page_content": "VARCHAR", + "metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta + "vector": "ARRAY", + } + + expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"] + + print("✓ Expected table structure:") + for col, dtype in expected_columns.items(): + print(f" - {col}: {dtype}") + + print("\n✓ Required metadata fields:") + for field in expected_metadata_fields: + print(f" - {field}") + + print("\n✓ Index requirements:") + print(" - Vector index (HNSW) on 'vector' column") + print(" - Full-text index on 'page_content' (optional)") + print(" - Functional index on metadata->>'$.doc_id' (recommended)") + print(" - Functional index on metadata->>'$.document_id' (recommended)") + + return True + + +def main(): + """Run all tests""" + print("Starting Clickzetta integration tests for Dify Docker\n") + + tests = [ + ("Direct Clickzetta Connection", test_clickzetta_connection), + ("Dify API Status", test_dify_api), + ("Table Structure Verification", verify_table_structure), + ] + + results = [] + for test_name, test_func in tests: + try: + success = test_func() + results.append((test_name, success)) + except Exception as e: + print(f"\n✗ {test_name} crashed: {e}") + results.append((test_name, False)) + + # Summary + print("\n" + "=" * 50) + print("Test Summary:") + print("=" * 50) + + passed = sum(1 for _, success in results if success) + total = len(results) + + for test_name, success in results: + status = "✅ PASSED" if success else "❌ FAILED" + print(f"{test_name}: {status}") + + print(f"\nTotal: {passed}/{total} tests passed") + + if passed == total: + print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.") + print("\nNext steps:") + print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d") + print("2. Access Dify at http://localhost:3000") + print("3. Create a dataset and test vector storage with Clickzetta") + return 0 + else: + print("\n⚠️ Some tests failed. Please check the errors above.") + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py index 2a0c1bb038..a5ff5b9e82 100644 --- a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py +++ b/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py @@ -11,7 +11,9 @@ class ElasticSearchVectorTest(AbstractVectorTest): self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self.vector = ElasticSearchVector( index_name=self.collection_name.lower(), - config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"), + config=ElasticSearchConfig( + use_cloud=False, host="http://localhost", port="9200", username="elastic", password="elastic" + ), attributes=self.attributes, ) diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py index 61d9a9e712..fe0e03f7b8 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/tests/integration_tests/vdb/qdrant/test_qdrant.py @@ -1,4 +1,5 @@ from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector +from core.rag.models.document import Document from tests.integration_tests.vdb.test_vector_store import ( AbstractVectorTest, setup_mock_redis, @@ -18,6 +19,14 @@ class QdrantVectorTest(AbstractVectorTest): ), ) + def search_by_vector(self): + super().search_by_vector() + # only test for qdrant, may not work on other vector stores + hits_by_vector: list[Document] = self.vector.search_by_vector( + query_vector=self.example_embedding, score_threshold=1 + ) + assert len(hits_by_vector) == 0 + def test_qdrant_vector(setup_mock_redis): QdrantVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py index da549af1b6..aebf3fbda1 100644 --- a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py +++ b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py @@ -2,6 +2,7 @@ import os import uuid import tablestore +from _pytest.python_api import approx from core.rag.datasource.vdb.tablestore.tablestore_vector import ( TableStoreConfig, @@ -16,7 +17,7 @@ from tests.integration_tests.vdb.test_vector_store import ( class TableStoreVectorTest(AbstractVectorTest): - def __init__(self): + def __init__(self, normalize_full_text_score: bool = False): super().__init__() self.vector = TableStoreVector( collection_name=self.collection_name, @@ -25,6 +26,7 @@ class TableStoreVectorTest(AbstractVectorTest): instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"), access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"), access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"), + normalize_full_text_bm25_score=normalize_full_text_score, ), ) @@ -64,7 +66,21 @@ class TableStoreVectorTest(AbstractVectorTest): docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id]) assert len(docs) == 1 assert docs[0].metadata["doc_id"] == self.example_doc_id - assert not hasattr(docs[0], "score") + if self.vector._config.normalize_full_text_bm25_score: + assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3) + else: + assert docs[0].metadata.get("score") is None + + # return none if normalize_full_text_score=true and score_threshold > 0 + docs = self.vector.search_by_full_text( + get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5 + ) + if self.vector._config.normalize_full_text_bm25_score: + assert len(docs) == 0 + else: + assert len(docs) == 1 + assert docs[0].metadata["doc_id"] == self.example_doc_id + assert docs[0].metadata.get("score") is None docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())]) assert len(docs) == 0 @@ -80,3 +96,5 @@ class TableStoreVectorTest(AbstractVectorTest): def test_tablestore_vector(setup_mock_redis): TableStoreVectorTest().run_all_tests() + TableStoreVectorTest(normalize_full_text_score=True).run_all_tests() + TableStoreVectorTest(normalize_full_text_score=False).run_all_tests() diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 707b28e6d8..4f659c5e13 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -55,8 +55,8 @@ def init_code_node(code_config: dict): environment_variables=[], conversation_variables=[], ) - variable_pool.add(["code", "123", "args1"], 1) - variable_pool.add(["code", "123", "args2"], 2) + variable_pool.add(["code", "args1"], 1) + variable_pool.add(["code", "args2"], 2) node = CodeNode( id=str(uuid.uuid4()), @@ -96,9 +96,9 @@ def test_execute_code(setup_code_executor_mock): "variables": [ { "variable": "args1", - "value_selector": ["1", "123", "args1"], + "value_selector": ["1", "args1"], }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + {"variable": "args2", "value_selector": ["1", "args2"]}, ], "answer": "123", "code_language": "python3", @@ -107,8 +107,8 @@ def test_execute_code(setup_code_executor_mock): } node = init_code_node(code_config) - node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1) - node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2) + node.graph_runtime_state.variable_pool.add(["1", "args1"], 1) + node.graph_runtime_state.variable_pool.add(["1", "args2"], 2) # execute node result = node._run() @@ -142,9 +142,9 @@ def test_execute_code_output_validator(setup_code_executor_mock): "variables": [ { "variable": "args1", - "value_selector": ["1", "123", "args1"], + "value_selector": ["1", "args1"], }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + {"variable": "args2", "value_selector": ["1", "args2"]}, ], "answer": "123", "code_language": "python3", @@ -153,8 +153,8 @@ def test_execute_code_output_validator(setup_code_executor_mock): } node = init_code_node(code_config) - node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], 1) - node.graph_runtime_state.variable_pool.add(["1", "123", "args2"], 2) + node.graph_runtime_state.variable_pool.add(["1", "args1"], 1) + node.graph_runtime_state.variable_pool.add(["1", "args2"], 2) # execute node result = node._run() @@ -217,9 +217,9 @@ def test_execute_code_output_validator_depth(): "variables": [ { "variable": "args1", - "value_selector": ["1", "123", "args1"], + "value_selector": ["1", "args1"], }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + {"variable": "args2", "value_selector": ["1", "args2"]}, ], "answer": "123", "code_language": "python3", @@ -307,9 +307,9 @@ def test_execute_code_output_object_list(): "variables": [ { "variable": "args1", - "value_selector": ["1", "123", "args1"], + "value_selector": ["1", "args1"], }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + {"variable": "args2", "value_selector": ["1", "args2"]}, ], "answer": "123", "code_language": "python3", diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index d7856129a3..f7bb7c4600 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -49,8 +49,8 @@ def init_http_node(config: dict): environment_variables=[], conversation_variables=[], ) - variable_pool.add(["a", "b123", "args1"], 1) - variable_pool.add(["a", "b123", "args2"], 2) + variable_pool.add(["a", "args1"], 1) + variable_pool.add(["a", "args2"], 2) node = HttpRequestNode( id=str(uuid.uuid4()), @@ -160,6 +160,177 @@ def test_custom_authorization_header(setup_http_mock): assert "?A=b" in data assert "X-Header: 123" in data + # Custom authorization header should be set (may be masked) + assert "X-Auth:" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock): + """Test: In custom authentication mode, when the api_key is empty, no header should be set.""" + from core.workflow.entities.variable_pool import VariablePool + from core.workflow.nodes.http_request.entities import ( + HttpRequestNodeAuthorization, + HttpRequestNodeData, + HttpRequestNodeTimeout, + ) + from core.workflow.nodes.http_request.executor import Executor + from core.workflow.system_variable import SystemVariable + + # Create variable pool + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="test", files=[]), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + + # Create node data with custom auth and empty api_key + node_data = HttpRequestNodeData( + title="http", + desc="", + url="http://example.com", + method="get", + authorization=HttpRequestNodeAuthorization( + type="api-key", + config={ + "type": "custom", + "api_key": "", # Empty api_key + "header": "X-Custom-Auth", + }, + ), + headers="", + params="", + body=None, + ssl_verify=True, + ) + + # Create executor + executor = Executor( + node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool + ) + + # Get assembled headers + headers = executor._assembling_headers() + + # When api_key is empty, the custom header should NOT be set + assert "X-Custom-Auth" not in headers + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_bearer_authorization_with_custom_header_ignored(setup_http_mock): + """ + Test that when switching from custom to bearer authorization, + the custom header settings don't interfere with bearer token. + This test verifies the fix for issue #23554. + """ + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "bearer", + "api_key": "test-token", + "header": "", # Empty header - should default to Authorization + }, + }, + "headers": "", + "params": "", + "body": None, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + # In bearer mode, should use Authorization header (value is masked with *) + assert "Authorization: " in data + # Should contain masked Bearer token + assert "*" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_basic_authorization_with_custom_header_ignored(setup_http_mock): + """ + Test that when switching from custom to basic authorization, + the custom header settings don't interfere with basic auth. + This test verifies the fix for issue #23554. + """ + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "user:pass", + "header": "", # Empty header - should default to Authorization + }, + }, + "headers": "", + "params": "", + "body": None, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + # In basic mode, should use Authorization header (value is masked with *) + assert "Authorization: " in data + # Should contain masked Basic credentials + assert "*" in data + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_custom_authorization_with_empty_api_key(setup_http_mock): + """ + Test that custom authorization doesn't set header when api_key is empty. + This test verifies the fix for issue #23554. + """ + node = init_http_node( + config={ + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + "config": { + "type": "custom", + "api_key": "", # Empty api_key + "header": "X-Custom-Auth", + }, + }, + "headers": "", + "params": "", + "body": None, + }, + } + ) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + # Custom header should NOT be set when api_key is empty + assert "X-Custom-Auth:" not in data @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) @@ -171,7 +342,7 @@ def test_template(setup_http_mock): "title": "http", "desc": "", "method": "get", - "url": "http://example.com/{{#a.b123.args2#}}", + "url": "http://example.com/{{#a.args2#}}", "authorization": { "type": "api-key", "config": { @@ -180,8 +351,8 @@ def test_template(setup_http_mock): "header": "api-key", }, }, - "headers": "X-Header:123\nX-Header2:{{#a.b123.args2#}}", - "params": "A:b\nTemplate:{{#a.b123.args2#}}", + "headers": "X-Header:123\nX-Header2:{{#a.args2#}}", + "params": "A:b\nTemplate:{{#a.args2#}}", "body": None, }, } @@ -223,7 +394,7 @@ def test_json(setup_http_mock): { "key": "", "type": "text", - "value": '{"a": "{{#a.b123.args1#}}"}', + "value": '{"a": "{{#a.args1#}}"}', }, ], }, @@ -239,6 +410,7 @@ def test_json(setup_http_mock): assert "X-Header: 123" in data +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_x_www_form_urlencoded(setup_http_mock): node = init_http_node( config={ @@ -264,12 +436,12 @@ def test_x_www_form_urlencoded(setup_http_mock): { "key": "a", "type": "text", - "value": "{{#a.b123.args1#}}", + "value": "{{#a.args1#}}", }, { "key": "b", "type": "text", - "value": "{{#a.b123.args2#}}", + "value": "{{#a.args2#}}", }, ], }, @@ -285,6 +457,7 @@ def test_x_www_form_urlencoded(setup_http_mock): assert "X-Header: 123" in data +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_form_data(setup_http_mock): node = init_http_node( config={ @@ -310,12 +483,12 @@ def test_form_data(setup_http_mock): { "key": "a", "type": "text", - "value": "{{#a.b123.args1#}}", + "value": "{{#a.args1#}}", }, { "key": "b", "type": "text", - "value": "{{#a.b123.args2#}}", + "value": "{{#a.args2#}}", }, ], }, @@ -334,6 +507,7 @@ def test_form_data(setup_http_mock): assert "X-Header: 123" in data +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_none_data(setup_http_mock): node = init_http_node( config={ @@ -366,6 +540,7 @@ def test_none_data(setup_http_mock): assert "123123123" not in data +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_mock_404(setup_http_mock): node = init_http_node( config={ @@ -394,6 +569,7 @@ def test_mock_404(setup_http_mock): assert "Not Found" in resp.get("body", "") +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_multi_colons_parse(setup_http_mock): node = init_http_node( config={ @@ -436,3 +612,87 @@ def test_multi_colons_parse(setup_http_mock): assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "") # resp = result.outputs # assert "http://example3.com" == resp.get("headers", {}).get("referer") + + +@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) +def test_nested_object_variable_selector(setup_http_mock): + """Test variable selector functionality with nested object properties.""" + # Create independent test setup without affecting other tests + graph_config = { + "edges": [ + { + "id": "start-source-next-target", + "source": "start", + "target": "1", + }, + ], + "nodes": [ + {"data": {"type": "start"}, "id": "start"}, + { + "id": "1", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com/{{#a.args2#}}/{{#a.args3.nested#}}", + "authorization": { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + }, + "headers": "X-Header:{{#a.args3.nested#}}", + "params": "nested_param:{{#a.args3.nested#}}", + "body": None, + }, + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # Create independent variable pool for this test only + variable_pool = VariablePool( + system_variables=SystemVariable(user_id="aaa", files=[]), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["a", "args1"], 1) + variable_pool.add(["a", "args2"], 2) + variable_pool.add(["a", "args3"], {"nested": "nested_value"}) # Only for this test + + node = HttpRequestNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + config=graph_config["nodes"][1], + ) + + # Initialize node data + if "data" in graph_config["nodes"][1]: + node.init_node_data(graph_config["nodes"][1]["data"]) + + result = node._run() + assert result.process_data is not None + data = result.process_data.get("request", "") + + # Verify nested object property is correctly resolved + assert "/2/nested_value" in data # URL path should contain resolved nested value + assert "X-Header: nested_value" in data # Header should contain nested value + assert "nested_param=nested_value" in data # Param should contain nested value diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index edd70193a8..ef373d968d 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -71,8 +71,8 @@ def init_parameter_extractor_node(config: dict): environment_variables=[], conversation_variables=[], ) - variable_pool.add(["a", "b123", "args1"], 1) - variable_pool.add(["a", "b123", "args2"], 2) + variable_pool.add(["a", "args1"], 1) + variable_pool.add(["a", "args2"], 2) node = ParameterExtractorNode( id=str(uuid.uuid4()), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index f71a5ee140..56265c6b95 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -26,9 +26,9 @@ def test_execute_code(setup_code_executor_mock): "variables": [ { "variable": "args1", - "value_selector": ["1", "123", "args1"], + "value_selector": ["1", "args1"], }, - {"variable": "args2", "value_selector": ["1", "123", "args2"]}, + {"variable": "args2", "value_selector": ["1", "args2"]}, ], "template": code, }, @@ -66,8 +66,8 @@ def test_execute_code(setup_code_executor_mock): environment_variables=[], conversation_variables=[], ) - variable_pool.add(["1", "123", "args1"], 1) - variable_pool.add(["1", "123", "args2"], 3) + variable_pool.add(["1", "args1"], 1) + variable_pool.add(["1", "args2"], 3) node = TemplateTransformNode( id=str(uuid.uuid4()), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 8476c1f874..19a9b36350 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -81,7 +81,7 @@ def test_tool_variable_invoke(): ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"}) - node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1") + node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") # execute node result = node._run() diff --git a/api/core/workflow/workflow_engine_manager.py b/api/tests/test_containers_integration_tests/__init__.py similarity index 100% rename from api/core/workflow/workflow_engine_manager.py rename to api/tests/test_containers_integration_tests/__init__.py diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py new file mode 100644 index 0000000000..0369a5cbd0 --- /dev/null +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -0,0 +1,328 @@ +""" +TestContainers-based integration test configuration for Dify API. + +This module provides containerized test infrastructure using TestContainers library +to spin up real database and service instances for integration testing. This approach +ensures tests run against actual service implementations rather than mocks, providing +more reliable and realistic test scenarios. +""" + +import logging +import os +from collections.abc import Generator +from typing import Optional + +import pytest +from flask import Flask +from flask.testing import FlaskClient +from sqlalchemy.orm import Session +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs +from testcontainers.postgres import PostgresContainer +from testcontainers.redis import RedisContainer + +from app_factory import create_app +from models import db + +# Configure logging for test containers +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class DifyTestContainers: + """ + Manages all test containers required for Dify integration tests. + + This class provides a centralized way to manage multiple containers + needed for comprehensive integration testing, including databases, + caches, and search engines. + """ + + def __init__(self): + """Initialize container management with default configurations.""" + self.postgres: Optional[PostgresContainer] = None + self.redis: Optional[RedisContainer] = None + self.dify_sandbox: Optional[DockerContainer] = None + self._containers_started = False + logger.info("DifyTestContainers initialized - ready to manage test containers") + + def start_containers_with_env(self) -> None: + """ + Start all required containers for integration testing. + + This method initializes and starts PostgreSQL, Redis + containers with appropriate configurations for Dify testing. Containers + are started in dependency order to ensure proper initialization. + """ + if self._containers_started: + logger.info("Containers already started - skipping container startup") + return + + logger.info("Starting test containers for Dify integration tests...") + + # Start PostgreSQL container for main application database + # PostgreSQL is used for storing user data, workflows, and application state + logger.info("Initializing PostgreSQL container...") + self.postgres = PostgresContainer( + image="postgres:16-alpine", + ) + self.postgres.start() + db_host = self.postgres.get_container_host_ip() + db_port = self.postgres.get_exposed_port(5432) + os.environ["DB_HOST"] = db_host + os.environ["DB_PORT"] = str(db_port) + os.environ["DB_USERNAME"] = self.postgres.username + os.environ["DB_PASSWORD"] = self.postgres.password + os.environ["DB_DATABASE"] = self.postgres.dbname + logger.info( + "PostgreSQL container started successfully - Host: %s, Port: %s User: %s, Database: %s", + db_host, + db_port, + self.postgres.username, + self.postgres.dbname, + ) + + # Wait for PostgreSQL to be ready + logger.info("Waiting for PostgreSQL to be ready to accept connections...") + wait_for_logs(self.postgres, "is ready to accept connections", timeout=30) + logger.info("PostgreSQL container is ready and accepting connections") + + # Install uuid-ossp extension for UUID generation + logger.info("Installing uuid-ossp extension...") + try: + import psycopg2 + + conn = psycopg2.connect( + host=db_host, + port=db_port, + user=self.postgres.username, + password=self.postgres.password, + database=self.postgres.dbname, + ) + conn.autocommit = True + cursor = conn.cursor() + cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + cursor.close() + conn.close() + logger.info("uuid-ossp extension installed successfully") + except Exception as e: + logger.warning("Failed to install uuid-ossp extension: %s", e) + + # Set up storage environment variables + os.environ["STORAGE_TYPE"] = "opendal" + os.environ["OPENDAL_SCHEME"] = "fs" + os.environ["OPENDAL_FS_ROOT"] = "storage" + + # Start Redis container for caching and session management + # Redis is used for storing session data, cache entries, and temporary data + logger.info("Initializing Redis container...") + self.redis = RedisContainer(image="redis:latest", port=6379) + self.redis.start() + redis_host = self.redis.get_container_host_ip() + redis_port = self.redis.get_exposed_port(6379) + os.environ["REDIS_HOST"] = redis_host + os.environ["REDIS_PORT"] = str(redis_port) + logger.info("Redis container started successfully - Host: %s, Port: %s", redis_host, redis_port) + + # Wait for Redis to be ready + logger.info("Waiting for Redis to be ready to accept connections...") + wait_for_logs(self.redis, "Ready to accept connections", timeout=30) + logger.info("Redis container is ready and accepting connections") + + # Start Dify Sandbox container for code execution environment + # Dify Sandbox provides a secure environment for executing user code + logger.info("Initializing Dify Sandbox container...") + self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest") + self.dify_sandbox.with_exposed_ports(8194) + self.dify_sandbox.env = { + "API_KEY": "test_api_key", + } + self.dify_sandbox.start() + sandbox_host = self.dify_sandbox.get_container_host_ip() + sandbox_port = self.dify_sandbox.get_exposed_port(8194) + os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}" + os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key" + logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port) + + # Wait for Dify Sandbox to be ready + logger.info("Waiting for Dify Sandbox to be ready to accept connections...") + wait_for_logs(self.dify_sandbox, "config init success", timeout=60) + logger.info("Dify Sandbox container is ready and accepting connections") + + self._containers_started = True + logger.info("All test containers started successfully") + + def stop_containers(self) -> None: + """ + Stop and clean up all test containers. + + This method ensures proper cleanup of all containers to prevent + resource leaks and conflicts between test runs. + """ + if not self._containers_started: + logger.info("No containers to stop - containers were not started") + return + + logger.info("Stopping and cleaning up test containers...") + containers = [self.redis, self.postgres, self.dify_sandbox] + for container in containers: + if container: + try: + container_name = container.image + logger.info("Stopping container: %s", container_name) + container.stop() + logger.info("Successfully stopped container: %s", container_name) + except Exception as e: + # Log error but don't fail the test cleanup + logger.warning("Failed to stop container %s: %s", container, e) + + self._containers_started = False + logger.info("All test containers stopped and cleaned up successfully") + + +# Global container manager instance +_container_manager = DifyTestContainers() + + +def _create_app_with_containers() -> Flask: + """ + Create Flask application configured to use test containers. + + This function creates a Flask application instance that is configured + to connect to the test containers instead of the default development + or production databases. + + Returns: + Flask: Configured Flask application for containerized testing + """ + logger.info("Creating Flask application with test container configuration...") + + # Re-create the config after environment variables have been set + from configs import dify_config + + # Force re-creation of config with new environment variables + dify_config.__dict__.clear() + dify_config.__init__() + + # Create and configure the Flask application + logger.info("Initializing Flask application...") + app = create_app() + logger.info("Flask application created successfully") + + # Initialize database schema + logger.info("Creating database schema...") + with app.app_context(): + db.create_all() + logger.info("Database schema created successfully") + + logger.info("Flask application configured and ready for testing") + return app + + +@pytest.fixture(scope="session") +def set_up_containers_and_env() -> Generator[DifyTestContainers, None, None]: + """ + Session-scoped fixture to manage test containers. + + This fixture ensures containers are started once per test session + and properly cleaned up when all tests are complete. This approach + improves test performance by reusing containers across multiple tests. + + Yields: + DifyTestContainers: Container manager instance + """ + logger.info("=== Starting test session container management ===") + _container_manager.start_containers_with_env() + logger.info("Test containers ready for session") + yield _container_manager + logger.info("=== Cleaning up test session containers ===") + _container_manager.stop_containers() + logger.info("Test session container cleanup completed") + + +@pytest.fixture(scope="session") +def flask_app_with_containers(set_up_containers_and_env) -> Flask: + """ + Session-scoped Flask application fixture using test containers. + + This fixture provides a Flask application instance that is configured + to use the test containers for all database and service connections. + + Args: + containers: Container manager fixture + + Returns: + Flask: Configured Flask application + """ + logger.info("=== Creating session-scoped Flask application ===") + app = _create_app_with_containers() + logger.info("Session-scoped Flask application created successfully") + return app + + +@pytest.fixture +def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]: + """ + Request context fixture for containerized Flask application. + + This fixture provides a Flask request context for tests that need + to interact with the Flask application within a request scope. + + Args: + flask_app_with_containers: Flask application fixture + + Yields: + None: Request context is active during yield + """ + logger.debug("Creating Flask request context...") + with flask_app_with_containers.test_request_context(): + logger.debug("Flask request context active") + yield + logger.debug("Flask request context closed") + + +@pytest.fixture +def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]: + """ + Test client fixture for containerized Flask application. + + This fixture provides a Flask test client that can be used to make + HTTP requests to the containerized application for integration testing. + + Args: + flask_app_with_containers: Flask application fixture + + Yields: + FlaskClient: Test client instance + """ + logger.debug("Creating Flask test client...") + with flask_app_with_containers.test_client() as client: + logger.debug("Flask test client ready") + yield client + logger.debug("Flask test client closed") + + +@pytest.fixture +def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]: + """ + Database session fixture for containerized testing. + + This fixture provides a SQLAlchemy database session that is connected + to the test PostgreSQL container, allowing tests to interact with + the database directly. + + Args: + flask_app_with_containers: Flask application fixture + + Yields: + Session: Database session instance + """ + logger.debug("Creating database session...") + with flask_app_with_containers.app_context(): + session = db.session() + logger.debug("Database session created and ready") + try: + yield session + finally: + session.close() + logger.debug("Database session closed") diff --git a/api/tests/test_containers_integration_tests/factories/__init__.py b/api/tests/test_containers_integration_tests/factories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py new file mode 100644 index 0000000000..d6e14f3f54 --- /dev/null +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -0,0 +1,371 @@ +import unittest +from datetime import UTC, datetime +from typing import Optional +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from core.file import File, FileTransferMethod, FileType +from extensions.ext_database import db +from factories.file_factory import StorageKeyLoader +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + + +@pytest.mark.usefixtures("flask_req_ctx_with_containers") +class TestStorageKeyLoader(unittest.TestCase): + """ + Integration tests for StorageKeyLoader class. + + Tests the batched loading of storage keys from the database for files + with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE. + """ + + def setUp(self): + """Set up test data before each test method.""" + self.session = db.session() + self.tenant_id = str(uuid4()) + self.user_id = str(uuid4()) + self.conversation_id = str(uuid4()) + + # Create test data that will be cleaned up after each test + self.test_upload_files = [] + self.test_tool_files = [] + + # Create StorageKeyLoader instance + self.loader = StorageKeyLoader(self.session, self.tenant_id) + + def tearDown(self): + """Clean up test data after each test method.""" + self.session.rollback() + + def _create_upload_file( + self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> UploadFile: + """Helper method to create an UploadFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if storage_key is None: + storage_key = f"test_storage_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key=storage_key, + name="test_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file.id = file_id + + self.session.add(upload_file) + self.session.flush() + self.test_upload_files.append(upload_file) + + return upload_file + + def _create_tool_file( + self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None + ) -> ToolFile: + """Helper method to create a ToolFile record for testing.""" + if file_id is None: + file_id = str(uuid4()) + if file_key is None: + file_key = f"test_file_key_{uuid4()}" + if tenant_id is None: + tenant_id = self.tenant_id + + tool_file = ToolFile() + tool_file.id = file_id + tool_file.user_id = self.user_id + tool_file.tenant_id = tenant_id + tool_file.conversation_id = self.conversation_id + tool_file.file_key = file_key + tool_file.mimetype = "text/plain" + tool_file.original_url = "http://example.com/file.txt" + tool_file.name = "test_tool_file.txt" + tool_file.size = 2048 + + self.session.add(tool_file) + self.session.flush() + self.test_tool_files.append(tool_file) + + return tool_file + + def _create_file( + self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None + ) -> File: + """Helper method to create a File object for testing.""" + if tenant_id is None: + tenant_id = self.tenant_id + + # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods + file_related_id = None + remote_url = None + + if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): + file_related_id = related_id + elif transfer_method == FileTransferMethod.REMOTE_URL: + remote_url = "https://example.com/test_file.txt" + file_related_id = related_id + + return File( + id=str(uuid4()), # Generate new UUID for File.id + tenant_id=tenant_id, + type=FileType.DOCUMENT, + transfer_method=transfer_method, + related_id=file_related_id, + remote_url=remote_url, + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + size=1024, + storage_key="initial_key", + ) + + def test_load_storage_keys_local_file(self): + """Test loading storage keys for LOCAL_FILE transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_remote_url(self): + """Test loading storage keys for REMOTE_URL transfer method.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == upload_file.key + + def test_load_storage_keys_tool_file(self): + """Test loading storage keys for TOOL_FILE transfer method.""" + # Create test data + tool_file = self._create_tool_file() + file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Load storage keys + self.loader.load_storage_keys([file]) + + # Verify storage key was loaded correctly + assert file._storage_key == tool_file.file_key + + def test_load_storage_keys_mixed_methods(self): + """Test batch loading with mixed transfer methods.""" + # Create test data for different transfer methods + upload_file1 = self._create_upload_file() + upload_file2 = self._create_upload_file() + tool_file = self._create_tool_file() + + file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) + file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + files = [file1, file2, file3] + + # Load storage keys + self.loader.load_storage_keys(files) + + # Verify all storage keys were loaded correctly + assert file1._storage_key == upload_file1.key + assert file2._storage_key == upload_file2.key + assert file3._storage_key == tool_file.file_key + + def test_load_storage_keys_empty_list(self): + """Test with empty file list.""" + # Should not raise any exceptions + self.loader.load_storage_keys([]) + + def test_load_storage_keys_tenant_mismatch(self): + """Test tenant_id validation.""" + # Create file with different tenant_id + upload_file = self._create_upload_file() + file = self._create_file( + related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) + ) + + # Should raise ValueError for tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_missing_file_id(self): + """Test with None file.related_id.""" + # Create a file with valid parameters first, then manually set related_id to None + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = None + + # Should raise ValueError for None file related_id + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file]) + + assert str(context.value) == "file id should not be None." + + def test_load_storage_keys_nonexistent_upload_file_records(self): + """Test with missing UploadFile database records.""" + # Create file with non-existent upload file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_nonexistent_tool_file_records(self): + """Test with missing ToolFile database records.""" + # Create file with non-existent tool file id + non_existent_id = str(uuid4()) + file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) + + # Should raise ValueError for missing record + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_invalid_uuid(self): + """Test with invalid UUID format.""" + # Create a file with valid parameters first, then manually set invalid related_id + file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + file.related_id = "invalid-uuid-format" + + # Should raise ValueError for invalid UUID + with pytest.raises(ValueError): + self.loader.load_storage_keys([file]) + + def test_load_storage_keys_batch_efficiency(self): + """Test batched operations use efficient queries.""" + # Create multiple files of different types + upload_files = [self._create_upload_file() for _ in range(3)] + tool_files = [self._create_tool_file() for _ in range(2)] + + files = [] + files.extend( + [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] + ) + files.extend( + [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] + ) + + # Mock the session to count queries + with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: + self.loader.load_storage_keys(files) + + # Should make exactly 2 queries (one for upload_files, one for tool_files) + assert mock_scalars.call_count == 2 + + # Verify all storage keys were loaded correctly + for i, file in enumerate(files[:3]): + assert file._storage_key == upload_files[i].key + for i, file in enumerate(files[3:]): + assert file._storage_key == tool_files[i].file_key + + def test_load_storage_keys_tenant_isolation(self): + """Test that tenant isolation works correctly.""" + # Create files for different tenants + other_tenant_id = str(uuid4()) + + # Create upload file for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create upload file for other tenant (but don't add to cleanup list) + upload_file_other = UploadFile( + tenant_id=other_tenant_id, + storage_type="local", + key="other_tenant_key", + name="other_file.txt", + size=1024, + extension=".txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=self.user_id, + created_at=datetime.now(UTC), + used=False, + ) + upload_file_other.id = str(uuid4()) + self.session.add(upload_file_other) + self.session.flush() + + # Create file for other tenant but try to load with current tenant's loader + file_other = self._create_file( + related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError due to tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + # Current tenant's file should still work + self.loader.load_storage_keys([file_current]) + assert file_current._storage_key == upload_file_current.key + + def test_load_storage_keys_mixed_tenant_batch(self): + """Test batch with mixed tenant files (should fail on first mismatch).""" + # Create files for current tenant + upload_file_current = self._create_upload_file() + file_current = self._create_file( + related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + ) + + # Create file for different tenant + other_tenant_id = str(uuid4()) + file_other = self._create_file( + related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + ) + + # Should raise ValueError on tenant mismatch + with pytest.raises(ValueError) as context: + self.loader.load_storage_keys([file_current, file_other]) + + assert "invalid file, expected tenant_id" in str(context.value) + + def test_load_storage_keys_duplicate_file_ids(self): + """Test handling of duplicate file IDs in the batch.""" + # Create upload file + upload_file = self._create_upload_file() + + # Create two File objects with same related_id + file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Should handle duplicates gracefully + self.loader.load_storage_keys([file1, file2]) + + # Both files should have the same storage key + assert file1._storage_key == upload_file.key + assert file2._storage_key == upload_file.key + + def test_load_storage_keys_session_isolation(self): + """Test that the loader uses the provided session correctly.""" + # Create test data + upload_file = self._create_upload_file() + file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + # Create loader with different session (same underlying connection) + + with Session(bind=db.engine) as other_session: + other_loader = StorageKeyLoader(other_session, self.tenant_id) + with pytest.raises(ValueError): + other_loader.load_storage_keys([file]) diff --git a/api/tests/test_containers_integration_tests/services/__init__.py b/api/tests/test_containers_integration_tests/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py new file mode 100644 index 0000000000..415e65ce51 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -0,0 +1,3340 @@ +import json +from hashlib import sha256 +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import Unauthorized + +from configs import dify_config +from controllers.console.error import AccountNotFound, NotAllowedCreateWorkspace +from models.account import AccountStatus, TenantAccountJoin +from services.account_service import AccountService, RegisterService, TenantService, TokenPair +from services.errors.account import ( + AccountAlreadyInTenantError, + AccountLoginError, + AccountNotFoundError, + AccountPasswordError, + AccountRegisterError, + CurrentPasswordIncorrectError, +) +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError + + +class TestAccountService: + """Integration tests for AccountService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + patch("services.account_service.PassportService") as mock_passport_service, + ): + # Setup default mock returns + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True + mock_feature_service.get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_billing_service.is_email_in_freeze.return_value = False + mock_passport_service.return_value.issue.return_value = "mock_jwt_token" + + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + "passport_service": mock_passport_service, + } + + def test_create_account_and_login(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation and login with correct password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + assert account.email == email + assert account.status == AccountStatus.ACTIVE.value + + # Login with correct password + logged_in = AccountService.authenticate(email, password) + assert logged_in.id == account.id + + def test_create_account_without_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation without password (for OAuth users). + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + assert account.email == email + assert account.password is None + assert account.password_salt is None + + def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation when registration is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks to disable registration + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = False + + with pytest.raises(AccountNotFound): # AccountNotFound exception + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=fake.password(length=12), + ) + + def test_create_account_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation when email is in freeze period. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True + dify_config.BILLING_ENABLED = True + + with pytest.raises(AccountRegisterError): + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + dify_config.BILLING_ENABLED = False # Reset config for other tests + + def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with non-existent account. + """ + fake = Faker() + email = fake.email() + password = fake.password(length=12) + with pytest.raises(AccountNotFoundError): + AccountService.authenticate(email, password) + + def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with banned account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account first + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Ban the account + account.status = AccountStatus.BANNED.value + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(AccountLoginError): + AccountService.authenticate(email, password) + + def test_authenticate_wrong_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with wrong password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + correct_password = fake.password(length=12) + wrong_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account first + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=correct_password, + ) + + with pytest.raises(AccountPasswordError): + AccountService.authenticate(email, wrong_password) + + def test_authenticate_with_invite_token(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with invite token to set password for account without password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + new_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account without password + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + + # Authenticate with invite token to set password + authenticated_account = AccountService.authenticate( + email, + new_password, + invite_token="valid_invite_token", + ) + + assert authenticated_account.id == account.id + assert authenticated_account.password is not None + assert authenticated_account.password_salt is not None + + def test_authenticate_pending_account_activation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test authentication activates pending account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account with pending status + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + account.status = AccountStatus.PENDING.value + from extensions.ext_database import db + + db.session.commit() + + # Authenticate should activate the account + authenticated_account = AccountService.authenticate(email, password) + assert authenticated_account.status == AccountStatus.ACTIVE.value + assert authenticated_account.initialized_at is not None + + def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful password update. + """ + fake = Faker() + email = fake.email() + name = fake.name() + old_password = fake.password(length=12) + new_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=old_password, + ) + + # Update password + updated_account = AccountService.update_account_password(account, old_password, new_password) + + # Verify new password works + authenticated_account = AccountService.authenticate(email, new_password) + assert authenticated_account.id == account.id + + def test_update_account_password_wrong_current_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test password update with wrong current password. + """ + fake = Faker() + email = fake.email() + name = fake.name() + old_password = fake.password(length=12) + wrong_password = fake.password(length=12) + new_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=old_password, + ) + + with pytest.raises(CurrentPasswordIncorrectError): + AccountService.update_account_password(account, wrong_password, new_password) + + def test_update_account_password_invalid_new_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test password update with invalid new password format. + """ + fake = Faker() + email = fake.email() + name = fake.name() + old_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=old_password, + ) + + # Test with too short password (assuming minimum length validation) + with pytest.raises(ValueError): # Password validation error + AccountService.update_account_password(account, old_password, "123") + + def test_create_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account creation with automatic tenant creation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + account = AccountService.create_account_and_tenant( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + assert account.email == email + + # Verify tenant was created and linked + from extensions.ext_database import db + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + + def test_create_account_and_tenant_workspace_creation_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account creation when workspace creation is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + with pytest.raises(WorkSpaceNotAllowedCreateError): + AccountService.create_account_and_tenant( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + def test_create_account_and_tenant_workspace_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account creation when workspace limit is exceeded. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + with pytest.raises(WorkspacesLimitExceededError): + AccountService.create_account_and_tenant( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + def test_link_account_integrate_new_provider(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test linking account with new OAuth provider. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + + # Link with new provider + AccountService.link_account_integrate("new-google", "google_open_id_123", account) + + # Verify integration was created + from extensions.ext_database import db + from models.account import AccountIntegrate + + integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="new-google").first() + assert integration is not None + assert integration.open_id == "google_open_id_123" + + def test_link_account_integrate_existing_provider( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test linking account with existing provider (should update). + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=None, + ) + + # Link with provider first time + AccountService.link_account_integrate("exists-google", "google_open_id_123", account) + + # Link with same provider but different open_id (should update) + AccountService.link_account_integrate("exists-google", "google_open_id_456", account) + + # Verify integration was updated + from extensions.ext_database import db + from models.account import AccountIntegrate + + integration = ( + db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="exists-google").first() + ) + assert integration.open_id == "google_open_id_456" + + def test_close_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test closing an account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Close account + AccountService.close_account(account) + + # Verify account status changed + from extensions.ext_database import db + + db.session.refresh(account) + assert account.status == AccountStatus.CLOSED.value + + def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating account fields. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + updated_name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Update account fields + updated_account = AccountService.update_account(account, name=updated_name, interface_theme="dark") + + assert updated_account.name == updated_name + assert updated_account.interface_theme == "dark" + + def test_update_account_invalid_field(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating account with invalid field. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + with pytest.raises(AttributeError): + AccountService.update_account(account, invalid_field="value") + + def test_update_login_info(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating login information. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Update login info + AccountService.update_login_info(account, ip_address=ip_address) + + # Verify login info was updated + from extensions.ext_database import db + + db.session.refresh(account) + assert account.last_login_ip == ip_address + assert account.last_login_at is not None + + def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful login with token generation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Login + token_pair = AccountService.login(account, ip_address=ip_address) + + assert isinstance(token_pair, TokenPair) + assert token_pair.access_token == "mock_access_token" + assert token_pair.refresh_token is not None + + # Verify passport service was called with correct parameters + mock_passport = mock_external_service_dependencies["passport_service"].return_value + mock_passport.issue.assert_called_once() + call_args = mock_passport.issue.call_args[0][0] + assert call_args["user_id"] == account.id + assert call_args["iss"] is not None + assert call_args["sub"] == "Console API Passport" + + def test_login_pending_account_activation(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test login activates pending account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account with pending status + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + account.status = AccountStatus.PENDING.value + from extensions.ext_database import db + + db.session.commit() + + # Login should activate the account + token_pair = AccountService.login(account) + + db.session.refresh(account) + assert account.status == AccountStatus.ACTIVE.value + + def test_logout(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test logout functionality. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Login first to get refresh token + token_pair = AccountService.login(account) + + # Logout + AccountService.logout(account=account) + + # Verify refresh token was deleted from Redis + from extensions.ext_redis import redis_client + + refresh_token_key = f"account_refresh_token:{account.id}" + assert redis_client.get(refresh_token_key) is None + + def test_refresh_token_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful token refresh. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "new_mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + # Create associated Tenant + TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + + # Login to get initial tokens + initial_token_pair = AccountService.login(account) + + # Refresh token + new_token_pair = AccountService.refresh_token(initial_token_pair.refresh_token) + + assert isinstance(new_token_pair, TokenPair) + assert new_token_pair.access_token == "new_mock_access_token" + assert new_token_pair.refresh_token != initial_token_pair.refresh_token + + def test_refresh_token_invalid_token(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test refresh token with invalid token. + """ + fake = Faker() + invalid_token = fake.uuid4() + with pytest.raises(ValueError, match="Invalid refresh token"): + AccountService.refresh_token(invalid_token) + + def test_refresh_token_invalid_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test refresh token with valid token but invalid account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_access_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Login to get tokens + token_pair = AccountService.login(account) + + # Delete account + from extensions.ext_database import db + + db.session.delete(account) + db.session.commit() + + # Try to refresh token with deleted account + with pytest.raises(ValueError, match="Invalid account"): + AccountService.refresh_token(token_pair.refresh_token) + + def test_load_user_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading user by ID successfully. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + # Create associated Tenant + TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + + # Load user + loaded_user = AccountService.load_user(account.id) + + assert loaded_user is not None + assert loaded_user.id == account.id + assert loaded_user.email == account.email + + def test_load_user_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading non-existent user. + """ + fake = Faker() + non_existent_user_id = fake.uuid4() + loaded_user = AccountService.load_user(non_existent_user_id) + assert loaded_user is None + + def test_load_user_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading banned user raises Unauthorized. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Ban the account + account.status = AccountStatus.BANNED.value + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(Unauthorized): # Unauthorized exception + AccountService.load_user(account.id) + + def test_get_account_jwt_token(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test JWT token generation for account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_external_service_dependencies["passport_service"].return_value.issue.return_value = "mock_jwt_token" + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate JWT token + token = AccountService.get_account_jwt_token(account) + + assert token == "mock_jwt_token" + + # Verify passport service was called with correct parameters + mock_passport = mock_external_service_dependencies["passport_service"].return_value + mock_passport.issue.assert_called_once() + call_args = mock_passport.issue.call_args[0][0] + assert call_args["user_id"] == account.id + assert call_args["iss"] is not None + assert call_args["sub"] == "Console API Passport" + + def test_load_logged_in_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test loading logged in account by ID. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + # Create associated Tenant + TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + + # Load logged in account + loaded_account = AccountService.load_logged_in_account(account_id=account.id) + + assert loaded_account is not None + assert loaded_account.id == account.id + + def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user through email successfully. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Get user through email + found_user = AccountService.get_user_through_email(email) + + assert found_user is not None + assert found_user.id == account.id + + def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user through non-existent email. + """ + fake = Faker() + non_existent_email = fake.email() + found_user = AccountService.get_user_through_email(non_existent_email) + assert found_user is None + + def test_get_user_through_email_banned_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting banned user through email raises Unauthorized. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Ban the account + account.status = AccountStatus.BANNED.value + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(Unauthorized): # Unauthorized exception + AccountService.get_user_through_email(email) + + def test_get_user_through_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user through email that is in freeze period. + """ + fake = Faker() + email_in_freeze = fake.email() + # Setup mocks + dify_config.BILLING_ENABLED = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True + + with pytest.raises(AccountRegisterError): + AccountService.get_user_through_email(email_in_freeze) + + # Reset config + dify_config.BILLING_ENABLED = False + + def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account deletion (should add task to queue). + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + with patch("services.account_service.delete_account_task") as mock_delete_task: + # Delete account + AccountService.delete_account(account) + + # Verify task was added to queue + mock_delete_task.delay.assert_called_once_with(account.id) + + def test_generate_account_deletion_verification_code( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generating account deletion verification code. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate verification code + token, code = AccountService.generate_account_deletion_verification_code(account) + + assert token is not None + assert code is not None + assert len(code) == 6 + assert code.isdigit() + + def test_verify_account_deletion_code_valid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test verifying valid account deletion code. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate verification code + token, code = AccountService.generate_account_deletion_verification_code(account) + + # Verify code + is_valid = AccountService.verify_account_deletion_code(token, code) + assert is_valid is True + + def test_verify_account_deletion_code_invalid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test verifying invalid account deletion code. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + wrong_code = fake.numerify(text="######") + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate verification code + token, code = AccountService.generate_account_deletion_verification_code(account) + + # Verify with wrong code + is_valid = AccountService.verify_account_deletion_code(token, wrong_code) + assert is_valid is False + + def test_verify_account_deletion_code_invalid_token( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test verifying account deletion code with invalid token. + """ + fake = Faker() + invalid_token = fake.uuid4() + invalid_code = fake.numerify(text="######") + is_valid = AccountService.verify_account_deletion_code(invalid_token, invalid_code) + assert is_valid is False + + +class TestTenantService: + """Integration tests for TenantService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + ): + # Setup default mock returns + mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True + mock_feature_service.get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_billing_service.is_email_in_freeze.return_value = False + + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + } + + def test_create_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tenant creation with default settings. + """ + fake = Faker() + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant + tenant = TenantService.create_tenant(name=tenant_name) + + assert tenant.name == tenant_name + assert tenant.plan == "basic" + assert tenant.status == "normal" + assert tenant.encrypt_public_key is not None + + def test_create_tenant_workspace_creation_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tenant creation when workspace creation is disabled. + """ + fake = Faker() + tenant_name = fake.company() + # Setup mocks to disable workspace creation + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + + with pytest.raises(NotAllowedCreateWorkspace): # NotAllowedCreateWorkspace exception + TenantService.create_tenant(name=tenant_name) + + def test_create_tenant_with_custom_name(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tenant creation with custom name and setup flag. + """ + fake = Faker() + custom_tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + + # Create tenant with setup flag (should bypass workspace creation restriction) + tenant = TenantService.create_tenant(name=custom_tenant_name, is_setup=True, is_from_dashboard=True) + + assert tenant.name == custom_tenant_name + assert tenant.plan == "basic" + assert tenant.status == "normal" + assert tenant.encrypt_public_key is not None + + def test_create_tenant_member_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tenant member creation. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create tenant member + tenant_member = TenantService.create_tenant_member(tenant, account, role="admin") + + assert tenant_member.tenant_id == tenant.id + assert tenant_member.account_id == account.id + assert tenant_member.role == "admin" + + def test_create_tenant_member_duplicate_owner(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test creating duplicate owner for a tenant (should fail). + """ + fake = Faker() + tenant_name = fake.company() + email1 = fake.email() + name1 = fake.name() + password1 = fake.password(length=12) + email2 = fake.email() + name2 = fake.name() + password2 = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + account1 = AccountService.create_account( + email=email1, + name=name1, + interface_language="en-US", + password=password1, + ) + account2 = AccountService.create_account( + email=email2, + name=name2, + interface_language="en-US", + password=password2, + ) + + # Create first owner + TenantService.create_tenant_member(tenant, account1, role="owner") + + # Try to create second owner (should fail) + with pytest.raises(Exception, match="Tenant already has an owner"): + TenantService.create_tenant_member(tenant, account2, role="owner") + + def test_create_tenant_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating role for existing tenant member. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create member with initial role + tenant_member1 = TenantService.create_tenant_member(tenant, account, role="normal") + assert tenant_member1.role == "normal" + + # Update member role + tenant_member2 = TenantService.create_tenant_member(tenant, account, role="editor") + assert tenant_member2.tenant_id == tenant_member1.tenant_id + assert tenant_member2.account_id == tenant_member1.account_id + assert tenant_member2.role == "editor" + + def test_get_join_tenants_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting join tenants for an account. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant1_name = fake.company() + tenant2_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenants + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant1 = TenantService.create_tenant(name=tenant1_name) + tenant2 = TenantService.create_tenant(name=tenant2_name) + + # Add account to both tenants + TenantService.create_tenant_member(tenant1, account, role="normal") + TenantService.create_tenant_member(tenant2, account, role="admin") + + # Get join tenants + join_tenants = TenantService.get_join_tenants(account) + + assert len(join_tenants) == 2 + tenant_names = [tenant.name for tenant in join_tenants] + assert tenant1_name in tenant_names + assert tenant2_name in tenant_names + + def test_get_current_tenant_by_account_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting current tenant by account successfully. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant = TenantService.create_tenant(name=tenant_name) + + # Add account to tenant and set as current + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + from extensions.ext_database import db + + db.session.commit() + + # Get current tenant + current_tenant = TenantService.get_current_tenant_by_account(account) + + assert current_tenant.id == tenant.id + assert current_tenant.name == tenant.name + assert current_tenant.role == "owner" + + def test_get_current_tenant_by_account_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting current tenant when account has no current tenant. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account without setting current tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Try to get current tenant (should fail) + with pytest.raises(AttributeError): + TenantService.get_current_tenant_by_account(account) + + def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tenant switching. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant1_name = fake.company() + tenant2_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenants + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant1 = TenantService.create_tenant(name=tenant1_name) + tenant2 = TenantService.create_tenant(name=tenant2_name) + + # Add account to both tenants + TenantService.create_tenant_member(tenant1, account, role="owner") + TenantService.create_tenant_member(tenant2, account, role="admin") + + # Set initial current tenant + account.current_tenant = tenant1 + from extensions.ext_database import db + + db.session.commit() + + # Switch to second tenant + TenantService.switch_tenant(account, tenant2.id) + + # Verify tenant was switched + db.session.refresh(account) + assert account.current_tenant_id == tenant2.id + + def test_switch_tenant_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tenant switching without providing tenant ID. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Try to switch tenant without providing tenant ID + with pytest.raises(ValueError, match="Tenant ID must be provided"): + TenantService.switch_tenant(account, None) + + def test_switch_tenant_account_not_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test switching to a tenant where account is not a member. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + tenant_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create account and tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + tenant = TenantService.create_tenant(name=tenant_name) + + # Try to switch to tenant where account is not a member + with pytest.raises(Exception, match="Tenant not found or account is not a member of the tenant"): + TenantService.switch_tenant(account, tenant.id) + + def test_has_roles_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking if tenant has specific roles. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + admin_account = AccountService.create_account( + email=admin_email, + name=admin_name, + interface_language="en-US", + password=admin_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, admin_account, role="admin") + + # Check if tenant has owner role + from models.account import TenantAccountRole + + has_owner = TenantService.has_roles(tenant, [TenantAccountRole.OWNER]) + assert has_owner is True + + # Check if tenant has admin role + has_admin = TenantService.has_roles(tenant, [TenantAccountRole.ADMIN]) + assert has_admin is True + + # Check if tenant has normal role (should be False) + has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL]) + assert has_normal is False + + def test_has_roles_invalid_role_type(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking roles with invalid role type. + """ + fake = Faker() + tenant_name = fake.company() + invalid_role = fake.word() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant + tenant = TenantService.create_tenant(name=tenant_name) + + # Try to check roles with invalid role type + with pytest.raises(ValueError, match="all roles must be TenantAccountRole"): + TenantService.has_roles(tenant, [invalid_role]) + + def test_get_user_role_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting user role in a tenant. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant with specific role + TenantService.create_tenant_member(tenant, account, role="editor") + + # Get user role + user_role = TenantService.get_user_role(account, tenant) + + assert user_role == "editor" + + def test_check_member_permission_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking member permission successfully. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="normal") + + # Check owner permission to add member (should succeed) + TenantService.check_member_permission(tenant, owner_account, member_account, "add") + + def test_check_member_permission_invalid_action( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test checking member permission with invalid action. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + invalid_action = "invalid_action_that_doesnt_exist" + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant + TenantService.create_tenant_member(tenant, account, role="owner") + + # Try to check permission with invalid action + with pytest.raises(Exception, match="Invalid action"): + TenantService.check_member_permission(tenant, account, None, invalid_action) + + def test_check_member_permission_operate_self(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test checking member permission when trying to operate self. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant + TenantService.create_tenant_member(tenant, account, role="owner") + + # Try to check permission to operate self + with pytest.raises(Exception, match="Cannot operate self"): + TenantService.check_member_permission(tenant, account, account, "remove") + + def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful member removal from tenant. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="normal") + + # Remove member + TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + + # Verify member was removed + from extensions.ext_database import db + from models.account import TenantAccountJoin + + member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + ) + assert member_join is None + + def test_remove_member_from_tenant_operate_self( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test removing member when trying to operate self. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Add account to tenant + TenantService.create_tenant_member(tenant, account, role="owner") + + # Try to remove self + with pytest.raises(Exception, match="Cannot operate self"): + TenantService.remove_member_from_tenant(tenant, account, account) + + def test_remove_member_from_tenant_not_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test removing member who is not in the tenant. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + non_member_email = fake.email() + non_member_name = fake.name() + non_member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + non_member_account = AccountService.create_account( + email=non_member_email, + name=non_member_name, + interface_language="en-US", + password=non_member_password, + ) + + # Add only owner to tenant + TenantService.create_tenant_member(tenant, owner_account, role="owner") + + # Try to remove non-member + with pytest.raises(Exception, match="Member not in tenant"): + TenantService.remove_member_from_tenant(tenant, non_member_account, owner_account) + + def test_update_member_role_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful member role update. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="normal") + + # Update member role + TenantService.update_member_role(tenant, member_account, "admin", owner_account) + + # Verify role was updated + from extensions.ext_database import db + from models.account import TenantAccountJoin + + member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + ) + assert member_join.role == "admin" + + def test_update_member_role_to_owner(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating member role to owner (should change current owner to admin). + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="admin") + + # Update member role to owner + TenantService.update_member_role(tenant, member_account, "owner", owner_account) + + # Verify roles were updated correctly + from extensions.ext_database import db + from models.account import TenantAccountJoin + + owner_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=owner_account.id).first() + ) + member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + ) + assert owner_join.role == "admin" + assert member_join.role == "owner" + + def test_update_member_role_already_assigned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating member role to already assigned role. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + member_email = fake.email() + member_name = fake.name() + member_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + member_account = AccountService.create_account( + email=member_email, + name=member_name, + interface_language="en-US", + password=member_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, member_account, role="admin") + + # Try to update member role to already assigned role + with pytest.raises(Exception, match="The provided role is already assigned to the member"): + TenantService.update_member_role(tenant, member_account, "admin", owner_account) + + def test_get_tenant_count_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting tenant count successfully. + """ + fake = Faker() + tenant1_name = fake.company() + tenant2_name = fake.company() + tenant3_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create multiple tenants + tenant1 = TenantService.create_tenant(name=tenant1_name) + tenant2 = TenantService.create_tenant(name=tenant2_name) + tenant3 = TenantService.create_tenant(name=tenant3_name) + + # Get tenant count + tenant_count = TenantService.get_tenant_count() + + # Should have at least 3 tenants (may be more from other tests) + assert tenant_count >= 3 + + def test_create_owner_tenant_if_not_exist_new_user( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating owner tenant for new user without existing tenants. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + workspace_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create owner tenant + TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + assert account.current_tenant is not None + assert account.current_tenant.name == workspace_name + + def test_create_owner_tenant_if_not_exist_existing_tenant( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating owner tenant when user already has a tenant. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + existing_tenant_name = fake.company() + new_workspace_name = fake.company() + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + + # Create account and existing tenant + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + existing_tenant = TenantService.create_tenant(name=existing_tenant_name) + TenantService.create_tenant_member(existing_tenant, account, role="owner") + account.current_tenant = existing_tenant + from extensions.ext_database import db + + db.session.commit() + + # Try to create owner tenant again (should not create new one) + TenantService.create_owner_tenant_if_not_exist(account, name=new_workspace_name) + + # Verify no new tenant was created + tenant_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).all() + assert len(tenant_joins) == 1 + assert account.current_tenant.id == existing_tenant.id + + def test_create_owner_tenant_if_not_exist_workspace_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating owner tenant when workspace creation is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + workspace_name = fake.company() + # Setup mocks to disable workspace creation + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Try to create owner tenant (should fail) + with pytest.raises(WorkSpaceNotAllowedCreateError): # WorkSpaceNotAllowedCreateError exception + TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + + def test_get_tenant_members_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting tenant members successfully. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + normal_email = fake.email() + normal_name = fake.name() + normal_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + admin_account = AccountService.create_account( + email=admin_email, + name=admin_name, + interface_language="en-US", + password=admin_password, + ) + normal_account = AccountService.create_account( + email=normal_email, + name=normal_name, + interface_language="en-US", + password=normal_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, admin_account, role="admin") + TenantService.create_tenant_member(tenant, normal_account, role="normal") + + # Get tenant members + members = TenantService.get_tenant_members(tenant) + + assert len(members) == 3 + member_emails = [member.email for member in members] + assert owner_email in member_emails + assert admin_email in member_emails + assert normal_email in member_emails + + # Verify roles are set correctly + for member in members: + if member.email == owner_email: + assert member.role == "owner" + elif member.email == admin_email: + assert member.role == "admin" + elif member.email == normal_email: + assert member.role == "normal" + + def test_get_dataset_operator_members_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting dataset operator members successfully. + """ + fake = Faker() + tenant_name = fake.company() + owner_email = fake.email() + owner_name = fake.name() + owner_password = fake.password(length=12) + operator_email = fake.email() + operator_name = fake.name() + operator_password = fake.password(length=12) + normal_email = fake.email() + normal_name = fake.name() + normal_password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant and accounts + tenant = TenantService.create_tenant(name=tenant_name) + owner_account = AccountService.create_account( + email=owner_email, + name=owner_name, + interface_language="en-US", + password=owner_password, + ) + dataset_operator_account = AccountService.create_account( + email=operator_email, + name=operator_name, + interface_language="en-US", + password=operator_password, + ) + normal_account = AccountService.create_account( + email=normal_email, + name=normal_name, + interface_language="en-US", + password=normal_password, + ) + + # Add members with different roles + TenantService.create_tenant_member(tenant, owner_account, role="owner") + TenantService.create_tenant_member(tenant, dataset_operator_account, role="dataset_operator") + TenantService.create_tenant_member(tenant, normal_account, role="normal") + + # Get dataset operator members + dataset_operators = TenantService.get_dataset_operator_members(tenant) + + assert len(dataset_operators) == 1 + assert dataset_operators[0].email == operator_email + assert dataset_operators[0].role == "dataset_operator" + + def test_get_custom_config_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting custom config successfully. + """ + fake = Faker() + tenant_name = fake.company() + theme = fake.random_element(elements=("dark", "light")) + language = fake.random_element(elements=("zh-CN", "en-US")) + # Setup mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + + # Create tenant with custom config + tenant = TenantService.create_tenant(name=tenant_name) + + # Set custom config + custom_config = {"theme": theme, "language": language, "feature_flags": {"beta": True}} + tenant.custom_config_dict = custom_config + from extensions.ext_database import db + + db.session.commit() + + # Get custom config + retrieved_config = TenantService.get_custom_config(tenant.id) + + assert retrieved_config == custom_config + assert retrieved_config["theme"] == theme + assert retrieved_config["language"] == language + assert retrieved_config["feature_flags"]["beta"] is True + + +class TestRegisterService: + """Integration tests for RegisterService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + patch("services.account_service.PassportService") as mock_passport_service, + ): + # Setup default mock returns + mock_feature_service.get_system_features.return_value.is_allow_register = True + mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True + mock_feature_service.get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_billing_service.is_email_in_freeze.return_value = False + mock_passport_service.return_value.issue.return_value = "mock_jwt_token" + + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + "passport_service": mock_passport_service, + } + + def test_setup_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful system setup with account creation and tenant setup. + """ + fake = Faker() + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute setup + RegisterService.setup( + email=admin_email, + name=admin_name, + password=admin_password, + ip_address=ip_address, + ) + + # Verify account was created + from extensions.ext_database import db + from models.account import Account + from models.model import DifySetup + + account = db.session.query(Account).filter_by(email=admin_email).first() + assert account is not None + assert account.name == admin_name + assert account.last_login_ip == ip_address + assert account.initialized_at is not None + assert account.status == "active" + + # Verify DifySetup was created + dify_setup = db.session.query(DifySetup).first() + assert dify_setup is not None + + # Verify tenant was created and linked + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + + def test_setup_failure_rollback(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test setup failure with proper rollback of all created entities. + """ + fake = Faker() + admin_email = fake.email() + admin_name = fake.name() + admin_password = fake.password(length=12) + ip_address = fake.ipv4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account to raise exception + with patch("services.account_service.AccountService.create_account") as mock_create_account: + mock_create_account.side_effect = Exception("Database error") + + # Execute setup and verify exception + with pytest.raises(ValueError, match="Setup failed: Database error"): + RegisterService.setup( + email=admin_email, + name=admin_name, + password=admin_password, + ip_address=ip_address, + ) + + # Verify no entities were created (rollback worked) + from extensions.ext_database import db + from models.account import Account, Tenant, TenantAccountJoin + from models.model import DifySetup + + account = db.session.query(Account).filter_by(email=admin_email).first() + tenant_count = db.session.query(Tenant).count() + tenant_join_count = db.session.query(TenantAccountJoin).count() + dify_setup_count = db.session.query(DifySetup).count() + + assert account is None + assert tenant_count == 0 + assert tenant_join_count == 0 + assert dify_setup_count == 0 + + def test_register_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful account registration with workspace creation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + ) + + # Verify account was created + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + assert account.current_tenant is not None + assert account.current_tenant.name == f"{name}'s Workspace" + + def test_register_with_oauth(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration with OAuth integration. + """ + fake = Faker() + email = fake.email() + name = fake.name() + open_id = fake.uuid4() + provider = fake.random_element(elements=("google", "github", "microsoft")) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration with OAuth + account = RegisterService.register( + email=email, + name=name, + password=None, + open_id=open_id, + provider=provider, + language=language, + ) + + # Verify account was created + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify OAuth integration was created + from extensions.ext_database import db + from models.account import AccountIntegrate + + integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + assert integration is not None + assert integration.open_id == open_id + + def test_register_with_pending_status(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration with pending status. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration with pending status + from models.account import AccountStatus + + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + status=AccountStatus.PENDING, + ) + + # Verify account was created with pending status + assert account.email == email + assert account.name == name + assert account.status == "pending" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is not None + assert tenant_join.role == "owner" + + def test_register_workspace_creation_disabled(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration when workspace creation is disabled. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # with pytest.raises(AccountRegisterError, match="Workspace is not allowed to create."): + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + ) + + # Verify account was created with no tenant + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is None + + def test_register_workspace_limit_exceeded(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration when workspace limit is exceeded. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = False + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # with pytest.raises(AccountRegisterError, match="Workspace is not allowed to create."): + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + ) + + # Verify account was created with no tenant + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify tenant was created and linked + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is None + + def test_register_without_workspace(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test account registration without workspace creation. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute registration without workspace creation + account = RegisterService.register( + email=email, + name=name, + password=password, + language=language, + create_workspace_required=False, + ) + + # Verify account was created + assert account.email == email + assert account.name == name + assert account.status == "active" + assert account.initialized_at is not None + + # Verify no tenant was created + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + assert tenant_join is None + + def test_invite_new_member_new_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting a new member who doesn't have an account yet. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + new_member_email = fake.email() + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Mock the email task + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + mock_send_mail.delay.return_value = None + + # Execute invitation + token = RegisterService.invite_new_member( + tenant=tenant, + email=new_member_email, + language=language, + role="normal", + inviter=inviter, + ) + + # Verify token was generated + assert token is not None + assert len(token) > 0 + + # Verify email task was called + mock_send_mail.delay.assert_called_once() + + # Verify new account was created with pending status + from extensions.ext_database import db + from models.account import Account, TenantAccountJoin + + new_account = db.session.query(Account).filter_by(email=new_member_email).first() + assert new_account is not None + assert new_account.name == new_member_email.split("@")[0] # Default name from email + assert new_account.status == "pending" + + # Verify tenant member was created + tenant_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=new_account.id).first() + ) + assert tenant_join is not None + assert tenant_join.role == "normal" + + def test_invite_new_member_existing_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting an existing member who is not in the tenant yet. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + existing_member_email = fake.email() + existing_member_name = fake.name() + existing_member_password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Create existing account + existing_account = AccountService.create_account( + email=existing_member_email, + name=existing_member_name, + interface_language="en-US", + password=existing_member_password, + ) + + # Mock the email task + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + mock_send_mail.delay.return_value = None + with pytest.raises(AccountAlreadyInTenantError, match="Account already in tenant."): + # Execute invitation + token = RegisterService.invite_new_member( + tenant=tenant, + email=existing_member_email, + language=language, + role="admin", + inviter=inviter, + ) + + # Verify email task was not called + mock_send_mail.delay.assert_not_called() + + # Verify tenant member was created for existing account + from extensions.ext_database import db + from models.account import TenantAccountJoin + + tenant_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=existing_account.id).first() + ) + assert tenant_join is not None + assert tenant_join.role == "admin" + + def test_invite_new_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting a member who is already in the tenant with pending status. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + existing_pending_member_email = fake.email() + existing_pending_member_name = fake.name() + existing_pending_member_password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Create existing account with pending status + existing_account = AccountService.create_account( + email=existing_pending_member_email, + name=existing_pending_member_name, + interface_language="en-US", + password=existing_pending_member_password, + ) + existing_account.status = "pending" + from extensions.ext_database import db + + db.session.commit() + + # Add existing account to tenant + TenantService.create_tenant_member(tenant, existing_account, role="normal") + + # Mock the email task + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + mock_send_mail.delay.return_value = None + + # Execute invitation (should resend email for pending member) + token = RegisterService.invite_new_member( + tenant=tenant, + email=existing_pending_member_email, + language=language, + role="normal", + inviter=inviter, + ) + + # Verify token was generated + assert token is not None + assert len(token) > 0 + + # Verify email task was called + mock_send_mail.delay.assert_called_once() + + def test_invite_new_member_no_inviter(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test inviting a member without providing an inviter. + """ + fake = Faker() + tenant_name = fake.company() + new_member_email = fake.email() + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant + tenant = TenantService.create_tenant(name=tenant_name) + + # Execute invitation without inviter (should fail) + with pytest.raises(ValueError, match="Inviter is required"): + RegisterService.invite_new_member( + tenant=tenant, + email=new_member_email, + language=language, + role="normal", + inviter=None, + ) + + def test_invite_new_member_account_already_in_tenant( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test inviting a member who is already in the tenant with active status. + """ + fake = Faker() + tenant_name = fake.company() + inviter_email = fake.email() + inviter_name = fake.name() + inviter_password = fake.password(length=12) + already_in_tenant_email = fake.email() + already_in_tenant_name = fake.name() + already_in_tenant_password = fake.password(length=12) + language = fake.random_element(elements=("en-US", "zh-CN")) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and inviter account + tenant = TenantService.create_tenant(name=tenant_name) + inviter = AccountService.create_account( + email=inviter_email, + name=inviter_name, + interface_language="en-US", + password=inviter_password, + ) + TenantService.create_tenant_member(tenant, inviter, role="owner") + + # Create existing account with active status + existing_account = AccountService.create_account( + email=already_in_tenant_email, + name=already_in_tenant_name, + interface_language="en-US", + password=already_in_tenant_password, + ) + existing_account.status = "active" + from extensions.ext_database import db + + db.session.commit() + + # Add existing account to tenant + TenantService.create_tenant_member(tenant, existing_account, role="normal") + + # Execute invitation (should fail for active member) + with pytest.raises(AccountAlreadyInTenantError, match="Account already in tenant."): + RegisterService.invite_new_member( + tenant=tenant, + email=already_in_tenant_email, + language=language, + role="normal", + inviter=inviter, + ) + + def test_generate_invite_token_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful generation of invite token. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Execute token generation + token = RegisterService.generate_invite_token(tenant, account) + + # Verify token was generated + assert token is not None + assert len(token) > 0 + + # Verify token was stored in Redis + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + stored_data = redis_client.get(token_key) + assert stored_data is not None + + # Verify stored data contains correct information + import json + + invitation_data = json.loads(stored_data.decode("utf-8")) + assert invitation_data["account_id"] == str(account.id) + assert invitation_data["email"] == account.email + assert invitation_data["workspace_id"] == tenant.id + + def test_is_valid_invite_token_valid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation of valid invite token. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + # Execute validation + is_valid = RegisterService.is_valid_invite_token(token) + + # Verify token is valid + assert is_valid is True + + def test_is_valid_invite_token_invalid(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation of invalid invite token. + """ + fake = Faker() + invalid_token = fake.uuid4() + # Execute validation with non-existent token + is_valid = RegisterService.is_valid_invite_token(invalid_token) + + # Verify token is invalid + assert is_valid is False + + def test_revoke_token_with_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test revoking token with workspace ID and email. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + # Verify token exists in Redis before revocation + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + assert redis_client.get(token_key) is not None + + # Execute token revocation + RegisterService.revoke_token( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify token was not deleted from Redis + assert redis_client.get(token_key) is not None + + def test_revoke_token_without_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test revoking token without workspace ID and email. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + # Verify token exists in Redis before revocation + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + assert redis_client.get(token_key) is not None + + # Execute token revocation without workspace and email + RegisterService.revoke_token( + workspace_id="", + email="", + token=token, + ) + + # Verify token was deleted from Redis + assert redis_client.get(token_key) is None + + def test_get_invitation_if_token_valid_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with valid token. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + TenantService.create_tenant_member(tenant, account, role="normal") + + # Generate a real token + token = RegisterService.generate_invite_token(tenant, account) + + email_hash = sha256(account.email.encode()).hexdigest() + cache_key = f"member_invite_token:{tenant.id}, {email_hash}:{token}" + from extensions.ext_redis import redis_client + + redis_client.setex(cache_key, 24 * 60 * 60, account.id) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify result contains expected data + assert result is not None + assert result["account"].id == account.id + assert result["tenant"].id == tenant.id + assert result["data"]["account_id"] == str(account.id) + assert result["data"]["email"] == account.email + assert result["data"]["workspace_id"] == tenant.id + + def test_get_invitation_if_token_valid_invalid_token( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with invalid token. + """ + fake = Faker() + workspace_id = fake.uuid4() + email = fake.email() + invalid_token = fake.uuid4() + # Execute invitation retrieval with invalid token + result = RegisterService.get_invitation_if_token_valid( + workspace_id=workspace_id, + email=email, + token=invalid_token, + ) + + # Verify result is None + assert result is None + + def test_get_invitation_if_token_valid_invalid_tenant( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with invalid tenant. + """ + fake = Faker() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + invalid_tenant_id = fake.uuid4() + token = fake.uuid4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create account + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + + # Create a real token but with non-existent tenant ID + from extensions.ext_redis import redis_client + + invitation_data = { + "account_id": str(account.id), + "email": account.email, + "workspace_id": invalid_tenant_id, + } + token_key = RegisterService._get_invitation_token_key(token) + import json + + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=invalid_tenant_id, + email=account.email, + token=token, + ) + + # Verify result is None (tenant not found) + assert result is None + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_if_token_valid_account_mismatch( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with account ID mismatch. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + token = fake.uuid4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + TenantService.create_tenant_member(tenant, account, role="normal") + + # Create a real token but with mismatched account ID + from extensions.ext_redis import redis_client + + invitation_data = { + "account_id": "different-account-id", # Different from actual account ID + "email": account.email, + "workspace_id": tenant.id, + } + token_key = RegisterService._get_invitation_token_key(token) + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify result is None (account ID mismatch) + assert result is None + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_if_token_valid_tenant_not_normal( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation data with tenant not in normal status. + """ + fake = Faker() + tenant_name = fake.company() + email = fake.email() + name = fake.name() + password = fake.password(length=12) + token = fake.uuid4() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Create tenant and account + tenant = TenantService.create_tenant(name=tenant_name) + account = AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password=password, + ) + TenantService.create_tenant_member(tenant, account, role="normal") + + # Change tenant status to non-normal + tenant.status = "suspended" + from extensions.ext_database import db + + db.session.commit() + + # Create a real token + from extensions.ext_redis import redis_client + + invitation_data = { + "account_id": str(account.id), + "email": account.email, + "workspace_id": tenant.id, + } + token_key = RegisterService._get_invitation_token_key(token) + import json + + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService.get_invitation_if_token_valid( + workspace_id=tenant.id, + email=account.email, + token=token, + ) + + # Verify result is None (tenant not in normal status) + assert result is None + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_by_token_with_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation by token with workspace ID and email. + """ + fake = Faker() + token = fake.uuid4() + workspace_id = fake.uuid4() + email = fake.email() + + # Create the cache key as the service does + from hashlib import sha256 + + from extensions.ext_redis import redis_client + + email_hash = sha256(email.encode()).hexdigest() + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" + + # Store account ID in Redis + account_id = fake.uuid4() + redis_client.setex(cache_key, 24 * 60 * 60, account_id) + + # Execute invitation retrieval + result = RegisterService._get_invitation_by_token( + token=token, + workspace_id=workspace_id, + email=email, + ) + + # Verify result contains expected data + assert result is not None + assert result["account_id"] == account_id + assert result["email"] == email + assert result["workspace_id"] == workspace_id + + # Clean up + redis_client.delete(cache_key) + + def test_get_invitation_by_token_without_workspace_and_email( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting invitation by token without workspace ID and email. + """ + fake = Faker() + token = fake.uuid4() + invitation_data = { + "account_id": fake.uuid4(), + "email": fake.email(), + "workspace_id": fake.uuid4(), + } + + # Store invitation data in Redis using standard token key + from extensions.ext_redis import redis_client + + token_key = RegisterService._get_invitation_token_key(token) + import json + + redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) + + # Execute invitation retrieval + result = RegisterService._get_invitation_by_token(token=token) + + # Verify result contains expected data + assert result is not None + assert result["account_id"] == invitation_data["account_id"] + assert result["email"] == invitation_data["email"] + assert result["workspace_id"] == invitation_data["workspace_id"] + + # Clean up + redis_client.delete(token_key) + + def test_get_invitation_token_key(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting invitation token key. + """ + fake = Faker() + token = fake.uuid4() + # Execute token key generation + token_key = RegisterService._get_invitation_token_key(token) + + # Verify token key format + assert token_key == f"member_invite:token:{token}" diff --git a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py new file mode 100644 index 0000000000..9ed9008af9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py @@ -0,0 +1,885 @@ +import copy + +import pytest +from faker import Faker + +from core.prompt.prompt_templates.advanced_prompt_templates import ( + BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, + CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + CONTEXT, +) +from models.model import AppMode +from services.advanced_prompt_template_service import AdvancedPromptTemplateService + + +class TestAdvancedPromptTemplateService: + """Integration tests for AdvancedPromptTemplateService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + # This service doesn't have external dependencies, but we keep the pattern + # for consistency with other test files + return {} + + def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful prompt generation for Baichuan model. + + This test verifies: + - Proper prompt generation for Baichuan models + - Correct model detection logic + - Appropriate prompt template selection + """ + fake = Faker() + + # Test data for Baichuan model + args = { + "app_mode": AppMode.CHAT.value, + "model_mode": "completion", + "model_name": "baichuan-13b-chat", + "has_context": "true", + } + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + + # Verify context is included for Baichuan model + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert BAICHUAN_CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + assert "{{#histories#}}" in prompt_text + assert "{{#query#}}" in prompt_text + + def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful prompt generation for common models. + + This test verifies: + - Proper prompt generation for non-Baichuan models + - Correct model detection logic + - Appropriate prompt template selection + """ + fake = Faker() + + # Test data for common model + args = { + "app_mode": AppMode.CHAT.value, + "model_mode": "completion", + "model_name": "gpt-3.5-turbo", + "has_context": "true", + } + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + + # Verify context is included for common model + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + assert "{{#histories#}}" in prompt_text + assert "{{#query#}}" in prompt_text + + def test_get_prompt_case_insensitive_baichuan_detection( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Baichuan model detection is case insensitive. + + This test verifies: + - Model name detection works regardless of case + - Proper prompt template selection for different case variations + """ + fake = Faker() + + # Test different case variations + test_cases = ["Baichuan-13B-Chat", "BAICHUAN-13B-CHAT", "baichuan-13b-chat", "BaiChuan-13B-Chat"] + + for model_name in test_cases: + args = { + "app_mode": AppMode.CHAT.value, + "model_mode": "completion", + "model_name": model_name, + "has_context": "true", + } + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify Baichuan template is used + assert result is not None + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert BAICHUAN_CONTEXT in prompt_text + + def test_get_common_prompt_chat_app_completion_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test common prompt generation for chat app with completion mode. + + This test verifies: + - Correct prompt template selection for chat app + completion mode + - Proper context integration + - Template structure validation + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "true") + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + assert "conversation_histories_role" in result["completion_prompt_config"] + assert "stop" in result + + # Verify context is included + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + assert "{{#histories#}}" in prompt_text + assert "{{#query#}}" in prompt_text + + def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test common prompt generation for chat app with chat mode. + + This test verifies: + - Correct prompt template selection for chat app + chat mode + - Proper context integration + - Template structure validation + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "chat", "true") + + # Assert: Verify the expected outcomes + assert result is not None + assert "chat_prompt_config" in result + assert "prompt" in result["chat_prompt_config"] + assert len(result["chat_prompt_config"]["prompt"]) > 0 + assert "role" in result["chat_prompt_config"]["prompt"][0] + assert "text" in result["chat_prompt_config"]["prompt"][0] + + # Verify context is included + prompt_text = result["chat_prompt_config"]["prompt"][0]["text"] + assert CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + + def test_get_common_prompt_completion_app_completion_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test common prompt generation for completion app with completion mode. + + This test verifies: + - Correct prompt template selection for completion app + completion mode + - Proper context integration + - Template structure validation + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "completion", "true") + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + assert "stop" in result + + # Verify context is included + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + + def test_get_common_prompt_completion_app_chat_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test common prompt generation for completion app with chat mode. + + This test verifies: + - Correct prompt template selection for completion app + chat mode + - Proper context integration + - Template structure validation + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION.value, "chat", "true") + + # Assert: Verify the expected outcomes + assert result is not None + assert "chat_prompt_config" in result + assert "prompt" in result["chat_prompt_config"] + assert len(result["chat_prompt_config"]["prompt"]) > 0 + assert "role" in result["chat_prompt_config"]["prompt"][0] + assert "text" in result["chat_prompt_config"]["prompt"][0] + + # Verify context is included + prompt_text = result["chat_prompt_config"]["prompt"][0]["text"] + assert CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + + def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test common prompt generation without context. + + This test verifies: + - Correct handling when has_context is "false" + - Context is not included in prompt + - Template structure remains intact + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "completion", "false") + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + + # Verify context is NOT included + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert CONTEXT not in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + assert "{{#histories#}}" in prompt_text + assert "{{#query#}}" in prompt_text + + def test_get_common_prompt_unsupported_app_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test common prompt generation with unsupported app mode. + + This test verifies: + - Proper handling of unsupported app modes + - Default empty dict return + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_common_prompt("unsupported_mode", "completion", "true") + + # Assert: Verify empty dict is returned + assert result == {} + + def test_get_common_prompt_unsupported_model_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test common prompt generation with unsupported model mode. + + This test verifies: + - Proper handling of unsupported model modes + - Default empty dict return + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + + # Assert: Verify empty dict is returned + assert result == {} + + def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test completion prompt generation with context. + + This test verifies: + - Proper context integration in completion prompts + - Template structure preservation + - Context placement at the beginning + """ + fake = Faker() + + # Create test prompt template + prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "true", CONTEXT) + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + + # Verify context is prepended to original text + result_text = result["completion_prompt_config"]["prompt"]["text"] + assert result_text.startswith(CONTEXT) + assert original_text in result_text + assert result_text == CONTEXT + original_text + + def test_get_completion_prompt_without_context( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test completion prompt generation without context. + + This test verifies: + - Original template is preserved when no context + - No modification to prompt text + """ + fake = Faker() + + # Create test prompt template + prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT) + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + + # Verify original text is unchanged + result_text = result["completion_prompt_config"]["prompt"]["text"] + assert result_text == original_text + assert CONTEXT not in result_text + + def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test chat prompt generation with context. + + This test verifies: + - Proper context integration in chat prompts + - Template structure preservation + - Context placement at the beginning of first message + """ + fake = Faker() + + # Create test prompt template + prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) + original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "true", CONTEXT) + + # Assert: Verify the expected outcomes + assert result is not None + assert "chat_prompt_config" in result + assert "prompt" in result["chat_prompt_config"] + assert len(result["chat_prompt_config"]["prompt"]) > 0 + assert "text" in result["chat_prompt_config"]["prompt"][0] + + # Verify context is prepended to original text + result_text = result["chat_prompt_config"]["prompt"][0]["text"] + assert result_text.startswith(CONTEXT) + assert original_text in result_text + assert result_text == CONTEXT + original_text + + def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test chat prompt generation without context. + + This test verifies: + - Original template is preserved when no context + - No modification to prompt text + """ + fake = Faker() + + # Create test prompt template + prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) + original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT) + + # Assert: Verify the expected outcomes + assert result is not None + assert "chat_prompt_config" in result + assert "prompt" in result["chat_prompt_config"] + assert len(result["chat_prompt_config"]["prompt"]) > 0 + assert "text" in result["chat_prompt_config"]["prompt"][0] + + # Verify original text is unchanged + result_text = result["chat_prompt_config"]["prompt"][0]["text"] + assert result_text == original_text + assert CONTEXT not in result_text + + def test_get_baichuan_prompt_chat_app_completion_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Baichuan prompt generation for chat app with completion mode. + + This test verifies: + - Correct Baichuan prompt template selection for chat app + completion mode + - Proper Baichuan context integration + - Template structure validation + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "true") + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + assert "conversation_histories_role" in result["completion_prompt_config"] + assert "stop" in result + + # Verify Baichuan context is included + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert BAICHUAN_CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + assert "{{#histories#}}" in prompt_text + assert "{{#query#}}" in prompt_text + + def test_get_baichuan_prompt_chat_app_chat_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Baichuan prompt generation for chat app with chat mode. + + This test verifies: + - Correct Baichuan prompt template selection for chat app + chat mode + - Proper Baichuan context integration + - Template structure validation + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "chat", "true") + + # Assert: Verify the expected outcomes + assert result is not None + assert "chat_prompt_config" in result + assert "prompt" in result["chat_prompt_config"] + assert len(result["chat_prompt_config"]["prompt"]) > 0 + assert "role" in result["chat_prompt_config"]["prompt"][0] + assert "text" in result["chat_prompt_config"]["prompt"][0] + + # Verify Baichuan context is included + prompt_text = result["chat_prompt_config"]["prompt"][0]["text"] + assert BAICHUAN_CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + + def test_get_baichuan_prompt_completion_app_completion_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Baichuan prompt generation for completion app with completion mode. + + This test verifies: + - Correct Baichuan prompt template selection for completion app + completion mode + - Proper Baichuan context integration + - Template structure validation + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "completion", "true") + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + assert "stop" in result + + # Verify Baichuan context is included + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert BAICHUAN_CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + + def test_get_baichuan_prompt_completion_app_chat_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Baichuan prompt generation for completion app with chat mode. + + This test verifies: + - Correct Baichuan prompt template selection for completion app + chat mode + - Proper Baichuan context integration + - Template structure validation + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION.value, "chat", "true") + + # Assert: Verify the expected outcomes + assert result is not None + assert "chat_prompt_config" in result + assert "prompt" in result["chat_prompt_config"] + assert len(result["chat_prompt_config"]["prompt"]) > 0 + assert "role" in result["chat_prompt_config"]["prompt"][0] + assert "text" in result["chat_prompt_config"]["prompt"][0] + + # Verify Baichuan context is included + prompt_text = result["chat_prompt_config"]["prompt"][0]["text"] + assert BAICHUAN_CONTEXT in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + + def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test Baichuan prompt generation without context. + + This test verifies: + - Correct handling when has_context is "false" + - Baichuan context is not included in prompt + - Template structure remains intact + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "completion", "false") + + # Assert: Verify the expected outcomes + assert result is not None + assert "completion_prompt_config" in result + assert "prompt" in result["completion_prompt_config"] + assert "text" in result["completion_prompt_config"]["prompt"] + + # Verify Baichuan context is NOT included + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert BAICHUAN_CONTEXT not in prompt_text + assert "{{#pre_prompt#}}" in prompt_text + assert "{{#histories#}}" in prompt_text + assert "{{#query#}}" in prompt_text + + def test_get_baichuan_prompt_unsupported_app_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Baichuan prompt generation with unsupported app mode. + + This test verifies: + - Proper handling of unsupported app modes + - Default empty dict return + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_baichuan_prompt("unsupported_mode", "completion", "true") + + # Assert: Verify empty dict is returned + assert result == {} + + def test_get_baichuan_prompt_unsupported_model_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Baichuan prompt generation with unsupported model mode. + + This test verifies: + - Proper handling of unsupported model modes + - Default empty dict return + """ + fake = Faker() + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT.value, "unsupported_mode", "true") + + # Assert: Verify empty dict is returned + assert result == {} + + def test_get_prompt_all_app_modes_common_model( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test prompt generation for all app modes with common model. + + This test verifies: + - All app modes work correctly with common models + - Proper template selection for each combination + """ + fake = Faker() + + # Test all app modes + app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + model_modes = ["completion", "chat"] + + for app_mode in app_modes: + for model_mode in model_modes: + args = { + "app_mode": app_mode, + "model_mode": model_mode, + "model_name": "gpt-3.5-turbo", + "has_context": "true", + } + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify result is not empty + assert result is not None + assert result != {} + + def test_get_prompt_all_app_modes_baichuan_model( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test prompt generation for all app modes with Baichuan model. + + This test verifies: + - All app modes work correctly with Baichuan models + - Proper template selection for each combination + """ + fake = Faker() + + # Test all app modes + app_modes = [AppMode.CHAT.value, AppMode.COMPLETION.value] + model_modes = ["completion", "chat"] + + for app_mode in app_modes: + for model_mode in model_modes: + args = { + "app_mode": app_mode, + "model_mode": model_mode, + "model_name": "baichuan-13b-chat", + "has_context": "true", + } + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify result is not empty + assert result is not None + assert result != {} + + def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test prompt generation with edge cases. + + This test verifies: + - Handling of edge case inputs + - Proper error handling + - Consistent behavior with unusual inputs + """ + fake = Faker() + + # Test edge cases + edge_cases = [ + {"app_mode": "", "model_mode": "completion", "model_name": "gpt-3.5-turbo", "has_context": "true"}, + {"app_mode": AppMode.CHAT.value, "model_mode": "", "model_name": "gpt-3.5-turbo", "has_context": "true"}, + {"app_mode": AppMode.CHAT.value, "model_mode": "completion", "model_name": "", "has_context": "true"}, + { + "app_mode": AppMode.CHAT.value, + "model_mode": "completion", + "model_name": "gpt-3.5-turbo", + "has_context": "", + }, + ] + + for args in edge_cases: + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify method handles edge cases gracefully + # Should either return a valid result or empty dict, but not crash + assert result is not None + + def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that original templates are not modified. + + This test verifies: + - Original template constants are not modified + - Deep copy is used properly + - Template immutability is maintained + """ + fake = Faker() + + # Store original templates + original_chat_completion = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + original_chat_chat = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) + original_completion_completion = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG) + original_completion_chat = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG) + + # Test with context + args = { + "app_mode": AppMode.CHAT.value, + "model_mode": "completion", + "model_name": "gpt-3.5-turbo", + "has_context": "true", + } + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify original templates are unchanged + assert original_chat_completion == CHAT_APP_COMPLETION_PROMPT_CONFIG + assert original_chat_chat == CHAT_APP_CHAT_PROMPT_CONFIG + assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG + assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG + + def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that original Baichuan templates are not modified. + + This test verifies: + - Original Baichuan template constants are not modified + - Deep copy is used properly + - Template immutability is maintained + """ + fake = Faker() + + # Store original templates + original_baichuan_chat_completion = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG) + original_baichuan_chat_chat = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG) + original_baichuan_completion_completion = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG) + original_baichuan_completion_chat = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG) + + # Test with context + args = { + "app_mode": AppMode.CHAT.value, + "model_mode": "completion", + "model_name": "baichuan-13b-chat", + "has_context": "true", + } + + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify original templates are unchanged + assert original_baichuan_chat_completion == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG + assert original_baichuan_chat_chat == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG + assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG + assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG + + def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test consistency of context integration across different scenarios. + + This test verifies: + - Context is always prepended correctly + - Context integration is consistent across different templates + - No context duplication or corruption + """ + fake = Faker() + + # Test different scenarios + test_scenarios = [ + { + "app_mode": AppMode.CHAT.value, + "model_mode": "completion", + "model_name": "gpt-3.5-turbo", + "has_context": "true", + }, + { + "app_mode": AppMode.CHAT.value, + "model_mode": "chat", + "model_name": "gpt-3.5-turbo", + "has_context": "true", + }, + { + "app_mode": AppMode.COMPLETION.value, + "model_mode": "completion", + "model_name": "gpt-3.5-turbo", + "has_context": "true", + }, + { + "app_mode": AppMode.COMPLETION.value, + "model_mode": "chat", + "model_name": "gpt-3.5-turbo", + "has_context": "true", + }, + ] + + for args in test_scenarios: + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify context integration is consistent + assert result is not None + assert result != {} + + # Check that context is properly integrated + if "completion_prompt_config" in result: + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert prompt_text.startswith(CONTEXT) + elif "chat_prompt_config" in result: + prompt_text = result["chat_prompt_config"]["prompt"][0]["text"] + assert prompt_text.startswith(CONTEXT) + + def test_baichuan_context_integration_consistency( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test consistency of Baichuan context integration across different scenarios. + + This test verifies: + - Baichuan context is always prepended correctly + - Context integration is consistent across different templates + - No context duplication or corruption + """ + fake = Faker() + + # Test different scenarios + test_scenarios = [ + { + "app_mode": AppMode.CHAT.value, + "model_mode": "completion", + "model_name": "baichuan-13b-chat", + "has_context": "true", + }, + { + "app_mode": AppMode.CHAT.value, + "model_mode": "chat", + "model_name": "baichuan-13b-chat", + "has_context": "true", + }, + { + "app_mode": AppMode.COMPLETION.value, + "model_mode": "completion", + "model_name": "baichuan-13b-chat", + "has_context": "true", + }, + { + "app_mode": AppMode.COMPLETION.value, + "model_mode": "chat", + "model_name": "baichuan-13b-chat", + "has_context": "true", + }, + ] + + for args in test_scenarios: + # Act: Execute the method under test + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert: Verify context integration is consistent + assert result is not None + assert result != {} + + # Check that Baichuan context is properly integrated + if "completion_prompt_config" in result: + prompt_text = result["completion_prompt_config"]["prompt"]["text"] + assert prompt_text.startswith(BAICHUAN_CONTEXT) + elif "chat_prompt_config" in result: + prompt_text = result["chat_prompt_config"]["prompt"][0]["text"] + assert prompt_text.startswith(BAICHUAN_CONTEXT) diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py new file mode 100644 index 0000000000..d63b188b12 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -0,0 +1,1033 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from core.plugin.impl.exc import PluginDaemonClientSideError +from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought +from services.account_service import AccountService, TenantService +from services.agent_service import AgentService +from services.app_service import AppService + + +class TestAgentService: + """Integration tests for AgentService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client, + patch("services.agent_service.ToolManager") as mock_tool_manager, + patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager, + patch("services.agent_service.current_user") as mock_current_user, + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for agent service + mock_plugin_agent_client_instance = mock_plugin_agent_client.return_value + mock_plugin_agent_client_instance.fetch_agent_strategy_providers.return_value = [ + MagicMock( + plugin_id="test_plugin", + declaration=MagicMock( + identity=MagicMock(name="test_provider"), + strategies=[MagicMock(identity=MagicMock(name="test_strategy"))], + ), + ) + ] + mock_plugin_agent_client_instance.fetch_agent_strategy_provider.return_value = MagicMock( + plugin_id="test_plugin", + declaration=MagicMock( + identity=MagicMock(name="test_provider"), + strategies=[MagicMock(identity=MagicMock(name="test_strategy"))], + ), + ) + + # Setup ToolManager mocks + mock_tool_manager.get_tool_icon.return_value = "test_icon" + mock_tool_manager.get_tool_label.return_value = MagicMock( + to_dict=lambda: {"en_US": "Test Tool", "zh_Hans": "测试工具"} + ) + + # Setup AgentConfigManager mocks + mock_agent_config = MagicMock() + mock_agent_config.tools = [ + MagicMock(tool_name="test_tool", provider_type="test_provider", provider_id="test_id") + ] + mock_agent_config_manager.convert.return_value = mock_agent_config + + # Setup current_user mock + mock_current_user.timezone = "UTC" + + # Setup default mock returns for app service + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Mock ModelManager for model configuration + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + yield { + "plugin_agent_client": mock_plugin_agent_client, + "tool_manager": mock_tool_manager, + "agent_config_manager": mock_agent_config_manager, + "current_user": mock_current_user, + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app with realistic data + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "agent-chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Update the app model config to set agent_mode for agent-chat mode + if app.mode == "agent-chat" and app.app_model_config: + app.app_model_config.agent_mode = json.dumps({"enabled": True, "strategy": "react", "tools": []}) + from extensions.ext_database import db + + db.session.commit() + + return app, account + + def _create_test_conversation_and_message(self, db_session_with_containers, app, account): + """ + Helper method to create a test conversation and message with agent thoughts. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: App instance + account: Account instance + + Returns: + tuple: (conversation, message) - Created conversation and message instances + """ + fake = Faker() + + from extensions.ext_database import db + + # Create conversation + conversation = Conversation( + id=fake.uuid4(), + app_id=app.id, + from_account_id=account.id, + from_end_user_id=None, + name=fake.sentence(), + inputs={}, + status="normal", + mode="chat", + from_source="api", + ) + db.session.add(conversation) + db.session.commit() + + # Create app model config + app_model_config = AppModelConfig( + id=fake.uuid4(), + app_id=app.id, + provider="openai", + model_id="gpt-3.5-turbo", + configs={}, + model="gpt-3.5-turbo", + agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), + ) + db.session.add(app_model_config) + db.session.commit() + + # Update conversation with app model config + conversation.app_model_config_id = app_model_config.id + db.session.commit() + + # Create message + message = Message( + id=fake.uuid4(), + conversation_id=conversation.id, + app_id=app.id, + from_account_id=account.id, + from_end_user_id=None, + inputs={}, + query=fake.text(max_nb_chars=100), + message=[{"role": "user", "text": fake.text(max_nb_chars=100)}], + answer=fake.text(max_nb_chars=200), + message_tokens=100, + message_unit_price=0.001, + answer_tokens=200, + answer_unit_price=0.001, + provider_response_latency=1.5, + currency="USD", + from_source="api", + ) + db.session.add(message) + db.session.commit() + + return conversation, message + + def _create_test_agent_thoughts(self, db_session_with_containers, message): + """ + Helper method to create test agent thoughts for a message. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + message: Message instance + + Returns: + list: Created agent thoughts + """ + fake = Faker() + + from extensions.ext_database import db + + agent_thoughts = [] + + # Create first agent thought + thought1 = MessageAgentThought( + id=fake.uuid4(), + message_id=message.id, + position=1, + thought="I need to analyze the user's request", + tool="test_tool", + tool_labels_str=json.dumps({"test_tool": {"en_US": "Test Tool", "zh_Hans": "测试工具"}}), + tool_meta_str=json.dumps( + { + "test_tool": { + "error": None, + "time_cost": 0.5, + "tool_config": {"tool_provider_type": "test_provider", "tool_provider": "test_id"}, + "tool_parameters": {}, + } + } + ), + tool_input=json.dumps({"test_tool": {"input": "test_input"}}), + observation=json.dumps({"test_tool": {"output": "test_output"}}), + tokens=50, + created_by_role="account", + created_by=message.from_account_id, + ) + db.session.add(thought1) + agent_thoughts.append(thought1) + + # Create second agent thought + thought2 = MessageAgentThought( + id=fake.uuid4(), + message_id=message.id, + position=2, + thought="Based on the analysis, I can provide a response", + tool="dataset_tool", + tool_labels_str=json.dumps({"dataset_tool": {"en_US": "Dataset Tool", "zh_Hans": "数据集工具"}}), + tool_meta_str=json.dumps( + { + "dataset_tool": { + "error": None, + "time_cost": 0.3, + "tool_config": {"tool_provider_type": "dataset-retrieval", "tool_provider": "dataset_id"}, + "tool_parameters": {}, + } + } + ), + tool_input=json.dumps({"dataset_tool": {"query": "test_query"}}), + observation=json.dumps({"dataset_tool": {"results": "test_results"}}), + tokens=30, + created_by_role="account", + created_by=message.from_account_id, + ) + db.session.add(thought2) + agent_thoughts.append(thought2) + + db.session.commit() + + return agent_thoughts + + def test_get_agent_logs_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of agent logs with complete data. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + agent_thoughts = self._create_test_agent_thoughts(db_session_with_containers, message) + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result structure + assert result is not None + assert "meta" in result + assert "iterations" in result + assert "files" in result + + # Verify meta information + meta = result["meta"] + assert meta["status"] == "success" + assert meta["executor"] == account.name + assert meta["iterations"] == 2 + assert meta["agent_mode"] == "react" + assert meta["total_tokens"] == 300 # 100 + 200 + assert meta["elapsed_time"] == 1.5 + + # Verify iterations + iterations = result["iterations"] + assert len(iterations) == 2 + + # Verify first iteration + first_iteration = iterations[0] + assert first_iteration["tokens"] == 50 + assert first_iteration["thought"] == "I need to analyze the user's request" + assert len(first_iteration["tool_calls"]) == 1 + + tool_call = first_iteration["tool_calls"][0] + assert tool_call["tool_name"] == "test_tool" + assert tool_call["tool_label"] == {"en_US": "Test Tool", "zh_Hans": "测试工具"} + assert tool_call["status"] == "success" + assert tool_call["time_cost"] == 0.5 + assert tool_call["tool_icon"] == "test_icon" + + # Verify second iteration + second_iteration = iterations[1] + assert second_iteration["tokens"] == 30 + assert second_iteration["thought"] == "Based on the analysis, I can provide a response" + assert len(second_iteration["tool_calls"]) == 1 + + dataset_tool_call = second_iteration["tool_calls"][0] + assert dataset_tool_call["tool_name"] == "dataset_tool" + assert dataset_tool_call["tool_label"] == {"en_US": "Dataset Tool", "zh_Hans": "数据集工具"} + assert dataset_tool_call["status"] == "success" + assert dataset_tool_call["time_cost"] == 0.3 + assert dataset_tool_call["tool_icon"] == "" # dataset-retrieval tools have empty icon + + def test_get_agent_logs_conversation_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when conversation is not found. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Execute the method under test with non-existent conversation + with pytest.raises(ValueError, match="Conversation not found"): + AgentService.get_agent_logs(app, fake.uuid4(), fake.uuid4()) + + def test_get_agent_logs_message_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test error handling when message is not found. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + + # Execute the method under test with non-existent message + with pytest.raises(ValueError, match="Message not found"): + AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4()) + + def test_get_agent_logs_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test agent logs retrieval when conversation is from end user. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + from extensions.ext_database import db + + # Create end user + end_user = EndUser( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type="web_app", + is_anonymous=False, + session_id=fake.uuid4(), + name=fake.name(), + ) + db.session.add(end_user) + db.session.commit() + + # Create conversation with end user + conversation = Conversation( + id=fake.uuid4(), + app_id=app.id, + from_account_id=None, + from_end_user_id=end_user.id, + name=fake.sentence(), + inputs={}, + status="normal", + mode="chat", + from_source="api", + ) + db.session.add(conversation) + db.session.commit() + + # Create app model config + app_model_config = AppModelConfig( + id=fake.uuid4(), + app_id=app.id, + provider="openai", + model_id="gpt-3.5-turbo", + configs={}, + model="gpt-3.5-turbo", + agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), + ) + db.session.add(app_model_config) + db.session.commit() + + # Update conversation with app model config + conversation.app_model_config_id = app_model_config.id + db.session.commit() + + # Create message + message = Message( + id=fake.uuid4(), + conversation_id=conversation.id, + app_id=app.id, + from_account_id=None, + from_end_user_id=end_user.id, + inputs={}, + query=fake.text(max_nb_chars=100), + message=[{"role": "user", "text": fake.text(max_nb_chars=100)}], + answer=fake.text(max_nb_chars=200), + message_tokens=100, + message_unit_price=0.001, + answer_tokens=200, + answer_unit_price=0.001, + provider_response_latency=1.5, + currency="USD", + from_source="api", + ) + db.session.add(message) + db.session.commit() + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result + assert result is not None + assert result["meta"]["executor"] == end_user.name + + def test_get_agent_logs_with_unknown_executor(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test agent logs retrieval when executor is unknown. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + from extensions.ext_database import db + + # Create conversation with non-existent account + conversation = Conversation( + id=fake.uuid4(), + app_id=app.id, + from_account_id=fake.uuid4(), # Non-existent account + from_end_user_id=None, + name=fake.sentence(), + inputs={}, + status="normal", + mode="chat", + from_source="api", + ) + db.session.add(conversation) + db.session.commit() + + # Create app model config + app_model_config = AppModelConfig( + id=fake.uuid4(), + app_id=app.id, + provider="openai", + model_id="gpt-3.5-turbo", + configs={}, + model="gpt-3.5-turbo", + agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), + ) + db.session.add(app_model_config) + db.session.commit() + + # Update conversation with app model config + conversation.app_model_config_id = app_model_config.id + db.session.commit() + + # Create message + message = Message( + id=fake.uuid4(), + conversation_id=conversation.id, + app_id=app.id, + from_account_id=fake.uuid4(), # Non-existent account + from_end_user_id=None, + inputs={}, + query=fake.text(max_nb_chars=100), + message=[{"role": "user", "text": fake.text(max_nb_chars=100)}], + answer=fake.text(max_nb_chars=200), + message_tokens=100, + message_unit_price=0.001, + answer_tokens=200, + answer_unit_price=0.001, + provider_response_latency=1.5, + currency="USD", + from_source="api", + ) + db.session.add(message) + db.session.commit() + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result + assert result is not None + assert result["meta"]["executor"] == "Unknown" + + def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test agent logs retrieval with tool errors. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + + from extensions.ext_database import db + + # Create agent thought with tool error + thought_with_error = MessageAgentThought( + id=fake.uuid4(), + message_id=message.id, + position=1, + thought="I need to analyze the user's request", + tool="error_tool", + tool_labels_str=json.dumps({"error_tool": {"en_US": "Error Tool", "zh_Hans": "错误工具"}}), + tool_meta_str=json.dumps( + { + "error_tool": { + "error": "Tool execution failed", + "time_cost": 0.5, + "tool_config": {"tool_provider_type": "test_provider", "tool_provider": "test_id"}, + "tool_parameters": {}, + } + } + ), + tool_input=json.dumps({"error_tool": {"input": "test_input"}}), + observation=json.dumps({"error_tool": {"output": "error_output"}}), + tokens=50, + created_by_role="account", + created_by=message.from_account_id, + ) + db.session.add(thought_with_error) + db.session.commit() + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result + assert result is not None + iterations = result["iterations"] + assert len(iterations) == 1 + + tool_call = iterations[0]["tool_calls"][0] + assert tool_call["status"] == "error" + assert tool_call["error"] == "Tool execution failed" + + def test_get_agent_logs_without_agent_thoughts( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test agent logs retrieval when message has no agent thoughts. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result + assert result is not None + assert result["meta"]["iterations"] == 0 + assert len(result["iterations"]) == 0 + + def test_get_agent_logs_app_model_config_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when app model config is not found. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + from extensions.ext_database import db + + # Remove app model config to test error handling + app.app_model_config_id = None + db.session.commit() + + # Create conversation without app model config + conversation = Conversation( + id=fake.uuid4(), + app_id=app.id, + from_account_id=account.id, + from_end_user_id=None, + name=fake.sentence(), + inputs={}, + status="normal", + mode="chat", + from_source="api", + app_model_config_id=None, # Explicitly set to None + ) + db.session.add(conversation) + db.session.commit() + + # Create message + message = Message( + id=fake.uuid4(), + conversation_id=conversation.id, + app_id=app.id, + from_account_id=account.id, + from_end_user_id=None, + inputs={}, + query=fake.text(max_nb_chars=100), + message=[{"role": "user", "text": fake.text(max_nb_chars=100)}], + answer=fake.text(max_nb_chars=200), + message_tokens=100, + message_unit_price=0.001, + answer_tokens=200, + answer_unit_price=0.001, + provider_response_latency=1.5, + currency="USD", + from_source="api", + ) + db.session.add(message) + db.session.commit() + + # Execute the method under test + with pytest.raises(ValueError, match="App model config not found"): + AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + def test_get_agent_logs_agent_config_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when agent config is not found. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + + # Mock AgentConfigManager to return None + mock_external_service_dependencies["agent_config_manager"].convert.return_value = None + + # Execute the method under test + with pytest.raises(ValueError, match="Agent config not found"): + AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + def test_list_agent_providers_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful listing of agent providers. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Execute the method under test + result = AgentService.list_agent_providers(str(account.id), str(app.tenant_id)) + + # Verify the result + assert result is not None + assert len(result) == 1 + assert result[0].plugin_id == "test_plugin" + + # Verify the mock was called correctly + mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value + mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id)) + + def test_get_agent_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of specific agent provider. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + provider_name = "test_provider" + + # Execute the method under test + result = AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name) + + # Verify the result + assert result is not None + assert result.plugin_id == "test_plugin" + + # Verify the mock was called correctly + mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value + mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name) + + def test_get_agent_provider_plugin_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test error handling when plugin daemon client raises an error. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + provider_name = "test_provider" + error_message = "Plugin not found" + + # Mock PluginAgentClient to raise an error + mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value + mock_plugin_client.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError(error_message) + + # Execute the method under test + with pytest.raises(ValueError, match=error_message): + AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name) + + def test_get_agent_logs_with_complex_tool_data( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test agent logs retrieval with complex tool data and multiple tools. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + + from extensions.ext_database import db + + # Create agent thought with multiple tools + complex_thought = MessageAgentThought( + id=fake.uuid4(), + message_id=message.id, + position=1, + thought="I need to use multiple tools to complete this task", + tool="tool1;tool2;tool3", + tool_labels_str=json.dumps( + { + "tool1": {"en_US": "First Tool", "zh_Hans": "第一个工具"}, + "tool2": {"en_US": "Second Tool", "zh_Hans": "第二个工具"}, + "tool3": {"en_US": "Third Tool", "zh_Hans": "第三个工具"}, + } + ), + tool_meta_str=json.dumps( + { + "tool1": { + "error": None, + "time_cost": 0.5, + "tool_config": {"tool_provider_type": "test_provider", "tool_provider": "test_id"}, + "tool_parameters": {"param1": "value1"}, + }, + "tool2": { + "error": "Tool 2 failed", + "time_cost": 0.3, + "tool_config": {"tool_provider_type": "another_provider", "tool_provider": "another_id"}, + "tool_parameters": {"param2": "value2"}, + }, + "tool3": { + "error": None, + "time_cost": 0.7, + "tool_config": {"tool_provider_type": "dataset-retrieval", "tool_provider": "dataset_id"}, + "tool_parameters": {"param3": "value3"}, + }, + } + ), + tool_input=json.dumps( + {"tool1": {"input1": "data1"}, "tool2": {"input2": "data2"}, "tool3": {"input3": "data3"}} + ), + observation=json.dumps( + {"tool1": {"output1": "result1"}, "tool2": {"output2": "result2"}, "tool3": {"output3": "result3"}} + ), + tokens=100, + created_by_role="account", + created_by=message.from_account_id, + ) + db.session.add(complex_thought) + db.session.commit() + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result + assert result is not None + iterations = result["iterations"] + assert len(iterations) == 1 + + tool_calls = iterations[0]["tool_calls"] + assert len(tool_calls) == 3 + + # Verify first tool + assert tool_calls[0]["tool_name"] == "tool1" + assert tool_calls[0]["tool_label"] == {"en_US": "First Tool", "zh_Hans": "第一个工具"} + assert tool_calls[0]["status"] == "success" + assert tool_calls[0]["tool_parameters"] == {"param1": "value1"} + + # Verify second tool (with error) + assert tool_calls[1]["tool_name"] == "tool2" + assert tool_calls[1]["tool_label"] == {"en_US": "Second Tool", "zh_Hans": "第二个工具"} + assert tool_calls[1]["status"] == "error" + assert tool_calls[1]["error"] == "Tool 2 failed" + + # Verify third tool (dataset tool) + assert tool_calls[2]["tool_name"] == "tool3" + assert tool_calls[2]["tool_label"] == {"en_US": "Third Tool", "zh_Hans": "第三个工具"} + assert tool_calls[2]["status"] == "success" + assert tool_calls[2]["tool_icon"] == "" # dataset-retrieval tools have empty icon + + def test_get_agent_logs_with_files(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test agent logs retrieval with message files and agent thought files. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + + from core.file import FileTransferMethod, FileType + from extensions.ext_database import db + from models.enums import CreatorUserRole + + # Add files to message + from models.model import MessageFile + + message_file1 = MessageFile( + message_id=message.id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + url="http://example.com/file1.jpg", + belongs_to="user", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=message.from_account_id, + ) + message_file2 = MessageFile( + message_id=message.id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + url="http://example.com/file2.png", + belongs_to="user", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=message.from_account_id, + ) + db.session.add(message_file1) + db.session.add(message_file2) + db.session.commit() + + # Create agent thought with files + thought_with_files = MessageAgentThought( + id=fake.uuid4(), + message_id=message.id, + position=1, + thought="I need to process some files", + tool="file_tool", + tool_labels_str=json.dumps({"file_tool": {"en_US": "File Tool", "zh_Hans": "文件工具"}}), + tool_meta_str=json.dumps( + { + "file_tool": { + "error": None, + "time_cost": 0.5, + "tool_config": {"tool_provider_type": "test_provider", "tool_provider": "test_id"}, + "tool_parameters": {}, + } + } + ), + tool_input=json.dumps({"file_tool": {"input": "test_input"}}), + observation=json.dumps({"file_tool": {"output": "test_output"}}), + message_files=json.dumps(["file1", "file2"]), + tokens=50, + created_by_role="account", + created_by=message.from_account_id, + ) + db.session.add(thought_with_files) + db.session.commit() + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result + assert result is not None + assert len(result["files"]) == 2 + + iterations = result["iterations"] + assert len(iterations) == 1 + assert len(iterations[0]["files"]) == 2 + assert "file1" in iterations[0]["files"] + assert "file2" in iterations[0]["files"] + + def test_get_agent_logs_with_different_timezone( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test agent logs retrieval with different timezone settings. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + + # Mock current_user with different timezone + mock_external_service_dependencies["current_user"].timezone = "Asia/Shanghai" + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result + assert result is not None + assert "start_time" in result["meta"] + + # Verify the timezone conversion + start_time = result["meta"]["start_time"] + assert "T" in start_time # ISO format + assert "+08:00" in start_time or "Z" in start_time # Timezone offset + + def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test agent logs retrieval with empty tool data. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + + from extensions.ext_database import db + + # Create agent thought with empty tool data + empty_thought = MessageAgentThought( + id=fake.uuid4(), + message_id=message.id, + position=1, + thought="I need to analyze the user's request", + tool="", # Empty tool + tool_labels_str="{}", # Empty labels + tool_meta_str="{}", # Empty meta + tool_input="", # Empty input + observation="", # Empty observation + tokens=50, + created_by_role="account", + created_by=message.from_account_id, + ) + db.session.add(empty_thought) + db.session.commit() + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result + assert result is not None + iterations = result["iterations"] + assert len(iterations) == 1 + + # Verify empty tool calls + tool_calls = iterations[0]["tool_calls"] + assert len(tool_calls) == 0 # No tools to process + + def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test agent logs retrieval with malformed JSON data in tool fields. + """ + fake = Faker() + + # Create test data + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) + + from extensions.ext_database import db + + # Create agent thought with malformed JSON + malformed_thought = MessageAgentThought( + id=fake.uuid4(), + message_id=message.id, + position=1, + thought="I need to analyze the user's request", + tool="test_tool", + tool_labels_str="invalid json", # Malformed JSON + tool_meta_str="invalid json", # Malformed JSON + tool_input="invalid json", # Malformed JSON + observation="invalid json", # Malformed JSON + tokens=50, + created_by_role="account", + created_by=message.from_account_id, + ) + db.session.add(malformed_thought) + db.session.commit() + + # Execute the method under test + result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) + + # Verify the result - should handle malformed JSON gracefully + assert result is not None + iterations = result["iterations"] + assert len(iterations) == 1 + + tool_calls = iterations[0]["tool_calls"] + assert len(tool_calls) == 1 + + # Verify default values for malformed JSON + tool_call = tool_calls[0] + assert tool_call["tool_name"] == "test_tool" + assert tool_call["tool_label"] == "test_tool" # Default to tool name + assert tool_call["tool_input"] == {} + assert tool_call["tool_output"] == "invalid json" # Raw observation value + assert tool_call["tool_parameters"] == {} diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py new file mode 100644 index 0000000000..92d93d601e --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -0,0 +1,1252 @@ +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import NotFound + +from models.model import MessageAnnotation +from services.annotation_service import AppAnnotationService +from services.app_service import AppService + + +class TestAnnotationService: + """Integration tests for AnnotationService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.annotation_service.FeatureService") as mock_feature_service, + patch("services.annotation_service.add_annotation_to_index_task") as mock_add_task, + patch("services.annotation_service.update_annotation_to_index_task") as mock_update_task, + patch("services.annotation_service.delete_annotation_index_task") as mock_delete_task, + patch("services.annotation_service.enable_annotation_reply_task") as mock_enable_task, + patch("services.annotation_service.disable_annotation_reply_task") as mock_disable_task, + patch("services.annotation_service.batch_import_annotations_task") as mock_batch_import_task, + patch("services.annotation_service.current_user") as mock_current_user, + ): + # Setup default mock returns + mock_account_feature_service.get_features.return_value.billing.enabled = False + mock_add_task.delay.return_value = None + mock_update_task.delay.return_value = None + mock_delete_task.delay.return_value = None + mock_enable_task.delay.return_value = None + mock_disable_task.delay.return_value = None + mock_batch_import_task.delay.return_value = None + + yield { + "account_feature_service": mock_account_feature_service, + "feature_service": mock_feature_service, + "add_task": mock_add_task, + "update_task": mock_update_task, + "delete_task": mock_delete_task, + "enable_task": mock_enable_task, + "disable_task": mock_disable_task, + "batch_import_task": mock_batch_import_task, + "current_user": mock_current_user, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant first + from services.account_service import AccountService, TenantService + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Setup current_user mock + self._mock_current_user(mock_external_service_dependencies, account.id, tenant.id) + + return app, account + + def _mock_current_user(self, mock_external_service_dependencies, account_id, tenant_id): + """ + Helper method to mock the current user for testing. + """ + mock_external_service_dependencies["current_user"].id = account_id + mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id + + def _create_test_conversation(self, app, account, fake): + """ + Helper method to create a test conversation with all required fields. + """ + from extensions.ext_database import db + from models.model import Conversation + + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=fake.sentence(), + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(conversation) + db.session.flush() + return conversation + + def _create_test_message(self, app, conversation, account, fake): + """ + Helper method to create a test message with all required fields. + """ + import json + + from extensions.ext_database import db + from models.model import Message + + message = Message( + app_id=app.id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation.id, + inputs={}, + query=fake.sentence(), + message=json.dumps([{"role": "user", "text": fake.sentence()}]), + message_tokens=0, + message_unit_price=0, + message_price_unit=0.001, + answer=fake.text(max_nb_chars=200), + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0.001, + parent_message_id=None, + provider_response_latency=0, + total_price=0, + currency="USD", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(message) + db.session.commit() + return message + + def test_insert_app_annotation_directly_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct insertion of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation directly + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + assert annotation.hit_count == 0 + assert annotation.id is not None + + # Verify annotation was saved to database + from extensions.ext_database import db + + db.session.refresh(annotation) + assert annotation.id is not None + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_insert_app_annotation_directly_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test direct insertion of app annotation when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Try to insert annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id) + + def test_update_app_annotation_directly_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct update of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # First, create an annotation + original_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(original_args, app.id) + + # Update the annotation + updated_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + + # Verify annotation was updated correctly + assert updated_annotation.id == annotation.id + assert updated_annotation.app_id == app.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.account_id == account.id + + # Verify original values were changed + assert updated_annotation.question != original_args["question"] + assert updated_annotation.content != original_args["answer"] + + # Verify update_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["update_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_new( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating new annotation from message. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Setup annotation data with message_id + annotation_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation from message + annotation = AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.conversation_id == conversation.id + assert annotation.message_id == message.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_update( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test updating existing annotation from message. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Create initial annotation + initial_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + initial_annotation = AppAnnotationService.up_insert_app_annotation_from_message(initial_args, app.id) + + # Update the annotation + updated_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.up_insert_app_annotation_from_message(updated_args, app.id) + + # Verify annotation was updated correctly (same ID) + assert updated_annotation.id == initial_annotation.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.question != initial_args["question"] + assert updated_annotation.content != initial_args["answer"] + + # Verify add_annotation_to_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["add_task"].delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating annotation from message when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Try to insert annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id) + + def test_get_annotation_list_by_app_id_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful retrieval of annotation list by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple annotations + annotations = [] + for i in range(3): + annotation_args = { + "question": f"Question {i}: {fake.sentence()}", + "answer": f"Answer {i}: {fake.text(max_nb_chars=200)}", + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotations.append(annotation) + + # Get annotation list + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword="" + ) + + # Verify results + assert len(annotation_list) == 3 + assert total == 3 + + # Verify all annotations belong to the correct app + for annotation in annotation_list: + assert annotation.app_id == app.id + assert annotation.account_id == account.id + + def test_get_annotation_list_by_app_id_with_keyword( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of annotation list with keyword search. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotations with specific keywords + unique_keyword = f"unique_{fake.uuid4()[:8]}" + annotation_args = { + "question": f"Question with {unique_keyword} keyword", + "answer": f"Answer with {unique_keyword} keyword", + } + AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + # Create another annotation without the keyword + other_args = { + "question": "Different question without special term", + "answer": "Different answer without special content", + } + + AppAnnotationService.insert_app_annotation_directly(other_args, app.id) + + # Search with keyword + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id( + app.id, page=1, limit=10, keyword=unique_keyword + ) + + # Verify only matching annotations are returned + assert len(annotation_list) == 1 + assert total == 1 + assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content + + def test_get_annotation_list_by_app_id_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of annotation list when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to get annotation list with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="") + + def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful deletion of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotation_id = annotation.id + + # Delete the annotation + AppAnnotationService.delete_app_annotation(app.id, annotation_id) + + # Verify annotation was deleted + from extensions.ext_database import db + + deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + assert deleted_annotation is None + + # Verify delete_annotation_index_task was called (when annotation setting exists) + # Note: In this test, no annotation setting exists, so task should not be called + mock_external_service_dependencies["delete_task"].delay.assert_not_called() + + def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deletion of app annotation when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + annotation_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to delete annotation with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) + + def test_delete_app_annotation_annotation_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test deletion of app annotation when annotation is not found. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + non_existent_annotation_id = fake.uuid4() + + # Try to delete non-existent annotation + with pytest.raises(NotFound, match="Annotation not found"): + AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) + + def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful enabling of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Setup enable arguments + enable_args = { + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + } + + # Enable annotation + result = AppAnnotationService.enable_app_annotation(enable_args, app.id) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["enable_task"].delay.assert_called_once() + + def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful disabling of app annotation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Disable annotation + result = AppAnnotationService.disable_app_annotation(app.id) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["disable_task"].delay.assert_called_once() + + def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test enabling app annotation when job is already cached. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock Redis to return cached job + from extensions.ext_redis import redis_client + + cached_job_id = fake.uuid4() + enable_app_annotation_key = f"enable_app_annotation_{app.id}" + redis_client.set(enable_app_annotation_key, cached_job_id) + + # Setup enable arguments + enable_args = { + "score_threshold": 0.8, + "embedding_provider_name": "openai", + "embedding_model_name": "text-embedding-ada-002", + } + + # Enable annotation (should return cached job) + result = AppAnnotationService.enable_app_annotation(enable_args, app.id) + + # Verify cached result + assert cached_job_id == result["job_id"].decode("utf-8") + assert result["job_status"] == "processing" + + # Verify task was not called again + mock_external_service_dependencies["enable_task"].delay.assert_not_called() + + # Clean up + redis_client.delete(enable_app_annotation_key) + + def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of annotation hit histories. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Add some hit histories + for i in range(3): + AppAnnotationService.add_annotation_history( + annotation_id=annotation.id, + app_id=app.id, + annotation_question=annotation.question, + annotation_content=annotation.content, + query=f"Query {i}: {fake.sentence()}", + user_id=account.id, + message_id=fake.uuid4(), + from_source="console", + score=0.8 + (i * 0.1), + ) + + # Get hit histories + hit_histories, total = AppAnnotationService.get_annotation_hit_histories( + app.id, annotation.id, page=1, limit=10 + ) + + # Verify results + assert len(hit_histories) == 3 + assert total == 3 + + # Verify all histories belong to the correct annotation + for history in hit_histories: + assert history.annotation_id == annotation.id + assert history.app_id == app.id + assert history.account_id == account.id + + def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful addition of annotation history. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Get initial hit count + initial_hit_count = annotation.hit_count + + # Add annotation history + query = fake.sentence() + message_id = fake.uuid4() + score = 0.85 + + AppAnnotationService.add_annotation_history( + annotation_id=annotation.id, + app_id=app.id, + annotation_question=annotation.question, + annotation_content=annotation.content, + query=query, + user_id=account.id, + message_id=message_id, + from_source="console", + score=score, + ) + + # Verify hit count was incremented + from extensions.ext_database import db + + db.session.refresh(annotation) + assert annotation.hit_count == initial_hit_count + 1 + + # Verify history was created + from models.model import AppAnnotationHitHistory + + history = ( + db.session.query(AppAnnotationHitHistory) + .filter( + AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id + ) + .first() + ) + + assert history is not None + assert history.app_id == app.id + assert history.account_id == account.id + assert history.question == query + assert history.score == score + assert history.source == "console" + + def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of annotation by ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an annotation + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + created_annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Get annotation by ID + retrieved_annotation = AppAnnotationService.get_annotation_by_id(created_annotation.id) + + # Verify annotation was retrieved correctly + assert retrieved_annotation is not None + assert retrieved_annotation.id == created_annotation.id + assert retrieved_annotation.app_id == app.id + assert retrieved_annotation.question == annotation_args["question"] + assert retrieved_annotation.content == annotation_args["answer"] + assert retrieved_annotation.account_id == account.id + + def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful batch import of app annotations. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create CSV content + csv_content = "Question 1,Answer 1\nQuestion 2,Answer 2\nQuestion 3,Answer 3" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + mock_external_service_dependencies["feature_service"].get_features.return_value.billing.enabled = False + + # Mock pandas to return expected DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame( + {0: ["Question 1", "Question 2", "Question 3"], 1: ["Answer 1", "Answer 2", "Answer 3"]} + ) + mock_pd.read_csv.return_value = mock_df + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify result structure + assert "job_id" in result + assert "job_status" in result + assert result["job_status"] == "waiting" + assert result["job_id"] is not None + + # Verify task was called + mock_external_service_dependencies["batch_import_task"].delay.assert_called_once() + + def test_batch_import_app_annotations_empty_file( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test batch import with empty CSV file. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create empty CSV content + csv_content = "" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + # Mock pandas to return empty DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame() + mock_pd.read_csv.return_value = mock_df + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify error result + assert "error_msg" in result + assert "empty" in result["error_msg"].lower() + + def test_batch_import_app_annotations_quota_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test batch import when quota is exceeded. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create CSV content + csv_content = "Question 1,Answer 1\nQuestion 2,Answer 2\nQuestion 3,Answer 3" + + # Mock FileStorage + from io import BytesIO + + from werkzeug.datastructures import FileStorage + + file_storage = FileStorage( + stream=BytesIO(csv_content.encode("utf-8")), filename="annotations.csv", content_type="text/csv" + ) + + # Mock pandas to return DataFrame + import pandas as pd + + with patch("services.annotation_service.pd") as mock_pd: + mock_df = pd.DataFrame( + {0: ["Question 1", "Question 2", "Question 3"], 1: ["Answer 1", "Answer 2", "Answer 3"]} + ) + mock_pd.read_csv.return_value = mock_df + + # Mock FeatureService to return billing enabled with quota exceeded + mock_external_service_dependencies["feature_service"].get_features.return_value.billing.enabled = True + mock_external_service_dependencies[ + "feature_service" + ].get_features.return_value.annotation_quota_limit.limit = 1 + mock_external_service_dependencies[ + "feature_service" + ].get_features.return_value.annotation_quota_limit.size = 0 + + # Batch import annotations + result = AppAnnotationService.batch_import_app_annotations(app.id, file_storage) + + # Verify error result + assert "error_msg" in result + assert "limit" in result["error_msg"].lower() + + def test_get_app_annotation_setting_by_app_id_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting enabled app annotation setting by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Get annotation setting + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Verify result structure + assert result["enabled"] is True + assert result["id"] == annotation_setting.id + assert result["score_threshold"] == 0.8 + assert result["embedding_model"]["embedding_provider_name"] == "openai" + assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" + + def test_get_app_annotation_setting_by_app_id_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting disabled app annotation setting by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Get annotation setting (no setting exists) + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Verify result structure + assert result["enabled"] is False + + def test_update_app_annotation_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful update of app annotation setting. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Update annotation setting + update_args = { + "score_threshold": 0.9, + } + + result = AppAnnotationService.update_app_annotation_setting(app.id, annotation_setting.id, update_args) + + # Verify result structure + assert result["enabled"] is True + assert result["id"] == annotation_setting.id + assert result["score_threshold"] == 0.9 + assert result["embedding_model"]["embedding_provider_name"] == "openai" + assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" + + # Verify database was updated + db.session.refresh(annotation_setting) + assert annotation_setting.score_threshold == 0.9 + + def test_export_annotation_list_by_app_id_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful export of annotation list by app ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple annotations + annotations = [] + for i in range(3): + annotation_args = { + "question": f"Question {i}: {fake.sentence()}", + "answer": f"Answer {i}: {fake.text(max_nb_chars=200)}", + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotations.append(annotation) + + # Export annotation list + exported_annotations = AppAnnotationService.export_annotation_list_by_app_id(app.id) + + # Verify results + assert len(exported_annotations) == 3 + + # Verify all annotations belong to the correct app and are ordered by created_at desc + for i, annotation in enumerate(exported_annotations): + assert annotation.app_id == app.id + assert annotation.account_id == account.id + if i > 0: + # Verify descending order (newer first) + assert annotation.created_at <= exported_annotations[i - 1].created_at + + def test_export_annotation_list_by_app_id_app_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test export of annotation list when app is not found. + """ + fake = Faker() + non_existent_app_id = fake.uuid4() + + # Mock random current user to avoid dependency issues + self._mock_current_user(mock_external_service_dependencies, fake.uuid4(), fake.uuid4()) + + # Try to export annotation list with non-existent app + with pytest.raises(NotFound, match="App not found"): + AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id) + + def test_insert_app_annotation_directly_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct insertion of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Setup annotation data + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation directly + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + assert annotation.hit_count == 0 + assert annotation.id is not None + + # Verify add_annotation_to_index_task was called + mock_external_service_dependencies["add_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["add_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == annotation_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id + + def test_update_app_annotation_directly_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful direct update of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # First, create an annotation + original_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(original_args, app.id) + + # Reset mock to clear previous calls + mock_external_service_dependencies["update_task"].delay.reset_mock() + + # Update the annotation + updated_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + + # Verify annotation was updated correctly + assert updated_annotation.id == annotation.id + assert updated_annotation.app_id == app.id + assert updated_annotation.question == updated_args["question"] + assert updated_annotation.content == updated_args["answer"] + assert updated_annotation.account_id == account.id + + # Verify original values were changed + assert updated_annotation.question != original_args["question"] + assert updated_annotation.content != original_args["answer"] + + # Verify update_annotation_to_index_task was called + mock_external_service_dependencies["update_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["update_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == updated_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id + + def test_delete_app_annotation_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful deletion of app annotation with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Create an annotation first + annotation_args = { + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + annotation = AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) + annotation_id = annotation.id + + # Reset mock to clear previous calls + mock_external_service_dependencies["delete_task"].delay.reset_mock() + + # Delete the annotation + AppAnnotationService.delete_app_annotation(app.id, annotation_id) + + # Verify annotation was deleted + deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + assert deleted_annotation is None + + # Verify delete_annotation_index_task was called + mock_external_service_dependencies["delete_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["delete_task"].delay.call_args[0] + assert call_args[0] == annotation_id # annotation_id + assert call_args[1] == app.id # app_id + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == collection_binding.id # collection_binding_id + + def test_up_insert_app_annotation_from_message_with_setting_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating annotation from message with annotation setting enabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create annotation setting first + from extensions.ext_database import db + from models.dataset import DatasetCollectionBinding + from models.model import AppAnnotationSetting + + # Create a collection binding first + collection_binding = DatasetCollectionBinding() + collection_binding.id = fake.uuid4() + collection_binding.provider_name = "openai" + collection_binding.model_name = "text-embedding-ada-002" + collection_binding.type = "annotation" + collection_binding.collection_name = f"annotation_collection_{fake.uuid4()}" + db.session.add(collection_binding) + db.session.flush() + + # Create annotation setting + annotation_setting = AppAnnotationSetting() + annotation_setting.app_id = app.id + annotation_setting.score_threshold = 0.8 + annotation_setting.collection_binding_id = collection_binding.id + annotation_setting.created_user_id = account.id + annotation_setting.updated_user_id = account.id + db.session.add(annotation_setting) + db.session.commit() + + # Create a conversation and message first + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Setup annotation data with message_id + annotation_args = { + "message_id": message.id, + "question": fake.sentence(), + "answer": fake.text(max_nb_chars=200), + } + + # Insert annotation from message + annotation = AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, app.id) + + # Verify annotation was created correctly + assert annotation.app_id == app.id + assert annotation.conversation_id == conversation.id + assert annotation.message_id == message.id + assert annotation.question == annotation_args["question"] + assert annotation.content == annotation_args["answer"] + assert annotation.account_id == account.id + + # Verify add_annotation_to_index_task was called + mock_external_service_dependencies["add_task"].delay.assert_called_once() + call_args = mock_external_service_dependencies["add_task"].delay.call_args[0] + assert call_args[0] == annotation.id # annotation_id + assert call_args[1] == annotation_args["question"] # question + assert call_args[2] == account.current_tenant_id # tenant_id + assert call_args[3] == app.id # app_id + assert call_args[4] == collection_binding.id # collection_binding_id diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py new file mode 100644 index 0000000000..6cd8337ff9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -0,0 +1,487 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from models.api_based_extension import APIBasedExtension +from services.account_service import AccountService, TenantService +from services.api_based_extension_service import APIBasedExtensionService + + +class TestAPIBasedExtensionService: + """Integration tests for APIBasedExtensionService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.api_based_extension_service.APIBasedExtensionRequestor") as mock_requestor, + ): + # Setup default mock returns + mock_account_feature_service.get_features.return_value.billing.enabled = False + + # Mock successful ping response + mock_requestor_instance = mock_requestor.return_value + mock_requestor_instance.request.return_value = {"result": "pong"} + + yield { + "account_feature_service": mock_account_feature_service, + "requestor": mock_requestor, + "requestor_instance": mock_requestor_instance, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + return account, tenant + + def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful saving of API-based extension. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + # Save extension + saved_extension = APIBasedExtensionService.save(extension_data) + + # Verify extension was saved correctly + assert saved_extension.id is not None + assert saved_extension.tenant_id == tenant.id + assert saved_extension.name == extension_data.name + assert saved_extension.api_endpoint == extension_data.api_endpoint + assert saved_extension.api_key == extension_data.api_key # Should be decrypted when retrieved + assert saved_extension.created_at is not None + + # Verify extension was saved to database + from extensions.ext_database import db + + db.session.refresh(saved_extension) + assert saved_extension.id is not None + + # Verify ping connection was called + mock_external_service_dependencies["requestor_instance"].request.assert_called_once() + + def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation errors when saving extension with invalid data. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Test empty name + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = "" + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test empty api_endpoint + extension_data.name = fake.company() + extension_data.api_endpoint = "" + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test empty api_key + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = "" + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService.save(extension_data) + + def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of all extensions by tenant ID. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create multiple extensions + extensions = [] + for i in range(3): + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = f"Extension {i}: {fake.company()}" + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + saved_extension = APIBasedExtensionService.save(extension_data) + extensions.append(saved_extension) + + # Get all extensions for tenant + extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + + # Verify results + assert len(extension_list) == 3 + + # Verify all extensions belong to the correct tenant and are ordered by created_at desc + for i, extension in enumerate(extension_list): + assert extension.tenant_id == tenant.id + assert extension.api_key is not None # Should be decrypted + if i > 0: + # Verify descending order (newer first) + assert extension.created_at <= extension_list[i - 1].created_at + + def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of extension by tenant ID and extension ID. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create an extension + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + + # Get extension by ID + retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + + # Verify extension was retrieved correctly + assert retrieved_extension is not None + assert retrieved_extension.id == created_extension.id + assert retrieved_extension.tenant_id == tenant.id + assert retrieved_extension.name == extension_data.name + assert retrieved_extension.api_endpoint == extension_data.api_endpoint + assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted + assert retrieved_extension.created_at is not None + + def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of extension when extension is not found. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + non_existent_extension_id = fake.uuid4() + + # Try to get non-existent extension + with pytest.raises(ValueError, match="API based extension is not found"): + APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) + + def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful deletion of extension. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create an extension first + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + extension_id = created_extension.id + + # Delete the extension + APIBasedExtensionService.delete(created_extension) + + # Verify extension was deleted + from extensions.ext_database import db + + deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() + assert deleted_extension is None + + def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation error when saving extension with duplicate name. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create first extension + extension_data1 = APIBasedExtension() + extension_data1.tenant_id = tenant.id + extension_data1.name = "Test Extension" + extension_data1.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data1.api_key = fake.password(length=20) + + APIBasedExtensionService.save(extension_data1) + + # Try to create second extension with same name + extension_data2 = APIBasedExtension() + extension_data2.tenant_id = tenant.id + extension_data2.name = "Test Extension" # Same name + extension_data2.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data2.api_key = fake.password(length=20) + + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService.save(extension_data2) + + def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful update of existing extension. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create initial extension + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + + # Save original values for later comparison + original_name = created_extension.name + original_endpoint = created_extension.api_endpoint + + # Update the extension + new_name = fake.company() + new_endpoint = f"https://{fake.domain_name()}/api" + new_api_key = fake.password(length=20) + + created_extension.name = new_name + created_extension.api_endpoint = new_endpoint + created_extension.api_key = new_api_key + + updated_extension = APIBasedExtensionService.save(created_extension) + + # Verify extension was updated correctly + assert updated_extension.id == created_extension.id + assert updated_extension.tenant_id == tenant.id + assert updated_extension.name == new_name + assert updated_extension.api_endpoint == new_endpoint + + # Verify original values were changed + assert updated_extension.name != original_name + assert updated_extension.api_endpoint != original_endpoint + + # Verify ping connection was called for both create and update + assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2 + + # Verify the update by retrieving the extension again + retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + assert retrieved_extension.name == new_name + assert retrieved_extension.api_endpoint == new_endpoint + assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved + + def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test connection error when saving extension with invalid endpoint. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock connection error + mock_external_service_dependencies["requestor_instance"].request.side_effect = ValueError( + "connection error: request timeout" + ) + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = "https://invalid-endpoint.com/api" + extension_data.api_key = fake.password(length=20) + + # Try to save extension with connection error + with pytest.raises(ValueError, match="connection error: request timeout"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_invalid_api_key_length( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test validation error when saving extension with API key that is too short. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup extension data with short API key + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = "1234" # Less than 5 characters + + # Try to save extension with short API key + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation errors when saving extension with empty required fields. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Test with None values + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = None + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test with None api_endpoint + extension_data.name = fake.company() + extension_data.api_endpoint = None + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService.save(extension_data) + + # Test with None api_key + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = None + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService.save(extension_data) + + def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of extensions when no extensions exist for tenant. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Get all extensions for tenant (none exist) + extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + + # Verify empty list is returned + assert len(extension_list) == 0 + assert extension_list == [] + + def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation error when ping response is invalid. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock invalid ping response + mock_external_service_dependencies["requestor_instance"].request.return_value = {"result": "invalid"} + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + # Try to save extension with invalid ping response + with pytest.raises(ValueError, match="{'result': 'invalid'}"): + APIBasedExtensionService.save(extension_data) + + def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test validation error when ping response is missing result field. + """ + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock ping response without result field + mock_external_service_dependencies["requestor_instance"].request.return_value = {"status": "ok"} + + # Setup extension data + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + # Try to save extension with missing ping result + with pytest.raises(ValueError, match="{'status': 'ok'}"): + APIBasedExtensionService.save(extension_data) + + def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of extension when tenant ID doesn't match. + """ + fake = Faker() + account1, tenant1 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create second account and tenant + account2, tenant2 = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create extension in first tenant + extension_data = APIBasedExtension() + extension_data.tenant_id = tenant1.id + extension_data.name = fake.company() + extension_data.api_endpoint = f"https://{fake.domain_name()}/api" + extension_data.api_key = fake.password(length=20) + + created_extension = APIBasedExtensionService.save(extension_data) + + # Try to get extension with wrong tenant ID + with pytest.raises(ValueError, match="API based extension is not found"): + APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py new file mode 100644 index 0000000000..f2bd9f8084 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -0,0 +1,473 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest +import yaml +from faker import Faker + +from models.model import App, AppModelConfig +from services.account_service import AccountService, TenantService +from services.app_dsl_service import AppDslService, ImportMode, ImportStatus +from services.app_service import AppService + + +class TestAppDslService: + """Integration tests for AppDslService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_dsl_service.WorkflowService") as mock_workflow_service, + patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service, + patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service, + patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy, + patch("services.app_dsl_service.redis_client") as mock_redis_client, + patch("services.app_dsl_service.app_was_created") as mock_app_was_created, + patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + ): + # Setup default mock returns + mock_workflow_service.return_value.get_draft_workflow.return_value = None + mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock() + mock_dependencies_service.generate_latest_dependencies.return_value = [] + mock_dependencies_service.get_leaked_dependencies.return_value = [] + mock_dependencies_service.generate_dependencies.return_value = [] + mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None + mock_ssrf_proxy.get.return_value.content = b"test content" + mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None + mock_redis_client.setex.return_value = None + mock_redis_client.get.return_value = None + mock_redis_client.delete.return_value = None + mock_app_was_created.send.return_value = None + mock_app_model_config_was_updated.send.return_value = None + + # Mock ModelManager for app service + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + # Mock FeatureService and EnterpriseService + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + yield { + "workflow_service": mock_workflow_service, + "dependencies_service": mock_dependencies_service, + "draft_variable_service": mock_draft_variable_service, + "ssrf_proxy": mock_ssrf_proxy, + "redis_client": mock_redis_client, + "app_was_created": mock_app_was_created, + "app_model_config_was_updated": mock_app_model_config_was_updated, + "model_manager": mock_model_manager, + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + with patch("services.account_service.FeatureService") as mock_account_feature_service: + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + return app, account + + def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"): + """ + Helper method to create simple YAML content for testing. + """ + yaml_data = { + "version": "0.3.0", + "kind": "app", + "app": { + "name": app_name, + "mode": app_mode, + "icon": "🤖", + "icon_background": "#FFEAD5", + "description": "Test app description", + "use_icon_as_answer_icon": False, + }, + "model_config": { + "model": { + "provider": "openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": { + "max_tokens": 1000, + "temperature": 0.7, + "top_p": 1.0, + }, + }, + "pre_prompt": "You are a helpful assistant.", + "prompt_type": "simple", + }, + } + return yaml.dump(yaml_data, allow_unicode=True) + + def test_import_app_yaml_content_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app import from YAML content. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create YAML content + yaml_content = self._create_simple_yaml_content(fake.company(), "chat") + + # Import app + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=yaml_content, + name="Imported App", + description="Imported app description", + ) + + # Verify import result + assert result.status == ImportStatus.COMPLETED + assert result.app_id is not None + assert result.app_mode == "chat" + assert result.imported_dsl_version == "0.3.0" + assert result.error == "" + + # Verify app was created in database + imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first() + assert imported_app is not None + assert imported_app.name == "Imported App" + assert imported_app.description == "Imported app description" + assert imported_app.mode == "chat" + assert imported_app.tenant_id == account.current_tenant_id + assert imported_app.created_by == account.id + + # Verify model config was created + model_config = ( + db_session_with_containers.query(AppModelConfig).filter(AppModelConfig.app_id == result.app_id).first() + ) + assert model_config is not None + # The provider and model_id are stored in the model field as JSON + model_dict = model_config.model_dict + assert model_dict["provider"] == "openai" + assert model_dict["name"] == "gpt-3.5-turbo" + + def test_import_app_yaml_url_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app import from YAML URL. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create YAML content for mock response + yaml_content = self._create_simple_yaml_content(fake.company(), "chat") + + # Setup mock response + mock_response = MagicMock() + mock_response.content = yaml_content.encode("utf-8") + mock_response.raise_for_status.return_value = None + mock_external_service_dependencies["ssrf_proxy"].get.return_value = mock_response + + # Import app from URL + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/app.yaml", + name="URL Imported App", + description="App imported from URL", + ) + + # Verify import result + assert result.status == ImportStatus.COMPLETED + assert result.app_id is not None + assert result.app_mode == "chat" + assert result.imported_dsl_version == "0.3.0" + assert result.error == "" + + # Verify app was created in database + imported_app = db_session_with_containers.query(App).filter(App.id == result.app_id).first() + assert imported_app is not None + assert imported_app.name == "URL Imported App" + assert imported_app.description == "App imported from URL" + assert imported_app.mode == "chat" + assert imported_app.tenant_id == account.current_tenant_id + + # Verify ssrf_proxy was called + mock_external_service_dependencies["ssrf_proxy"].get.assert_called_once_with( + "https://example.com/app.yaml", follow_redirects=True, timeout=(10, 10) + ) + + def test_import_app_invalid_yaml_format(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with invalid YAML format. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create invalid YAML content + invalid_yaml = "invalid: yaml: content: [" + + # Import app with invalid YAML + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=invalid_yaml, + name="Invalid App", + ) + + # Verify import failed + assert result.status == ImportStatus.FAILED + assert result.app_id is None + assert "Invalid YAML format" in result.error + assert result.imported_dsl_version == "" + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with missing YAML content. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Import app without YAML content + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + name="Missing Content App", + ) + + # Verify import failed + assert result.status == ImportStatus.FAILED + assert result.app_id is None + assert "yaml_content is required" in result.error + assert result.imported_dsl_version == "" + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with missing YAML URL. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Import app without YAML URL + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_URL, + name="Missing URL App", + ) + + # Verify import failed + assert result.status == ImportStatus.FAILED + assert result.app_id is None + assert "yaml_url is required" in result.error + assert result.imported_dsl_version == "" + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app import with invalid import mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create YAML content + yaml_content = self._create_simple_yaml_content(fake.company(), "chat") + + # Import app with invalid mode should raise ValueError + dsl_service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"): + dsl_service.import_app( + account=account, + import_mode="invalid-mode", + yaml_content=yaml_content, + name="Invalid Mode App", + ) + + # Verify no app was created in database + apps_count = db_session_with_containers.query(App).filter(App.tenant_id == account.current_tenant_id).count() + assert apps_count == 1 # Only the original test app + + def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful DSL export for chat app. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create model config for the app + model_config = AppModelConfig() + model_config.id = fake.uuid4() + model_config.app_id = app.id + model_config.provider = "openai" + model_config.model_id = "gpt-3.5-turbo" + model_config.model = json.dumps( + { + "provider": "openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": { + "max_tokens": 1000, + "temperature": 0.7, + }, + } + ) + model_config.pre_prompt = "You are a helpful assistant." + model_config.prompt_type = "simple" + model_config.created_by = account.id + model_config.updated_by = account.id + + # Set the app_model_config_id to link the config + app.app_model_config_id = model_config.id + + db_session_with_containers.add(model_config) + db_session_with_containers.commit() + + # Export DSL + exported_dsl = AppDslService.export_dsl(app, include_secret=False) + + # Parse exported YAML + exported_data = yaml.safe_load(exported_dsl) + + # Verify exported data structure + assert exported_data["kind"] == "app" + assert exported_data["app"]["name"] == app.name + assert exported_data["app"]["mode"] == app.mode + assert exported_data["app"]["icon"] == app.icon + assert exported_data["app"]["icon_background"] == app.icon_background + assert exported_data["app"]["description"] == app.description + + # Verify model config was exported + assert "model_config" in exported_data + # The exported model_config structure may be different from the database structure + # Check that the model config exists and has the expected content + assert exported_data["model_config"] is not None + + # Verify dependencies were exported + assert "dependencies" in exported_data + assert isinstance(exported_data["dependencies"], list) + + def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful DSL export for workflow app. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Update app to workflow mode + app.mode = "workflow" + db_session_with_containers.commit() + + # Mock workflow service to return a workflow + mock_workflow = MagicMock() + mock_workflow.to_dict.return_value = { + "graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []}, + "features": {}, + "environment_variables": [], + "conversation_variables": [], + } + mock_external_service_dependencies[ + "workflow_service" + ].return_value.get_draft_workflow.return_value = mock_workflow + + # Export DSL + exported_dsl = AppDslService.export_dsl(app, include_secret=False) + + # Parse exported YAML + exported_data = yaml.safe_load(exported_dsl) + + # Verify exported data structure + assert exported_data["kind"] == "app" + assert exported_data["app"]["name"] == app.name + assert exported_data["app"]["mode"] == "workflow" + + # Verify workflow was exported + assert "workflow" in exported_data + assert "graph" in exported_data["workflow"] + assert "nodes" in exported_data["workflow"]["graph"] + + # Verify dependencies were exported + assert "dependencies" in exported_data + assert isinstance(exported_data["dependencies"], list) + + # Verify workflow service was called + mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( + app + ) + + def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful dependency checking. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock Redis to return dependencies + mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' + mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json + + # Check dependencies + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.check_dependencies(app_model=app) + + # Verify result + assert result.leaked_dependencies == [] + + # Verify Redis was queried + mock_external_service_dependencies["redis_client"].get.assert_called_once_with( + f"app_check_dependencies:{app.id}" + ) + + # Verify dependencies service was called + mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py new file mode 100644 index 0000000000..ca0f309fd4 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -0,0 +1,1048 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker +from openai._exceptions import RateLimitError + +from core.app.entities.app_invoke_entities import InvokeFrom +from models.model import EndUser +from models.workflow import Workflow +from services.app_generate_service import AppGenerateService +from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError +from services.errors.llm import InvokeRateLimitError + + +class TestAppGenerateService: + """Integration tests for AppGenerateService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_generate_service.BillingService") as mock_billing_service, + patch("services.app_generate_service.WorkflowService") as mock_workflow_service, + patch("services.app_generate_service.RateLimit") as mock_rate_limit, + patch("services.app_generate_service.RateLimiter") as mock_rate_limiter, + patch("services.app_generate_service.CompletionAppGenerator") as mock_completion_generator, + patch("services.app_generate_service.ChatAppGenerator") as mock_chat_generator, + patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator, + patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator, + patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator, + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.app_generate_service.dify_config") as mock_dify_config, + ): + # Setup default mock returns for billing service + mock_billing_service.get_info.return_value = {"subscription": {"plan": "sandbox"}} + + # Setup default mock returns for workflow service + mock_workflow_service_instance = mock_workflow_service.return_value + mock_workflow_service_instance.get_published_workflow.return_value = MagicMock(spec=Workflow) + mock_workflow_service_instance.get_draft_workflow.return_value = MagicMock(spec=Workflow) + mock_workflow_service_instance.get_published_workflow_by_id.return_value = MagicMock(spec=Workflow) + + # Setup default mock returns for rate limiting + mock_rate_limit_instance = mock_rate_limit.return_value + mock_rate_limit_instance.enter.return_value = "test_request_id" + mock_rate_limit_instance.generate.return_value = ["test_response"] + mock_rate_limit_instance.exit.return_value = None + + mock_rate_limiter_instance = mock_rate_limiter.return_value + mock_rate_limiter_instance.is_rate_limited.return_value = False + mock_rate_limiter_instance.increment_rate_limit.return_value = None + + # Setup default mock returns for app generators + mock_completion_generator_instance = mock_completion_generator.return_value + mock_completion_generator_instance.generate.return_value = ["completion_response"] + mock_completion_generator_instance.generate_more_like_this.return_value = ["more_like_this_response"] + mock_completion_generator.convert_to_event_stream.return_value = ["completion_stream"] + + mock_chat_generator_instance = mock_chat_generator.return_value + mock_chat_generator_instance.generate.return_value = ["chat_response"] + mock_chat_generator.convert_to_event_stream.return_value = ["chat_stream"] + + mock_agent_chat_generator_instance = mock_agent_chat_generator.return_value + mock_agent_chat_generator_instance.generate.return_value = ["agent_chat_response"] + mock_agent_chat_generator.convert_to_event_stream.return_value = ["agent_chat_stream"] + + mock_advanced_chat_generator_instance = mock_advanced_chat_generator.return_value + mock_advanced_chat_generator_instance.generate.return_value = ["advanced_chat_response"] + mock_advanced_chat_generator_instance.single_iteration_generate.return_value = ["single_iteration_response"] + mock_advanced_chat_generator_instance.single_loop_generate.return_value = ["single_loop_response"] + mock_advanced_chat_generator.convert_to_event_stream.return_value = ["advanced_chat_stream"] + + mock_workflow_generator_instance = mock_workflow_generator.return_value + mock_workflow_generator_instance.generate.return_value = ["workflow_response"] + mock_workflow_generator_instance.single_iteration_generate.return_value = [ + "workflow_single_iteration_response" + ] + mock_workflow_generator_instance.single_loop_generate.return_value = ["workflow_single_loop_response"] + mock_workflow_generator.convert_to_event_stream.return_value = ["workflow_stream"] + + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Setup dify_config mock returns + mock_dify_config.BILLING_ENABLED = False + mock_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 + mock_dify_config.APP_DAILY_RATE_LIMIT = 1000 + + yield { + "billing_service": mock_billing_service, + "workflow_service": mock_workflow_service, + "rate_limit": mock_rate_limit, + "rate_limiter": mock_rate_limiter, + "completion_generator": mock_completion_generator, + "chat_generator": mock_chat_generator, + "agent_chat_generator": mock_agent_chat_generator, + "advanced_chat_generator": mock_advanced_chat_generator, + "workflow_generator": mock_workflow_generator, + "account_feature_service": mock_account_feature_service, + "dify_config": mock_dify_config, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + mode: App mode to create + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + from services.account_service import AccountService, TenantService + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app with realistic data + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": mode, + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + "max_active_requests": 5, + } + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + return app, account + + def _create_test_workflow(self, db_session_with_containers, app): + """ + Helper method to create a test workflow for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: App instance + + Returns: + Workflow: Created workflow instance + """ + fake = Faker() + + workflow = Workflow( + id=str(uuid.uuid4()), + app_id=app.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + type="workflow", + status="published", + ) + + from extensions.ext_database import db + + db.session.add(workflow) + db.session.commit() + + return workflow + + def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful generation for completion mode app. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify rate limiting was called + mock_external_service_dependencies["rate_limit"].return_value.enter.assert_called_once() + mock_external_service_dependencies["rate_limit"].return_value.generate.assert_called_once() + + # Verify completion generator was called + mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once() + + def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful generation for chat mode app. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="chat" + ) + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify chat generator was called + mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once() + + def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful generation for agent chat mode app. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="agent-chat" + ) + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify agent chat generator was called + mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() + + def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful generation for advanced chat mode app. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify advanced chat generator was called + mock_external_service_dependencies["advanced_chat_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["advanced_chat_generator"].convert_to_event_stream.assert_called_once() + + def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful generation for workflow mode app. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="workflow" + ) + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify workflow generator was called + mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() + + def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test generation with a specific workflow ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + workflow_id = str(uuid.uuid4()) + + # Setup test arguments + args = { + "inputs": {"query": fake.text(max_nb_chars=50)}, + "workflow_id": workflow_id, + "response_mode": "streaming", + } + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify workflow service was called with specific workflow ID + mock_external_service_dependencies[ + "workflow_service" + ].return_value.get_published_workflow_by_id.assert_called_once() + + def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test generation with debugger invoke from. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify draft workflow was fetched for debugger + mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once() + + def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test generation with non-streaming mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "blocking"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=False + ) + + # Verify the result + assert result == ["test_response"] + + # Verify rate limit exit was called for non-streaming mode + mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() + + def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test generation with EndUser instead of Account. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Create end user + end_user = EndUser( + tenant_id=account.current_tenant.id, + app_id=app.id, + type="normal", + external_user_id=fake.uuid4(), + name=fake.name(), + is_anonymous=False, + session_id=fake.uuid4(), + ) + + from extensions.ext_database import db + + db.session.add(end_user) + db.session.commit() + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + def test_generate_with_billing_enabled_sandbox_plan( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generation with billing enabled and sandbox plan. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Setup billing service mock for sandbox plan + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "subscription": {"plan": "sandbox"} + } + + # Set BILLING_ENABLED to True for this test + mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify billing service was called + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(app.tenant_id) + + def test_generate_with_rate_limit_exceeded(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test generation when rate limit is exceeded. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Setup billing service mock for sandbox plan + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "subscription": {"plan": "sandbox"} + } + + # Set BILLING_ENABLED to True for this test + mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True + + # Setup system rate limiter to return rate limited + with patch("services.app_generate_service.AppGenerateService.system_rate_limiter") as mock_system_rate_limiter: + mock_system_rate_limiter.is_rate_limited.return_value = True + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test and expect rate limit error + with pytest.raises(InvokeRateLimitError) as exc_info: + AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify error message + assert "Rate limit exceeded" in str(exc_info.value) + + def test_generate_with_rate_limit_error_from_openai( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generation when OpenAI rate limit error occurs. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Setup completion generator to raise RateLimitError + mock_response = MagicMock() + mock_response.request = MagicMock() + mock_external_service_dependencies["completion_generator"].return_value.generate.side_effect = RateLimitError( + "Rate limit exceeded", response=mock_response, body=None + ) + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test and expect rate limit error + with pytest.raises(InvokeRateLimitError) as exc_info: + AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify error message + assert "Rate limit exceeded" in str(exc_info.value) + + def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test generation with invalid app mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="chat" + ) + + # Manually set invalid mode after creation + app.mode = "invalid_mode" + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test and expect ValueError + with pytest.raises(ValueError) as exc_info: + AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify error message + assert "Invalid app mode" in str(exc_info.value) + + def test_generate_with_workflow_id_format_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generation with invalid workflow ID format. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + # Setup test arguments with invalid workflow ID + args = { + "inputs": {"query": fake.text(max_nb_chars=50)}, + "workflow_id": "invalid_uuid", + "response_mode": "streaming", + } + + # Execute the method under test and expect WorkflowIdFormatError + with pytest.raises(WorkflowIdFormatError) as exc_info: + AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify error message + assert "Invalid workflow_id format" in str(exc_info.value) + + def test_generate_with_workflow_not_found_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generation when workflow is not found. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + workflow_id = str(uuid.uuid4()) + + # Setup workflow service to return None (workflow not found) + mock_external_service_dependencies[ + "workflow_service" + ].return_value.get_published_workflow_by_id.return_value = None + + # Setup test arguments + args = { + "inputs": {"query": fake.text(max_nb_chars=50)}, + "workflow_id": workflow_id, + "response_mode": "streaming", + } + + # Execute the method under test and expect WorkflowNotFoundError + with pytest.raises(WorkflowNotFoundError) as exc_info: + AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify error message + assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value) + + def test_generate_with_workflow_not_initialized_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generation when workflow is not initialized for debugger. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + # Setup workflow service to return None (workflow not initialized) + mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.return_value = None + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test and expect ValueError + with pytest.raises(ValueError) as exc_info: + AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True + ) + + # Verify error message + assert "Workflow not initialized" in str(exc_info.value) + + def test_generate_with_workflow_not_published_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generation when workflow is not published for non-debugger. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + # Setup workflow service to return None (workflow not published) + mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test and expect ValueError + with pytest.raises(ValueError) as exc_info: + AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify error message + assert "Workflow not published" in str(exc_info.value) + + def test_generate_single_iteration_advanced_chat_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful single iteration generation for advanced chat mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + node_id = fake.uuid4() + args = {"inputs": {"query": fake.text(max_nb_chars=50)}} + + # Execute the method under test + result = AppGenerateService.generate_single_iteration( + app_model=app, user=account, node_id=node_id, args=args, streaming=True + ) + + # Verify the result + assert result == ["advanced_chat_stream"] + + # Verify advanced chat generator was called + mock_external_service_dependencies[ + "advanced_chat_generator" + ].return_value.single_iteration_generate.assert_called_once() + + def test_generate_single_iteration_workflow_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful single iteration generation for workflow mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="workflow" + ) + + node_id = fake.uuid4() + args = {"inputs": {"query": fake.text(max_nb_chars=50)}} + + # Execute the method under test + result = AppGenerateService.generate_single_iteration( + app_model=app, user=account, node_id=node_id, args=args, streaming=True + ) + + # Verify the result + assert result == ["advanced_chat_stream"] + + # Verify workflow generator was called + mock_external_service_dependencies[ + "workflow_generator" + ].return_value.single_iteration_generate.assert_called_once() + + def test_generate_single_iteration_invalid_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test single iteration generation with invalid app mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + node_id = fake.uuid4() + args = {"inputs": {"query": fake.text(max_nb_chars=50)}} + + # Execute the method under test and expect ValueError + with pytest.raises(ValueError) as exc_info: + AppGenerateService.generate_single_iteration( + app_model=app, user=account, node_id=node_id, args=args, streaming=True + ) + + # Verify error message + assert "Invalid app mode" in str(exc_info.value) + + def test_generate_single_loop_advanced_chat_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful single loop generation for advanced chat mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + node_id = fake.uuid4() + args = {"inputs": {"query": fake.text(max_nb_chars=50)}} + + # Execute the method under test + result = AppGenerateService.generate_single_loop( + app_model=app, user=account, node_id=node_id, args=args, streaming=True + ) + + # Verify the result + assert result == ["advanced_chat_stream"] + + # Verify advanced chat generator was called + mock_external_service_dependencies[ + "advanced_chat_generator" + ].return_value.single_loop_generate.assert_called_once() + + def test_generate_single_loop_workflow_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful single loop generation for workflow mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="workflow" + ) + + node_id = fake.uuid4() + args = {"inputs": {"query": fake.text(max_nb_chars=50)}} + + # Execute the method under test + result = AppGenerateService.generate_single_loop( + app_model=app, user=account, node_id=node_id, args=args, streaming=True + ) + + # Verify the result + assert result == ["advanced_chat_stream"] + + # Verify workflow generator was called + mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once() + + def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test single loop generation with invalid app mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + node_id = fake.uuid4() + args = {"inputs": {"query": fake.text(max_nb_chars=50)}} + + # Execute the method under test and expect ValueError + with pytest.raises(ValueError) as exc_info: + AppGenerateService.generate_single_loop( + app_model=app, user=account, node_id=node_id, args=args, streaming=True + ) + + # Verify error message + assert "Invalid app mode" in str(exc_info.value) + + def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful more like this generation. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + message_id = fake.uuid4() + + # Execute the method under test + result = AppGenerateService.generate_more_like_this( + app_model=app, user=account, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["more_like_this_response"] + + # Verify completion generator was called + mock_external_service_dependencies[ + "completion_generator" + ].return_value.generate_more_like_this.assert_called_once() + + def test_generate_more_like_this_with_end_user( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test more like this generation with EndUser. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Create end user + end_user = EndUser( + tenant_id=account.current_tenant.id, + app_id=app.id, + type="normal", + external_user_id=fake.uuid4(), + name=fake.name(), + is_anonymous=False, + session_id=fake.uuid4(), + ) + + from extensions.ext_database import db + + db.session.add(end_user) + db.session.commit() + + message_id = fake.uuid4() + + # Execute the method under test + result = AppGenerateService.generate_more_like_this( + app_model=app, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["more_like_this_response"] + + def test_get_max_active_requests_with_app_limit( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting max active requests with app-specific limit. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Set app-specific limit + app.max_active_requests = 10 + + # Execute the method under test + result = AppGenerateService._get_max_active_requests(app) + + # Verify the result (should return the smaller value between app limit and config limit) + assert result == 10 + + def test_get_max_active_requests_with_config_limit( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting max active requests with config limit being smaller. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Set app-specific limit higher than config + app.max_active_requests = 100 + + # Execute the method under test + result = AppGenerateService._get_max_active_requests(app) + + # Verify the result (should return the smaller value) + # Assuming config limit is smaller than 100 + assert result <= 100 + + def test_get_max_active_requests_with_zero_limits( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting max active requests with zero limits (infinite). + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Set app-specific limit to 0 (infinite) + app.max_active_requests = 0 + + # Execute the method under test + result = AppGenerateService._get_max_active_requests(app) + + # Verify the result (should return config limit when app limit is 0) + assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS + + def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that rate limit exit is called when an exception occurs. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="completion" + ) + + # Setup completion generator to raise an exception + mock_external_service_dependencies["completion_generator"].return_value.generate.side_effect = Exception( + "Test exception" + ) + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test and expect exception + with pytest.raises(Exception) as exc_info: + AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify exception message + assert "Test exception" in str(exc_info.value) + + # Verify rate limit exit was called for cleanup + mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() + + def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test generation with agent mode detection based on app configuration. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="chat" + ) + + # Mock app to have agent mode enabled by setting the mode directly + app.mode = "agent-chat" + + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify agent chat generator was called instead of regular chat generator + mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once() + mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() + + def test_generate_with_different_invoke_from_values( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test generation with different invoke from values. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="advanced-chat" + ) + + # Test different invoke from values + invoke_from_values = [ + InvokeFrom.SERVICE_API, + InvokeFrom.WEB_APP, + InvokeFrom.EXPLORE, + InvokeFrom.DEBUGGER, + ] + + for invoke_from in invoke_from_values: + # Setup test arguments + args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=invoke_from, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test generation with complex arguments including files and external trace ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account( + db_session_with_containers, mock_external_service_dependencies, mode="workflow" + ) + + # Setup complex test arguments + args = { + "inputs": { + "query": fake.text(max_nb_chars=50), + "context": fake.text(max_nb_chars=100), + "parameters": {"temperature": 0.7, "max_tokens": 1000}, + }, + "files": [ + {"id": fake.uuid4(), "name": "test_file.txt", "size": 1024}, + {"id": fake.uuid4(), "name": "test_image.jpg", "size": 2048}, + ], + "external_trace_id": fake.uuid4(), + "response_mode": "streaming", + } + + # Execute the method under test + result = AppGenerateService.generate( + app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True + ) + + # Verify the result + assert result == ["test_response"] + + # Verify workflow generator was called with complex args + mock_external_service_dependencies["workflow_generator"].return_value.generate.assert_called_once() + call_args = mock_external_service_dependencies["workflow_generator"].return_value.generate.call_args + assert call_args[1]["args"] == args diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py new file mode 100644 index 0000000000..69cd9fafee --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -0,0 +1,928 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from constants.model_template import default_app_templates +from models.model import App, Site +from services.account_service import AccountService, TenantService +from services.app_service import AppService + + +class TestAppService: + """Integration tests for AppService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for app service + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Mock ModelManager for model configuration + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app creation with basic parameters. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Verify app was created correctly + assert app.name == app_args["name"] + assert app.description == app_args["description"] + assert app.mode == app_args["mode"] + assert app.icon_type == app_args["icon_type"] + assert app.icon == app_args["icon"] + assert app.icon_background == app_args["icon_background"] + assert app.tenant_id == tenant.id + assert app.api_rph == app_args["api_rph"] + assert app.api_rpm == app_args["api_rpm"] + assert app.created_by == account.id + assert app.updated_by == account.id + assert app.status == "normal" + assert app.enable_site is True + assert app.enable_api is True + assert app.is_demo is False + assert app.is_public is False + assert app.is_universal is False + + def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app creation with different app modes. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Test different app modes + # from AppMode enum in default_app_model_template + app_modes = [v.value for v in default_app_templates] + + for mode in app_modes: + app_args = { + "name": f"{fake.company()} {mode}", + "description": f"Test app for {mode} mode", + "mode": mode, + "icon_type": "emoji", + "icon": "🚀", + "icon_background": "#4ECDC4", + } + + app = app_service.create_app(tenant.id, app_args, account) + + # Verify app mode was set correctly + assert app.mode == mode + assert app.name == app_args["name"] + assert app.tenant_id == tenant.id + assert app.created_by == account.id + + def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app retrieval. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + created_app = app_service.create_app(tenant.id, app_args, account) + + # Get app using the service + retrieved_app = app_service.get_app(created_app) + + # Verify retrieved app matches created app + assert retrieved_app.id == created_app.id + assert retrieved_app.name == created_app.name + assert retrieved_app.description == created_app.description + assert retrieved_app.mode == created_app.mode + assert retrieved_app.tenant_id == created_app.tenant_id + assert retrieved_app.created_by == created_app.created_by + + def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful paginated app list retrieval. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Create multiple apps + app_names = [fake.company() for _ in range(5)] + for name in app_names: + app_args = { + "name": name, + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "📱", + "icon_background": "#96CEB4", + } + app_service.create_app(tenant.id, app_args, account) + + # Get paginated apps + args = { + "page": 1, + "limit": 10, + "mode": "chat", + } + + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + + # Verify pagination results + assert paginated_apps is not None + assert len(paginated_apps.items) >= 5 # Should have at least 5 apps + assert paginated_apps.page == 1 + assert paginated_apps.per_page == 10 + + # Verify all apps belong to the correct tenant + for app in paginated_apps.items: + assert app.tenant_id == tenant.id + assert app.mode == "chat" + + def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test paginated app list with various filters. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Create apps with different modes + chat_app_args = { + "name": "Chat App", + "description": "A chat application", + "mode": "chat", + "icon_type": "emoji", + "icon": "💬", + "icon_background": "#FF6B6B", + } + completion_app_args = { + "name": "Completion App", + "description": "A completion application", + "mode": "completion", + "icon_type": "emoji", + "icon": "✍️", + "icon_background": "#4ECDC4", + } + + chat_app = app_service.create_app(tenant.id, chat_app_args, account) + completion_app = app_service.create_app(tenant.id, completion_app_args, account) + + # Test filter by mode + chat_args = { + "page": 1, + "limit": 10, + "mode": "chat", + } + chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args) + assert len(chat_apps.items) == 1 + assert chat_apps.items[0].mode == "chat" + + # Test filter by name + name_args = { + "page": 1, + "limit": 10, + "mode": "chat", + "name": "Chat", + } + filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args) + assert len(filtered_apps.items) == 1 + assert "Chat" in filtered_apps.items[0].name + + # Test filter by created_by_me + created_by_me_args = { + "page": 1, + "limit": 10, + "mode": "completion", + "is_created_by_me": True, + } + my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args) + assert len(my_apps.items) == 1 + + def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test paginated app list with tag filters. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + app_service = AppService() + + # Create an app + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🏷️", + "icon_background": "#FFEAA7", + } + app = app_service.create_app(tenant.id, app_args, account) + + # Mock TagService to return the app ID for tag filtering + with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service: + mock_tag_service.return_value = [app.id] + + # Test with tag filter + args = { + "page": 1, + "limit": 10, + "mode": "chat", + "tag_ids": ["tag1", "tag2"], + } + + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + + # Verify tag service was called + mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"]) + + # Verify results + assert paginated_apps is not None + assert len(paginated_apps.items) == 1 + assert paginated_apps.items[0].id == app.id + + # Test with tag filter that returns no results + with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service: + mock_tag_service.return_value = [] + + args = { + "page": 1, + "limit": 10, + "mode": "chat", + "tag_ids": ["nonexistent_tag"], + } + + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + + # Should return None when no apps match tag filter + assert paginated_apps is None + + def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app update with all fields. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original values + original_name = app.name + original_description = app.description + original_icon = app.icon + original_icon_background = app.icon_background + original_use_icon_as_answer_icon = app.use_icon_as_answer_icon + + # Update app + update_args = { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "emoji", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + } + + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app(app, update_args) + + # Verify updated fields + assert updated_app.name == update_args["name"] + assert updated_app.description == update_args["description"] + assert updated_app.icon == update_args["icon"] + assert updated_app.icon_background == update_args["icon_background"] + assert updated_app.use_icon_as_answer_icon is True + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app name update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original name + original_name = app.name + + # Update app name + new_name = "New App Name" + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_name(app, new_name) + + assert updated_app.name == new_name + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app icon update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original values + original_icon = app.icon + original_icon_background = app.icon_background + + # Update app icon + new_icon = "🌟" + new_icon_background = "#FFD93D" + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) + + assert updated_app.icon == new_icon + assert updated_app.icon_background == new_icon_background + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app site status update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🌐", + "icon_background": "#74B9FF", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original site status + original_site_status = app.enable_site + + # Update site status to disabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_site_status(app, False) + assert updated_app.enable_site is False + assert updated_app.updated_by == account.id + + # Update site status back to enabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_site_status(updated_app, True) + assert updated_app.enable_site is True + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app API status update. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🔌", + "icon_background": "#A29BFE", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original API status + original_api_status = app.enable_api + + # Update API status to disabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_api_status(app, False) + assert updated_app.enable_api is False + assert updated_app.updated_by == account.id + + # Update API status back to enabled + with patch("flask_login.utils._get_user", return_value=account): + updated_app = app_service.update_app_api_status(updated_app, True) + assert updated_app.enable_api is True + assert updated_app.updated_by == account.id + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app site status update when status doesn't change. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🔄", + "icon_background": "#FD79A8", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store original values + original_site_status = app.enable_site + original_updated_at = app.updated_at + + # Update site status to the same value (no change) + updated_app = app_service.update_app_site_status(app, original_site_status) + + # Verify app is returned unchanged + assert updated_app.id == app.id + assert updated_app.enable_site == original_site_status + assert updated_app.updated_at == original_updated_at + + # Verify other fields remain unchanged + assert updated_app.name == app.name + assert updated_app.description == app.description + assert updated_app.mode == app.mode + assert updated_app.tenant_id == app.tenant_id + assert updated_app.created_by == app.created_by + + def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app deletion. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🗑️", + "icon_background": "#E17055", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store app ID for verification + app_id = app.id + + # Mock the async deletion task + with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task: + mock_delete_task.delay.return_value = None + + # Delete app + app_service.delete_app(app) + + # Verify async deletion task was called + mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) + + # Verify app was deleted from database + from extensions.ext_database import db + + deleted_app = db.session.query(App).filter_by(id=app_id).first() + assert deleted_app is None + + def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app deletion with related data cleanup. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🧹", + "icon_background": "#00B894", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Store app ID for verification + app_id = app.id + + # Mock webapp auth cleanup + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.webapp_auth.enabled = True + + # Mock the async deletion task + with patch("services.app_service.remove_app_and_related_data_task") as mock_delete_task: + mock_delete_task.delay.return_value = None + + # Delete app + app_service.delete_app(app) + + # Verify webapp auth cleanup was called + mock_external_service_dependencies["enterprise_service"].WebAppAuth.cleanup_webapp.assert_called_once_with( + app_id + ) + + # Verify async deletion task was called + mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) + + # Verify app was deleted from database + from extensions.ext_database import db + + deleted_app = db.session.query(App).filter_by(id=app_id).first() + assert deleted_app is None + + def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app metadata retrieval. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "📊", + "icon_background": "#6C5CE7", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Get app metadata + app_meta = app_service.get_app_meta(app) + + # Verify metadata contains expected fields + assert "tool_icons" in app_meta + # Note: get_app_meta currently only returns tool_icons + + def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app code retrieval by app ID. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🔗", + "icon_background": "#FDCB6E", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Get app code by ID + app_code = AppService.get_app_code_by_id(app.id) + + # Verify app code was retrieved correctly + # Note: Site would be created when App is created, site.code is auto-generated + assert app_code is not None + assert len(app_code) > 0 + + def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful app ID retrieval by app code. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app first + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🆔", + "icon_background": "#E84393", + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Create a site for the app + site = Site() + site.app_id = app.id + site.code = fake.postalcode() + site.title = fake.company() + site.status = "normal" + site.default_language = "en-US" + site.customize_token_strategy = "uuid" + from extensions.ext_database import db + + db.session.add(site) + db.session.commit() + + # Get app ID by code + app_id = AppService.get_app_id_by_code(site.code) + + # Verify app ID was retrieved correctly + assert app_id == app.id + + def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app creation with invalid mode. + """ + fake = Faker() + + # Create account and tenant first + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments with invalid mode + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "invalid_mode", # Invalid mode + "icon_type": "emoji", + "icon": "❌", + "icon_background": "#D63031", + } + + app_service = AppService() + + # Attempt to create app with invalid mode + with pytest.raises(ValueError, match="invalid mode value"): + app_service.create_app(tenant.id, app_args, account) diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py new file mode 100644 index 0000000000..8bd5440411 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -0,0 +1,1785 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from services.feature_service import FeatureModel, FeatureService, KnowledgeRateLimitModel, SystemFeatureModel + + +class TestFeatureService: + """Integration tests for FeatureService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.feature_service.BillingService") as mock_billing_service, + patch("services.feature_service.EnterpriseService") as mock_enterprise_service, + ): + # Setup default mock returns for BillingService + mock_billing_service.get_info.return_value = { + "enabled": True, + "subscription": {"plan": "pro", "interval": "monthly", "education": True}, + "members": {"size": 5, "limit": 10}, + "apps": {"size": 3, "limit": 20}, + "vector_space": {"size": 2, "limit": 10}, + "documents_upload_quota": {"size": 15, "limit": 100}, + "annotation_quota_limit": {"size": 8, "limit": 50}, + "docs_processing": "enhanced", + "can_replace_logo": True, + "model_load_balancing_enabled": True, + "knowledge_rate_limit": {"limit": 100}, + } + + mock_billing_service.get_knowledge_rate_limit.return_value = {"limit": 100, "subscription_plan": "pro"} + + # Setup default mock returns for EnterpriseService + mock_enterprise_service.get_workspace_info.return_value = { + "WorkspaceMembers": {"used": 5, "limit": 10, "enabled": True} + } + + mock_enterprise_service.get_info.return_value = { + "SSOEnforcedForSignin": True, + "SSOEnforcedForSigninProtocol": "saml", + "EnableEmailCodeLogin": True, + "EnableEmailPasswordLogin": False, + "IsAllowRegister": False, + "IsAllowCreateWorkspace": False, + "Branding": { + "applicationTitle": "Test Enterprise", + "loginPageLogo": "https://example.com/logo.png", + "workspaceLogo": "https://example.com/workspace.png", + "favicon": "https://example.com/favicon.ico", + }, + "WebAppAuth": {"allowSso": True, "allowEmailCodeLogin": True, "allowEmailPasswordLogin": False}, + "SSOEnforcedForWebProtocol": "oidc", + "License": { + "status": "active", + "expiredAt": "2025-12-31", + "workspaces": {"enabled": True, "limit": 5, "used": 2}, + }, + "PluginInstallationPermission": { + "pluginInstallationScope": "official_only", + "restrictToMarketplaceOnly": True, + }, + } + + yield { + "billing_service": mock_billing_service, + "enterprise_service": mock_enterprise_service, + } + + def _create_test_tenant_id(self): + """Helper method to create a test tenant ID.""" + fake = Faker() + return fake.uuid4() + + def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful feature retrieval with billing and enterprise enabled. + + This test verifies: + - Proper feature model creation with all required fields + - Correct integration with billing service + - Proper enterprise workspace information handling + - Return value correctness and structure + """ + # Arrange: Setup test data with proper config mocking + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = True + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = True + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing features + assert result.billing.enabled is True + assert result.billing.subscription.plan == "pro" + assert result.billing.subscription.interval == "monthly" + assert result.education.activated is True + + # Verify member limitations + assert result.members.size == 5 + assert result.members.limit == 10 + + # Verify app limitations + assert result.apps.size == 3 + assert result.apps.limit == 20 + + # Verify vector space limitations + assert result.vector_space.size == 2 + assert result.vector_space.limit == 10 + + # Verify document upload quota + assert result.documents_upload_quota.size == 15 + assert result.documents_upload_quota.limit == 100 + + # Verify annotation quota + assert result.annotation_quota_limit.size == 8 + assert result.annotation_quota_limit.limit == 50 + + # Verify other features + assert result.docs_processing == "enhanced" + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is True + assert result.knowledge_rate_limit == 100 + + # Verify enterprise features + assert result.workspace_members.enabled is True + assert result.workspace_members.size == 5 + assert result.workspace_members.limit == 10 + + # Verify webapp copyright is enabled for non-sandbox plans + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with( + tenant_id + ) + + def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval for sandbox plan with specific limitations. + + This test verifies: + - Proper handling of sandbox plan limitations + - Correct webapp copyright settings for sandbox + - Transfer workspace restrictions for sandbox plans + - Proper billing service integration + """ + # Arrange: Setup sandbox plan mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = False + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = False + mock_config.EDUCATION_ENABLED = False + + # Set mock return value inside the patch context + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "sandbox", "interval": "monthly", "education": False}, + "members": {"size": 1, "limit": 3}, + "apps": {"size": 1, "limit": 5}, + "vector_space": {"size": 1, "limit": 2}, + "documents_upload_quota": {"size": 5, "limit": 20}, + "annotation_quota_limit": {"size": 2, "limit": 10}, + "docs_processing": "standard", + "can_replace_logo": False, + "model_load_balancing_enabled": False, + "knowledge_rate_limit": {"limit": 10}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify sandbox-specific limitations + assert result.billing.subscription.plan == "sandbox" + assert result.education.activated is False + + # Verify sandbox limitations + assert result.members.size == 1 + assert result.members.limit == 3 + assert result.apps.size == 1 + assert result.apps.limit == 5 + assert result.vector_space.size == 1 + assert result.vector_space.limit == 2 + assert result.documents_upload_quota.size == 5 + assert result.documents_upload_quota.limit == 20 + assert result.annotation_quota_limit.size == 2 + assert result.annotation_quota_limit.limit == 10 + + # Verify sandbox-specific restrictions + assert result.webapp_copyright_enabled is False + assert result.is_allow_transfer_workspace is False + assert result.can_replace_logo is False + assert result.model_load_balancing_enabled is False + assert result.docs_processing == "standard" + assert result.knowledge_rate_limit == 10 + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful knowledge rate limit retrieval with billing enabled. + + This test verifies: + - Proper knowledge rate limit model creation + - Correct integration with billing service + - Proper rate limit configuration + - Return value correctness and structure + """ + # Arrange: Setup test data with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_knowledge_rate_limit(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, KnowledgeRateLimitModel) + + # Verify rate limit configuration + assert result.enabled is True + assert result.limit == 100 + assert result.subscription_plan == "pro" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_called_once_with( + tenant_id + ) + + def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful system features retrieval with enterprise and marketplace enabled. + + This test verifies: + - Proper system feature model creation + - Correct integration with enterprise service + - Proper marketplace configuration + - Return value correctness and structure + """ + # Arrange: Setup test data with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = True + mock_config.ENABLE_EMAIL_CODE_LOGIN = True + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify SSO configuration + assert result.sso_enforced_for_signin is True + assert result.sso_enforced_for_signin_protocol == "saml" + + # Verify authentication settings + assert result.enable_email_code_login is True + assert result.enable_email_password_login is False + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify branding configuration + assert result.branding.application_title == "Test Enterprise" + assert result.branding.login_page_logo == "https://example.com/logo.png" + assert result.branding.workspace_logo == "https://example.com/workspace.png" + assert result.branding.favicon == "https://example.com/favicon.ico" + + # Verify webapp auth configuration + assert result.webapp_auth.allow_sso is True + assert result.webapp_auth.allow_email_code_login is True + assert result.webapp_auth.allow_email_password_login is False + assert result.webapp_auth.sso_config.protocol == "oidc" + + # Verify license configuration + assert result.license.status.value == "active" + assert result.license.expired_at == "2025-12-31" + assert result.license.workspaces.enabled is True + assert result.license.workspaces.limit == 5 + assert result.license.workspaces.size == 2 + + # Verify plugin installation permission + assert result.plugin_installation_permission.plugin_installation_scope == "official_only" + assert result.plugin_installation_permission.restrict_to_marketplace_only is True + + # Verify marketplace configuration + assert result.enable_marketplace is True + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval with basic configuration (no enterprise). + + This test verifies: + - Proper system feature model creation without enterprise + - Correct environment variable handling + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup basic config mock (no enterprise) + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = True + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = True + mock_config.ALLOW_CREATE_WORKSPACE = True + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify basic configuration + assert result.branding.enabled is False + assert result.webapp_auth.enabled is False + assert result.enable_change_email is True + + # Verify authentication settings from config + assert result.enable_email_code_login is True + assert result.enable_email_password_login is True + assert result.enable_social_oauth_login is False + assert result.is_allow_register is True + assert result.is_allow_create_workspace is True + assert result.is_email_setup is True + + # Verify marketplace configuration + assert result.enable_marketplace is False + + # Verify plugin package size (uses default value from dify_config) + assert result.max_plugin_package_size == 15728640 + + def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval when billing is disabled. + + This test verifies: + - Proper feature model creation without billing + - Correct environment variable handling + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup billing disabled mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = True + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = True + + tenant_id = self._create_test_tenant_id() + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing is disabled + assert result.billing.enabled is False + + # Verify environment-based features + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is True + assert result.dataset_operator_enabled is True + assert result.education.enabled is True + + # Verify default limitations + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify no enterprise features + assert result.workspace_members.enabled is False + assert result.webapp_copyright_enabled is False + + def test_get_knowledge_rate_limit_billing_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test knowledge rate limit retrieval when billing is disabled. + + This test verifies: + - Proper knowledge rate limit model creation without billing + - Default rate limit configuration + - Return value correctness and structure + """ + # Arrange: Setup billing disabled mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + + tenant_id = self._create_test_tenant_id() + + # Act: Execute the method under test + result = FeatureService.get_knowledge_rate_limit(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, KnowledgeRateLimitModel) + + # Verify default configuration + assert result.enabled is False + assert result.limit == 10 + assert result.subscription_plan == "" # Empty string when billing is disabled + + # Verify no billing service calls + mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called() + + def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with enterprise enabled but billing disabled. + + This test verifies: + - Proper feature model creation with enterprise only + - Correct enterprise service integration + - Proper workspace member handling + - Return value correctness and structure + """ + # Arrange: Setup enterprise only mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = True + mock_config.CAN_REPLACE_LOGO = False + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = False + mock_config.EDUCATION_ENABLED = False + + tenant_id = self._create_test_tenant_id() + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing is disabled + assert result.billing.enabled is False + + # Verify enterprise features + assert result.webapp_copyright_enabled is True + + # Verify workspace members from enterprise + assert result.workspace_members.enabled is True + assert result.workspace_members.size == 5 + assert result.workspace_members.limit == 10 + + # Verify environment-based features + assert result.can_replace_logo is False + assert result.model_load_balancing_enabled is False + assert result.dataset_operator_enabled is False + assert result.education.enabled is False + + # Verify default limitations + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with( + tenant_id + ) + mock_external_service_dependencies["billing_service"].get_info.assert_not_called() + + def test_get_system_features_enterprise_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval when enterprise is disabled. + + This test verifies: + - Proper system feature model creation without enterprise + - Correct environment variable handling + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup enterprise disabled mock + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + mock_config.MARKETPLACE_ENABLED = True + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = True + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = None + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 50 + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features are disabled + assert result.branding.enabled is False + assert result.webapp_auth.enabled is False + assert result.enable_change_email is True + + # Verify authentication settings from config + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.enable_social_oauth_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + assert result.is_email_setup is False + + # Verify marketplace configuration + assert result.enable_marketplace is True + + # Verify plugin package size (uses default value from dify_config) + assert result.max_plugin_package_size == 15728640 + + # Verify default license status + assert result.license.status.value == "none" + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + + # Verify no enterprise service calls + mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called() + + def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval without tenant ID (billing disabled). + + This test verifies: + - Proper feature model creation without tenant ID + - Correct handling when billing is disabled + - Default configuration values + - Return value correctness and structure + """ + # Arrange: Setup no tenant ID scenario + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + # Act: Execute the method under test + result = FeatureService.get_features("") + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing is disabled due to no tenant ID + assert result.billing.enabled is False + + # Verify environment-based features + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is False + assert result.dataset_operator_enabled is True + assert result.education.enabled is False + + # Verify default limitations + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + + # Verify no billing service calls + mock_external_service_dependencies["billing_service"].get_info.assert_not_called() + + def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with partial billing information. + + This test verifies: + - Proper handling of partial billing data + - Correct fallback to default values + - Proper billing service integration + - Return value correctness and structure + """ + # Arrange: Setup partial billing info mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "basic", "interval": "yearly"}, + # Missing members, apps, vector_space, etc. + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify billing features + assert result.billing.enabled is True + assert result.billing.subscription.plan == "basic" + assert result.billing.subscription.interval == "yearly" + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify basic plan restrictions (non-sandbox plans have webapp copyright enabled) + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case vector space configuration. + + This test verifies: + - Proper handling of vector space quota limits + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case vector space mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "pro", "interval": "monthly"}, + "vector_space": {"size": 0, "limit": 0}, + "apps": {"size": 5, "limit": 10}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify vector space configuration + assert result.vector_space.size == 0 + assert result.vector_space.limit == 0 + + # Verify apps configuration + assert result.apps.size == 5 + assert result.apps.limit == 10 + + # Verify pro plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_webapp_auth( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with edge case webapp auth configuration. + + This test verifies: + - Proper handling of webapp auth configuration + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case webapp auth mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "WebAppAuth": {"allowSso": False, "allowEmailCodeLogin": True, "allowEmailPasswordLogin": False} + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify webapp auth configuration + assert result.webapp_auth.allow_sso is False + assert result.webapp_auth.allow_email_code_login is True + assert result.webapp_auth.allow_email_password_login is False + assert result.webapp_auth.sso_config.protocol == "" + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify default values for missing enterprise info + assert result.sso_enforced_for_signin is False + assert result.sso_enforced_for_signin_protocol == "" + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case members quota configuration. + + This test verifies: + - Proper handling of members quota limits + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case members quota mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "basic", "interval": "yearly"}, + "members": {"size": 10, "limit": 10}, + "vector_space": {"size": 3, "limit": 5}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify members configuration + assert result.members.size == 10 + assert result.members.limit == 10 + + # Verify vector space configuration + assert result.vector_space.size == 3 + assert result.vector_space.limit == 5 + + # Verify basic plan features (non-sandbox plans have webapp copyright enabled) + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 + assert result.knowledge_rate_limit == 10 + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_plugin_installation_permission_scopes( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with different plugin installation permission scopes. + + This test verifies: + - Proper handling of different plugin installation scopes + - Correct enterprise service integration + - Proper permission configuration + - Return value correctness and structure + """ + + # Test case 1: Official only scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": { + "pluginInstallationScope": "official_only", + "restrictToMarketplaceOnly": True, + } + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "official_only" + assert result.plugin_installation_permission.restrict_to_marketplace_only is True + + # Test case 2: All plugins scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": {"pluginInstallationScope": "all", "restrictToMarketplaceOnly": False} + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "all" + assert result.plugin_installation_permission.restrict_to_marketplace_only is False + + # Test case 3: Specific partners scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": { + "pluginInstallationScope": "official_and_specific_partners", + "restrictToMarketplaceOnly": False, + } + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "official_and_specific_partners" + assert result.plugin_installation_permission.restrict_to_marketplace_only is False + + # Test case 4: None scope + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "PluginInstallationPermission": {"pluginInstallationScope": "none", "restrictToMarketplaceOnly": True} + } + + result = FeatureService.get_system_features() + assert result.plugin_installation_permission.plugin_installation_scope == "none" + assert result.plugin_installation_permission.restrict_to_marketplace_only is True + + def test_get_features_workspace_members_missing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval when workspace members info is missing from enterprise. + + This test verifies: + - Proper handling of missing workspace members data + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup missing workspace members mock + tenant_id = self._create_test_tenant_id() + mock_external_service_dependencies["enterprise_service"].get_workspace_info.return_value = { + # Missing WorkspaceMembers key + } + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify workspace members use default values + assert result.workspace_members.enabled is False + assert result.workspace_members.size == 0 + assert result.workspace_members.limit == 0 + + # Verify enterprise features + assert result.webapp_copyright_enabled is True + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with( + tenant_id + ) + + def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval with inactive license. + + This test verifies: + - Proper handling of inactive license status + - Correct enterprise service integration + - Proper license status handling + - Return value correctness and structure + """ + # Arrange: Setup inactive license mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "License": { + "status": "inactive", + "expiredAt": "", + "workspaces": {"enabled": False, "limit": 0, "used": 0}, + } + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify license status + assert result.license.status == "inactive" + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + assert result.license.workspaces.size == 0 + assert result.license.workspaces.limit == 0 + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_system_features_partial_enterprise_info( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with partial enterprise information. + + This test verifies: + - Proper handling of partial enterprise data + - Correct fallback to default values + - Proper enterprise service integration + - Return value correctness and structure + """ + # Arrange: Setup partial enterprise info mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "SSOEnforcedForSignin": True, + "Branding": {"applicationTitle": "Partial Enterprise"}, + # Missing WebAppAuth, License, PluginInstallationPermission, etc. + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify SSO configuration + assert result.sso_enforced_for_signin is True + assert result.sso_enforced_for_signin_protocol == "" + + # Verify branding configuration (partial) + assert result.branding.application_title == "Partial Enterprise" + assert result.branding.login_page_logo == "" + assert result.branding.workspace_logo == "" + assert result.branding.favicon == "" + + # Verify default values for missing enterprise info + assert result.webapp_auth.allow_sso is False + assert result.webapp_auth.allow_email_code_login is False + assert result.webapp_auth.allow_email_password_login is False + assert result.webapp_auth.sso_config.protocol == "" + + # Verify default license status + assert result.license.status == "none" + assert result.license.expired_at == "" + assert result.license.workspaces.enabled is False + + # Verify default plugin installation permission + assert result.plugin_installation_permission.plugin_installation_scope == "all" + assert result.plugin_installation_permission.restrict_to_marketplace_only is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case limit values. + + This test verifies: + - Proper handling of zero and negative limits + - Correct handling of very large limits + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case limits mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "enterprise", "interval": "yearly"}, + "members": {"size": 0, "limit": 0}, + "apps": {"size": 0, "limit": -1}, + "vector_space": {"size": 0, "limit": 999999}, + "documents_upload_quota": {"size": 0, "limit": 0}, + "annotation_quota_limit": {"size": 0, "limit": 1}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify edge case limits + assert result.members.size == 0 + assert result.members.limit == 0 + assert result.apps.size == 0 + assert result.apps.limit == -1 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 999999 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 0 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 1 + + # Verify enterprise plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_protocols( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with edge case protocol values. + + This test verifies: + - Proper handling of empty protocol strings + - Correct handling of special protocol values + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case protocols mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "SSOEnforcedForSigninProtocol": "", + "SSOEnforcedForWebProtocol": " ", + "WebAppAuth": {"allowSso": True, "allowEmailCodeLogin": False, "allowEmailPasswordLogin": True}, + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify edge case protocols + assert result.sso_enforced_for_signin_protocol == "" + assert result.webapp_auth.sso_config.protocol == " " + + # Verify webapp auth configuration + assert result.webapp_auth.allow_sso is True + assert result.webapp_auth.allow_email_code_login is False + assert result.webapp_auth.allow_email_password_login is True + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test feature retrieval with edge case education configuration. + + This test verifies: + - Proper handling of education feature flags + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case education mock + tenant_id = self._create_test_tenant_id() + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "education", "interval": "semester", "education": True}, + "members": {"size": 100, "limit": 200}, + "apps": {"size": 50, "limit": 100}, + "vector_space": {"size": 20, "limit": 50}, + "documents_upload_quota": {"size": 500, "limit": 1000}, + "annotation_quota_limit": {"size": 200, "limit": 500}, + } + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.EDUCATION_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify education features + assert result.education.enabled is True + assert result.education.activated is True + + # Verify education plan limits + assert result.members.size == 100 + assert result.members.limit == 200 + assert result.apps.size == 50 + assert result.apps.limit == 100 + assert result.vector_space.size == 20 + assert result.vector_space.limit == 50 + assert result.documents_upload_quota.size == 500 + assert result.documents_upload_quota.limit == 1000 + assert result.annotation_quota_limit.size == 200 + assert result.annotation_quota_limit.limit == 500 + + # Verify education plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_license_limitation_model_is_available( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test LicenseLimitationModel.is_available method with various scenarios. + + This test verifies: + - Proper quota availability calculation + - Correct handling of unlimited limits + - Proper handling of disabled limits + - Return value correctness for different scenarios + """ + from services.feature_service import LicenseLimitationModel + + # Test case 1: Limit disabled + disabled_limit = LicenseLimitationModel(enabled=False, size=5, limit=10) + assert disabled_limit.is_available(3) is True + assert disabled_limit.is_available(10) is True + + # Test case 2: Unlimited limit + unlimited_limit = LicenseLimitationModel(enabled=True, size=5, limit=0) + assert unlimited_limit.is_available(3) is True + assert unlimited_limit.is_available(100) is True + + # Test case 3: Available quota + available_limit = LicenseLimitationModel(enabled=True, size=5, limit=10) + assert available_limit.is_available(3) is True + assert available_limit.is_available(5) is True + assert available_limit.is_available(1) is True + + # Test case 4: Insufficient quota + insufficient_limit = LicenseLimitationModel(enabled=True, size=8, limit=10) + assert insufficient_limit.is_available(3) is False + assert insufficient_limit.is_available(2) is True + assert insufficient_limit.is_available(1) is True + + # Test case 5: Exact quota usage + exact_limit = LicenseLimitationModel(enabled=True, size=7, limit=10) + assert exact_limit.is_available(3) is True + assert exact_limit.is_available(3) is True + + def test_get_features_workspace_members_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval when workspace members are disabled in enterprise. + + This test verifies: + - Proper handling of disabled workspace members + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup workspace members disabled mock + tenant_id = self._create_test_tenant_id() + mock_external_service_dependencies["enterprise_service"].get_workspace_info.return_value = { + "WorkspaceMembers": {"used": 0, "limit": 0, "enabled": False} + } + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = False + mock_config.ENTERPRISE_ENABLED = True + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify workspace members are disabled + assert result.workspace_members.enabled is False + assert result.workspace_members.size == 0 + assert result.workspace_members.limit == 0 + + # Verify enterprise features + assert result.webapp_copyright_enabled is True + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id) + + def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test system features retrieval with expired license. + + This test verifies: + - Proper handling of expired license status + - Correct enterprise service integration + - Proper license status handling + - Return value correctness and structure + """ + # Arrange: Setup expired license mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "License": { + "status": "expired", + "expiredAt": "2023-12-31", + "workspaces": {"enabled": False, "limit": 0, "used": 0}, + } + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify license status + assert result.license.status == "expired" + assert result.license.expired_at == "2023-12-31" + assert result.license.workspaces.enabled is False + assert result.license.workspaces.size == 0 + assert result.license.workspaces.limit == 0 + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_docs_processing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with edge case document processing configuration. + + This test verifies: + - Proper handling of different document processing modes + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case docs processing mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = True + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "premium", "interval": "monthly"}, + "docs_processing": "advanced", + "can_replace_logo": True, + "model_load_balancing_enabled": True, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify docs processing configuration + assert result.docs_processing == "advanced" + assert result.can_replace_logo is True + assert result.model_load_balancing_enabled is True + + # Verify premium plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default limitations (no specific billing info) + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_branding( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features retrieval with edge case branding configuration. + + This test verifies: + - Proper handling of partial branding information + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case branding mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "Branding": { + "applicationTitle": "Edge Case App", + "loginPageLogo": None, + "workspaceLogo": "", + "favicon": "https://example.com/favicon.ico", + } + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify branding configuration (edge cases) + assert result.branding.application_title == "Edge Case App" + assert result.branding.login_page_logo is None # None value from mock + assert result.branding.workspace_logo == "" + assert result.branding.favicon == "https://example.com/favicon.ico" + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify default values for missing enterprise info + assert result.sso_enforced_for_signin is False + assert result.sso_enforced_for_signin_protocol == "" + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_annotation_quota( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with edge case annotation quota configuration. + + This test verifies: + - Proper handling of annotation quota limits + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case annotation quota mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "enterprise", "interval": "yearly"}, + "annotation_quota_limit": {"size": 999, "limit": 1000}, + "knowledge_rate_limit": {"limit": 500}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify annotation quota configuration + assert result.annotation_quota_limit.size == 999 + assert result.annotation_quota_limit.limit == 1000 + + # Verify knowledge rate limit + assert result.knowledge_rate_limit == 500 + + # Verify enterprise plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_features_edge_case_documents_upload( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with edge case documents upload settings. + + This test verifies: + - Proper handling of edge case documents upload configuration + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup edge case documents upload mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": {"plan": "pro", "interval": "monthly"}, + "documents_upload_quota": { + "size": 0, # Edge case: zero current size + "limit": 0, # Edge case: zero limit + }, + "knowledge_rate_limit": {"limit": 100}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify documents upload quota configuration (edge cases) + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 0 + + # Verify knowledge rate limit + assert result.knowledge_rate_limit == 100 + + # Verify pro plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 # Default value when not provided + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) + + def test_get_system_features_edge_case_license_lost( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test system features with lost license status. + + This test verifies: + - Proper handling of lost license status + - Correct enterprise service integration + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup lost license mock with proper config + with patch("services.feature_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + mock_config.MARKETPLACE_ENABLED = False + mock_config.ENABLE_EMAIL_CODE_LOGIN = False + mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True + mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ALLOW_REGISTER = False + mock_config.ALLOW_CREATE_WORKSPACE = False + mock_config.MAIL_TYPE = "smtp" + mock_config.PLUGIN_MAX_PACKAGE_SIZE = 100 + + mock_external_service_dependencies["enterprise_service"].get_info.return_value = { + "license": {"status": "lost", "expired_at": None, "plan": None} + } + + # Act: Execute the method under test + result = FeatureService.get_system_features() + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, SystemFeatureModel) + + # Verify enterprise features + assert result.branding.enabled is True + assert result.webapp_auth.enabled is True + assert result.enable_change_email is False + + # Verify default values for missing enterprise info + assert result.sso_enforced_for_signin is False + assert result.sso_enforced_for_signin_protocol == "" + assert result.enable_email_code_login is False + assert result.enable_email_password_login is True + assert result.is_allow_register is False + assert result.is_allow_create_workspace is False + + # Verify mock interactions + mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() + + def test_get_features_edge_case_education_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test feature retrieval with education feature disabled. + + This test verifies: + - Proper handling of disabled education features + - Correct integration with billing service + - Proper fallback to default values + - Return value correctness and structure + """ + # Arrange: Setup education disabled mock with proper config + tenant_id = self._create_test_tenant_id() + + with patch("services.feature_service.dify_config") as mock_config: + mock_config.BILLING_ENABLED = True + mock_config.ENTERPRISE_ENABLED = False + mock_config.CAN_REPLACE_LOGO = True + mock_config.MODEL_LB_ENABLED = False + mock_config.DATASET_OPERATOR_ENABLED = True + mock_config.EDUCATION_ENABLED = False + + mock_external_service_dependencies["billing_service"].get_info.return_value = { + "enabled": True, + "subscription": { + "plan": "pro", + "interval": "monthly", + "education": False, # Education explicitly disabled + }, + "knowledge_rate_limit": {"limit": 100}, + } + + # Act: Execute the method under test + result = FeatureService.get_features(tenant_id) + + # Assert: Verify the expected outcomes + assert result is not None + assert isinstance(result, FeatureModel) + + # Verify education configuration + assert result.education.activated is False + + # Verify knowledge rate limit + assert result.knowledge_rate_limit == 100 + + # Verify pro plan features + assert result.webapp_copyright_enabled is True + assert result.is_allow_transfer_workspace is True + + # Verify default values for missing billing info + assert result.members.size == 0 + assert result.members.limit == 1 + assert result.apps.size == 0 + assert result.apps.limit == 10 + assert result.vector_space.size == 0 + assert result.vector_space.limit == 5 + assert result.documents_upload_quota.size == 0 + assert result.documents_upload_quota.limit == 50 + assert result.annotation_quota_limit.size == 0 + assert result.annotation_quota_limit.limit == 10 # Default value when not provided + assert result.docs_processing == "standard" + + # Verify mock interactions + mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py new file mode 100644 index 0000000000..965c9c6242 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -0,0 +1,913 @@ +import hashlib +from io import BytesIO +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import NotFound + +from configs import dify_config +from models.account import Account, Tenant +from models.enums import CreatorUserRole +from models.model import EndUser, UploadFile +from services.errors.file import FileTooLargeError, UnsupportedFileTypeError +from services.file_service import FileService + + +class TestFileService: + """Integration tests for FileService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.file_service.storage") as mock_storage, + patch("services.file_service.file_helpers") as mock_file_helpers, + patch("services.file_service.ExtractProcessor") as mock_extract_processor, + ): + # Setup default mock returns + mock_storage.save.return_value = None + mock_storage.load.return_value = BytesIO(b"mock file content") + mock_file_helpers.get_signed_file_url.return_value = "https://example.com/signed-url" + mock_file_helpers.verify_image_signature.return_value = True + mock_file_helpers.verify_file_signature.return_value = True + mock_extract_processor.load_from_upload_file.return_value = "extracted text content" + + yield { + "storage": mock_storage, + "file_helpers": mock_file_helpers, + "extract_processor": mock_extract_processor, + } + + def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + Account: Created account instance + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + from models.account import TenantAccountJoin, TenantAccountRole + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account + + def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test end user for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + EndUser: Created end user instance + """ + fake = Faker() + + end_user = EndUser( + tenant_id=str(fake.uuid4()), + type="web", + name=fake.name(), + is_anonymous=False, + session_id=fake.uuid4(), + ) + + from extensions.ext_database import db + + db.session.add(end_user) + db.session.commit() + + return end_user + + def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account): + """ + Helper method to create a test upload file for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + account: Account instance + + Returns: + UploadFile: Created upload file instance + """ + fake = Faker() + + upload_file = UploadFile( + tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()), + storage_type="local", + key=f"upload_files/test/{fake.uuid4()}.txt", + name="test_file.txt", + size=1024, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=account.id, + created_at=fake.date_time(), + used=False, + hash=hashlib.sha3_256(b"test content").hexdigest(), + source_url="", + ) + + from extensions.ext_database import db + + db.session.add(upload_file) + db.session.commit() + + return upload_file + + # Test upload_file method + def test_upload_file_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful file upload with valid parameters. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = "test_document.pdf" + content = b"test file content" + mimetype = "application/pdf" + + upload_file = FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + assert upload_file is not None + assert upload_file.name == filename + assert upload_file.size == len(content) + assert upload_file.extension == "pdf" + assert upload_file.mime_type == mimetype + assert upload_file.created_by == account.id + assert upload_file.created_by_role == CreatorUserRole.ACCOUNT.value + assert upload_file.used is False + assert upload_file.hash == hashlib.sha3_256(content).hexdigest() + + # Verify storage was called + mock_external_service_dependencies["storage"].save.assert_called_once() + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(upload_file) + assert upload_file.id is not None + + def test_upload_file_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test file upload with end user instead of account. + """ + fake = Faker() + end_user = self._create_test_end_user(db_session_with_containers, mock_external_service_dependencies) + + filename = "test_image.jpg" + content = b"test image content" + mimetype = "image/jpeg" + + upload_file = FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=end_user, + ) + + assert upload_file is not None + assert upload_file.created_by == end_user.id + assert upload_file.created_by_role == CreatorUserRole.END_USER.value + + def test_upload_file_with_datasets_source(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test file upload with datasets source parameter. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = "test_document.pdf" + content = b"test file content" + mimetype = "application/pdf" + + upload_file = FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + source="datasets", + source_url="https://example.com/source", + ) + + assert upload_file is not None + assert upload_file.source_url == "https://example.com/source" + + def test_upload_file_invalid_filename_characters( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file upload with invalid filename characters. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = "test/file.txt" + content = b"test content" + mimetype = "text/plain" + + with pytest.raises(ValueError, match="Filename contains invalid characters"): + FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + def test_upload_file_filename_too_long(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test file upload with filename that exceeds length limit. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a filename longer than 200 characters + long_name = "a" * 250 + filename = f"{long_name}.txt" + content = b"test content" + mimetype = "text/plain" + + upload_file = FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + # Verify filename was truncated (the logic truncates the base name to 200 chars + extension) + # So the total length should be <= 200 + len(extension) + 1 (for the dot) + assert len(upload_file.name) <= 200 + len(upload_file.extension) + 1 + assert upload_file.name.endswith(".txt") + # Verify the base name was truncated + base_name = upload_file.name[:-4] # Remove .txt + assert len(base_name) <= 200 + + def test_upload_file_datasets_unsupported_type( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file upload for datasets with unsupported file type. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = "test_image.jpg" + content = b"test content" + mimetype = "image/jpeg" + + with pytest.raises(UnsupportedFileTypeError): + FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + source="datasets", + ) + + def test_upload_file_too_large(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test file upload with file size exceeding limit. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = "large_image.jpg" + # Create content larger than the limit + content = b"x" * (dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + 1) + mimetype = "image/jpeg" + + with pytest.raises(FileTooLargeError): + FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + # Test is_file_size_within_limit method + def test_is_file_size_within_limit_image_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file size check for image files within limit. + """ + extension = "jpg" + file_size = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit + + result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size) + + assert result is True + + def test_is_file_size_within_limit_video_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file size check for video files within limit. + """ + extension = "mp4" + file_size = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit + + result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size) + + assert result is True + + def test_is_file_size_within_limit_audio_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file size check for audio files within limit. + """ + extension = "mp3" + file_size = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit + + result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size) + + assert result is True + + def test_is_file_size_within_limit_document_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file size check for document files within limit. + """ + extension = "pdf" + file_size = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 # Exactly at limit + + result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size) + + assert result is True + + def test_is_file_size_within_limit_image_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file size check for image files exceeding limit. + """ + extension = "jpg" + file_size = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + 1 # Exceeds limit + + result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size) + + assert result is False + + def test_is_file_size_within_limit_unknown_extension( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file size check for unknown file extension. + """ + extension = "xyz" + file_size = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 # Uses default limit + + result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size) + + assert result is True + + # Test upload_text method + def test_upload_text_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful text upload. + """ + fake = Faker() + text = "This is a test text content" + text_name = "test_text.txt" + + # Mock current_user + with patch("services.file_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = str(fake.uuid4()) + mock_current_user.id = str(fake.uuid4()) + + upload_file = FileService.upload_text(text=text, text_name=text_name) + + assert upload_file is not None + assert upload_file.name == text_name + assert upload_file.size == len(text) + assert upload_file.extension == "txt" + assert upload_file.mime_type == "text/plain" + assert upload_file.used is True + assert upload_file.used_by == mock_current_user.id + + # Verify storage was called + mock_external_service_dependencies["storage"].save.assert_called_once() + + def test_upload_text_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test text upload with name that exceeds length limit. + """ + fake = Faker() + text = "test content" + long_name = "a" * 250 # Longer than 200 characters + + # Mock current_user + with patch("services.file_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = str(fake.uuid4()) + mock_current_user.id = str(fake.uuid4()) + + upload_file = FileService.upload_text(text=text, text_name=long_name) + + # Verify name was truncated + assert len(upload_file.name) <= 200 + assert upload_file.name == "a" * 200 + + # Test get_file_preview method + def test_get_file_preview_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful file preview generation. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + # Update file to have document extension + upload_file.extension = "pdf" + from extensions.ext_database import db + + db.session.commit() + + result = FileService.get_file_preview(file_id=upload_file.id) + + assert result == "extracted text content" + mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once() + + def test_get_file_preview_file_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test file preview with non-existent file. + """ + fake = Faker() + non_existent_id = str(fake.uuid4()) + + with pytest.raises(NotFound, match="File not found"): + FileService.get_file_preview(file_id=non_existent_id) + + def test_get_file_preview_unsupported_file_type( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file preview with unsupported file type. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + # Update file to have non-document extension + upload_file.extension = "jpg" + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(UnsupportedFileTypeError): + FileService.get_file_preview(file_id=upload_file.id) + + def test_get_file_preview_text_truncation(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test file preview with text that exceeds preview limit. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + # Update file to have document extension + upload_file.extension = "pdf" + from extensions.ext_database import db + + db.session.commit() + + # Mock long text content + long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT + mock_external_service_dependencies["extract_processor"].load_from_upload_file.return_value = long_text + + result = FileService.get_file_preview(file_id=upload_file.id) + + assert len(result) == 3000 # PREVIEW_WORDS_LIMIT + assert result == "x" * 3000 + + # Test get_image_preview method + def test_get_image_preview_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful image preview generation. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + # Update file to have image extension + upload_file.extension = "jpg" + from extensions.ext_database import db + + db.session.commit() + + timestamp = "1234567890" + nonce = "test_nonce" + sign = "test_signature" + + generator, mime_type = FileService.get_image_preview( + file_id=upload_file.id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + + assert generator is not None + assert mime_type == upload_file.mime_type + mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once() + + def test_get_image_preview_invalid_signature(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test image preview with invalid signature. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + # Mock invalid signature + mock_external_service_dependencies["file_helpers"].verify_image_signature.return_value = False + + timestamp = "1234567890" + nonce = "test_nonce" + sign = "invalid_signature" + + with pytest.raises(NotFound, match="File not found or signature is invalid"): + FileService.get_image_preview( + file_id=upload_file.id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + + def test_get_image_preview_file_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test image preview with non-existent file. + """ + fake = Faker() + non_existent_id = str(fake.uuid4()) + + timestamp = "1234567890" + nonce = "test_nonce" + sign = "test_signature" + + with pytest.raises(NotFound, match="File not found or signature is invalid"): + FileService.get_image_preview( + file_id=non_existent_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + + def test_get_image_preview_unsupported_file_type( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test image preview with non-image file type. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + # Update file to have non-image extension + upload_file.extension = "pdf" + from extensions.ext_database import db + + db.session.commit() + + timestamp = "1234567890" + nonce = "test_nonce" + sign = "test_signature" + + with pytest.raises(UnsupportedFileTypeError): + FileService.get_image_preview( + file_id=upload_file.id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + + # Test get_file_generator_by_file_id method + def test_get_file_generator_by_file_id_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful file generator retrieval. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + timestamp = "1234567890" + nonce = "test_nonce" + sign = "test_signature" + + generator, file_obj = FileService.get_file_generator_by_file_id( + file_id=upload_file.id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + + assert generator is not None + assert file_obj == upload_file + mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once() + + def test_get_file_generator_by_file_id_invalid_signature( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file generator retrieval with invalid signature. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + # Mock invalid signature + mock_external_service_dependencies["file_helpers"].verify_file_signature.return_value = False + + timestamp = "1234567890" + nonce = "test_nonce" + sign = "invalid_signature" + + with pytest.raises(NotFound, match="File not found or signature is invalid"): + FileService.get_file_generator_by_file_id( + file_id=upload_file.id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + + def test_get_file_generator_by_file_id_file_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file generator retrieval with non-existent file. + """ + fake = Faker() + non_existent_id = str(fake.uuid4()) + + timestamp = "1234567890" + nonce = "test_nonce" + sign = "test_signature" + + with pytest.raises(NotFound, match="File not found or signature is invalid"): + FileService.get_file_generator_by_file_id( + file_id=non_existent_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) + + # Test get_public_image_preview method + def test_get_public_image_preview_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful public image preview generation. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + # Update file to have image extension + upload_file.extension = "jpg" + from extensions.ext_database import db + + db.session.commit() + + generator, mime_type = FileService.get_public_image_preview(file_id=upload_file.id) + + assert generator is not None + assert mime_type == upload_file.mime_type + mock_external_service_dependencies["storage"].load.assert_called_once() + + def test_get_public_image_preview_file_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test public image preview with non-existent file. + """ + fake = Faker() + non_existent_id = str(fake.uuid4()) + + with pytest.raises(NotFound, match="File not found or signature is invalid"): + FileService.get_public_image_preview(file_id=non_existent_id) + + def test_get_public_image_preview_unsupported_file_type( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test public image preview with non-image file type. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + upload_file = self._create_test_upload_file( + db_session_with_containers, mock_external_service_dependencies, account + ) + + # Update file to have non-image extension + upload_file.extension = "pdf" + from extensions.ext_database import db + + db.session.commit() + + with pytest.raises(UnsupportedFileTypeError): + FileService.get_public_image_preview(file_id=upload_file.id) + + # Test edge cases and boundary conditions + def test_upload_file_empty_content(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test file upload with empty content. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = "empty.txt" + content = b"" + mimetype = "text/plain" + + upload_file = FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + assert upload_file is not None + assert upload_file.size == 0 + + def test_upload_file_special_characters_in_name( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file upload with special characters in filename (but valid ones). + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = "test-file_with_underscores_and.dots.txt" + content = b"test content" + mimetype = "text/plain" + + upload_file = FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + assert upload_file is not None + assert upload_file.name == filename + + def test_upload_file_different_case_extensions( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test file upload with different case extensions. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = "test.PDF" + content = b"test content" + mimetype = "application/pdf" + + upload_file = FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + assert upload_file is not None + assert upload_file.extension == "pdf" # Should be converted to lowercase + + def test_upload_text_empty_text(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test text upload with empty text. + """ + fake = Faker() + text = "" + text_name = "empty.txt" + + # Mock current_user + with patch("services.file_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = str(fake.uuid4()) + mock_current_user.id = str(fake.uuid4()) + + upload_file = FileService.upload_text(text=text, text_name=text_name) + + assert upload_file is not None + assert upload_file.size == 0 + + def test_file_size_limits_edge_cases(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test file size limits with edge case values. + """ + # Test exactly at limit + for extension, limit_config in [ + ("jpg", dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT), + ("mp4", dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT), + ("mp3", dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT), + ("pdf", dify_config.UPLOAD_FILE_SIZE_LIMIT), + ]: + file_size = limit_config * 1024 * 1024 + result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size) + assert result is True + + # Test one byte over limit + file_size = limit_config * 1024 * 1024 + 1 + result = FileService.is_file_size_within_limit(extension=extension, file_size=file_size) + assert result is False + + def test_upload_file_with_source_url(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test file upload with source URL that gets overridden by signed URL. + """ + fake = Faker() + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = "test.pdf" + content = b"test content" + mimetype = "application/pdf" + source_url = "https://original-source.com/file.pdf" + + upload_file = FileService.upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + source_url=source_url, + ) + + # When source_url is provided, it should be preserved + assert upload_file.source_url == source_url + + # The signed URL should only be set when source_url is empty + # Let's test that scenario + upload_file2 = FileService.upload_file( + filename="test2.pdf", + content=b"test content 2", + mimetype="application/pdf", + user=account, + source_url="", # Empty source_url + ) + + # Should have the signed URL when source_url is empty + assert upload_file2.source_url == "https://example.com/signed-url" diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py new file mode 100644 index 0000000000..ece6de6cdf --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -0,0 +1,775 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from models.model import MessageFeedback +from services.app_service import AppService +from services.errors.message import ( + FirstMessageNotExistsError, + LastMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) +from services.message_service import MessageService + + +class TestMessageService: + """Integration tests for MessageService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.message_service.ModelManager") as mock_model_manager, + patch("services.message_service.WorkflowService") as mock_workflow_service, + patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager, + patch("services.message_service.LLMGenerator") as mock_llm_generator, + patch("services.message_service.TraceQueueManager") as mock_trace_manager_class, + patch("services.message_service.TokenBufferMemory") as mock_token_buffer_memory, + ): + # Setup default mock returns + mock_account_feature_service.get_features.return_value.billing.enabled = False + + # Mock ModelManager + mock_model_instance = mock_model_manager.return_value.get_default_model_instance.return_value + mock_model_instance.get_tts_voices.return_value = [{"value": "test-voice"}] + + # Mock get_model_instance method as well + mock_model_manager.return_value.get_model_instance.return_value = mock_model_instance + + # Mock WorkflowService + mock_workflow = mock_workflow_service.return_value.get_published_workflow.return_value + mock_workflow_service.return_value.get_draft_workflow.return_value = mock_workflow + + # Mock AdvancedChatAppConfigManager + mock_app_config = mock_app_config_manager.get_app_config.return_value + mock_app_config.additional_features.suggested_questions_after_answer = True + + # Mock LLMGenerator + mock_llm_generator.generate_suggested_questions_after_answer.return_value = ["Question 1", "Question 2"] + + # Mock TraceQueueManager + mock_trace_manager_instance = mock_trace_manager_class.return_value + + # Mock TokenBufferMemory + mock_memory_instance = mock_token_buffer_memory.return_value + mock_memory_instance.get_history_prompt_text.return_value = "Mocked history prompt" + + yield { + "account_feature_service": mock_account_feature_service, + "model_manager": mock_model_manager, + "workflow_service": mock_workflow_service, + "app_config_manager": mock_app_config_manager, + "llm_generator": mock_llm_generator, + "trace_manager_class": mock_trace_manager_class, + "trace_manager_instance": mock_trace_manager_instance, + "token_buffer_memory": mock_token_buffer_memory, + # "current_user": mock_current_user, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant first + from services.account_service import AccountService, TenantService + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Setup app creation arguments + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "advanced-chat", # Use advanced-chat mode to use mocked workflow + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + # Create app + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + # Setup current_user mock + self._mock_current_user(mock_external_service_dependencies, account.id, tenant.id) + + return app, account + + def _mock_current_user(self, mock_external_service_dependencies, account_id, tenant_id): + """ + Helper method to mock the current user for testing. + """ + # mock_external_service_dependencies["current_user"].id = account_id + # mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id + + def _create_test_conversation(self, app, account, fake): + """ + Helper method to create a test conversation with all required fields. + """ + from extensions.ext_database import db + from models.model import Conversation + + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=fake.sentence(), + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(conversation) + db.session.flush() + return conversation + + def _create_test_message(self, app, conversation, account, fake): + """ + Helper method to create a test message with all required fields. + """ + import json + + from extensions.ext_database import db + from models.model import Message + + message = Message( + app_id=app.id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation.id, + inputs={}, + query=fake.sentence(), + message=json.dumps([{"role": "user", "text": fake.sentence()}]), + message_tokens=0, + message_unit_price=0, + message_price_unit=0.001, + answer=fake.text(max_nb_chars=200), + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0.001, + parent_message_id=None, + provider_response_latency=0, + total_price=0, + currency="USD", + invoke_from="console", + from_source="console", + from_end_user_id=None, + from_account_id=account.id, + ) + + db.session.add(message) + db.session.commit() + return message + + def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful pagination by first ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and multiple messages + conversation = self._create_test_conversation(app, account, fake) + messages = [] + for i in range(5): + message = self._create_test_message(app, conversation, account, fake) + messages.append(message) + + # Test pagination by first ID + result = MessageService.pagination_by_first_id( + app_model=app, + user=account, + conversation_id=conversation.id, + first_id=messages[2].id, # Use middle message as first_id + limit=2, + order="asc", + ) + + # Verify results + assert result.limit == 2 + assert len(result.data) == 2 + # total 5, from the middle, no more + assert result.has_more is False + # Verify messages are in ascending order + assert result.data[0].created_at <= result.data[1].created_at + + def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test pagination by first ID when no user is provided. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Test pagination with no user + result = MessageService.pagination_by_first_id( + app_model=app, user=None, conversation_id=fake.uuid4(), first_id=None, limit=10 + ) + + # Verify empty result + assert result.limit == 10 + assert len(result.data) == 0 + assert result.has_more is False + + def test_pagination_by_first_id_no_conversation_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test pagination by first ID when no conversation ID is provided. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Test pagination with no conversation ID + result = MessageService.pagination_by_first_id( + app_model=app, user=account, conversation_id="", first_id=None, limit=10 + ) + + # Verify empty result + assert result.limit == 10 + assert len(result.data) == 0 + assert result.has_more is False + + def test_pagination_by_first_id_invalid_first_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test pagination by first ID with invalid first_id. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + self._create_test_message(app, conversation, account, fake) + + # Test pagination with invalid first_id + with pytest.raises(FirstMessageNotExistsError): + MessageService.pagination_by_first_id( + app_model=app, + user=account, + conversation_id=conversation.id, + first_id=fake.uuid4(), # Non-existent message ID + limit=10, + ) + + def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful pagination by last ID. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and multiple messages + conversation = self._create_test_conversation(app, account, fake) + messages = [] + for i in range(5): + message = self._create_test_message(app, conversation, account, fake) + messages.append(message) + + # Test pagination by last ID + result = MessageService.pagination_by_last_id( + app_model=app, + user=account, + last_id=messages[2].id, # Use middle message as last_id + limit=2, + conversation_id=conversation.id, + ) + + # Verify results + assert result.limit == 2 + assert len(result.data) == 2 + # total 5, from the middle, no more + assert result.has_more is False + # Verify messages are in descending order + assert result.data[0].created_at >= result.data[1].created_at + + def test_pagination_by_last_id_with_include_ids( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test pagination by last ID with include_ids filter. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and multiple messages + conversation = self._create_test_conversation(app, account, fake) + messages = [] + for i in range(5): + message = self._create_test_message(app, conversation, account, fake) + messages.append(message) + + # Test pagination with include_ids + include_ids = [messages[0].id, messages[1].id, messages[2].id] + result = MessageService.pagination_by_last_id( + app_model=app, user=account, last_id=messages[1].id, limit=2, include_ids=include_ids + ) + + # Verify results + assert result.limit == 2 + assert len(result.data) <= 2 + # Verify all returned messages are in include_ids + for message in result.data: + assert message.id in include_ids + + def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test pagination by last ID when no user is provided. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Test pagination with no user + result = MessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + + # Verify empty result + assert result.limit == 10 + assert len(result.data) == 0 + assert result.has_more is False + + def test_pagination_by_last_id_invalid_last_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test pagination by last ID with invalid last_id. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + self._create_test_message(app, conversation, account, fake) + + # Test pagination with invalid last_id + with pytest.raises(LastMessageNotExistsError): + MessageService.pagination_by_last_id( + app_model=app, + user=account, + last_id=fake.uuid4(), # Non-existent message ID + limit=10, + conversation_id=conversation.id, + ) + + def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful creation of feedback. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Create feedback + rating = "like" + content = fake.text(max_nb_chars=100) + feedback = MessageService.create_feedback( + app_model=app, message_id=message.id, user=account, rating=rating, content=content + ) + + # Verify feedback was created correctly + assert feedback.app_id == app.id + assert feedback.conversation_id == conversation.id + assert feedback.message_id == message.id + assert feedback.rating == rating + assert feedback.content == content + assert feedback.from_source == "admin" + assert feedback.from_account_id == account.id + assert feedback.from_end_user_id is None + + def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test creating feedback when no user is provided. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Test creating feedback with no user + with pytest.raises(ValueError, match="user cannot be None"): + MessageService.create_feedback( + app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100) + ) + + def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating existing feedback. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Create initial feedback + initial_rating = "like" + initial_content = fake.text(max_nb_chars=100) + feedback = MessageService.create_feedback( + app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content + ) + + # Update feedback + updated_rating = "dislike" + updated_content = fake.text(max_nb_chars=100) + updated_feedback = MessageService.create_feedback( + app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content + ) + + # Verify feedback was updated correctly + assert updated_feedback.id == feedback.id + assert updated_feedback.rating == updated_rating + assert updated_feedback.content == updated_content + assert updated_feedback.rating != initial_rating + assert updated_feedback.content != initial_content + + def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting existing feedback by setting rating to None. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Create initial feedback + feedback = MessageService.create_feedback( + app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100) + ) + + # Delete feedback by setting rating to None + MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None) + + # Verify feedback was deleted + from extensions.ext_database import db + + deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() + assert deleted_feedback is None + + def test_create_feedback_no_rating_when_not_exists( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test creating feedback with no rating when feedback doesn't exist. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Test creating feedback with no rating when no feedback exists + with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"): + MessageService.create_feedback( + app_model=app, message_id=message.id, user=account, rating=None, content=None + ) + + def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of all message feedbacks. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple conversations and messages with feedbacks + feedbacks = [] + for i in range(3): + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + feedback = MessageService.create_feedback( + app_model=app, + message_id=message.id, + user=account, + rating="like" if i % 2 == 0 else "dislike", + content=f"Feedback {i}: {fake.text(max_nb_chars=50)}", + ) + feedbacks.append(feedback) + + # Get all feedbacks + result = MessageService.get_all_messages_feedbacks(app, page=1, limit=10) + + # Verify results + assert len(result) == 3 + + # Verify feedbacks are ordered by created_at desc + for i in range(len(result) - 1): + assert result[i]["created_at"] >= result[i + 1]["created_at"] + + def test_get_all_messages_feedbacks_pagination( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test pagination of message feedbacks. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple conversations and messages with feedbacks + for i in range(5): + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + MessageService.create_feedback( + app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}" + ) + + # Get feedbacks with pagination + result_page_1 = MessageService.get_all_messages_feedbacks(app, page=1, limit=3) + result_page_2 = MessageService.get_all_messages_feedbacks(app, page=2, limit=3) + + # Verify pagination results + assert len(result_page_1) == 3 + assert len(result_page_2) == 2 + + # Verify no overlap between pages + page_1_ids = {feedback["id"] for feedback in result_page_1} + page_2_ids = {feedback["id"] for feedback in result_page_2} + assert len(page_1_ids.intersection(page_2_ids)) == 0 + + def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of message. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Get message + retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id) + + # Verify message was retrieved correctly + assert retrieved_message.id == message.id + assert retrieved_message.app_id == app.id + assert retrieved_message.conversation_id == conversation.id + assert retrieved_message.from_source == "console" + assert retrieved_message.from_account_id == account.id + + def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting message that doesn't exist. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Test getting non-existent message + with pytest.raises(MessageNotExistsError): + MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4()) + + def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting message with wrong user (different account). + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Create another account + from services.account_service import AccountService, TenantService + + other_account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company()) + + # Test getting message with different user + with pytest.raises(MessageNotExistsError): + MessageService.get_message(app_model=app, user=other_account, message_id=message.id) + + def test_get_suggested_questions_after_answer_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful generation of suggested questions after answer. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Mock the LLMGenerator to return specific questions + mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"] + mock_external_service_dependencies[ + "llm_generator" + ].generate_suggested_questions_after_answer.return_value = mock_questions + + # Get suggested questions + from core.app.entities.app_invoke_entities import InvokeFrom + + result = MessageService.get_suggested_questions_after_answer( + app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API + ) + + # Verify results + assert result == mock_questions + + # Verify LLMGenerator was called + mock_external_service_dependencies[ + "llm_generator" + ].generate_suggested_questions_after_answer.assert_called_once() + + # Verify TraceQueueManager was called + mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once() + + def test_get_suggested_questions_after_answer_no_user( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting suggested questions when no user is provided. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Test getting suggested questions with no user + from core.app.entities.app_invoke_entities import InvokeFrom + + with pytest.raises(ValueError, match="user cannot be None"): + MessageService.get_suggested_questions_after_answer( + app_model=app, user=None, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API + ) + + def test_get_suggested_questions_after_answer_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting suggested questions when feature is disabled. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Mock the feature to be disabled + mock_external_service_dependencies[ + "app_config_manager" + ].get_app_config.return_value.additional_features.suggested_questions_after_answer = False + + # Test getting suggested questions when feature is disabled + from core.app.entities.app_invoke_entities import InvokeFrom + + with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError): + MessageService.get_suggested_questions_after_answer( + app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API + ) + + def test_get_suggested_questions_after_answer_no_workflow( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting suggested questions when no workflow exists. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Mock no workflow + mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None + + # Get suggested questions (should return empty list) + from core.app.entities.app_invoke_entities import InvokeFrom + + result = MessageService.get_suggested_questions_after_answer( + app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.SERVICE_API + ) + + # Verify empty result + assert result == [] + + def test_get_suggested_questions_after_answer_debugger_mode( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting suggested questions in debugger mode. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation and message + conversation = self._create_test_conversation(app, account, fake) + message = self._create_test_message(app, conversation, account, fake) + + # Mock questions + mock_questions = ["Debug question 1", "Debug question 2"] + mock_external_service_dependencies[ + "llm_generator" + ].generate_suggested_questions_after_answer.return_value = mock_questions + + # Get suggested questions in debugger mode + from core.app.entities.app_invoke_entities import InvokeFrom + + result = MessageService.get_suggested_questions_after_answer( + app_model=app, user=account, message_id=message.id, invoke_from=InvokeFrom.DEBUGGER + ) + + # Verify results + assert result == mock_questions + + # Verify draft workflow was used instead of published workflow + mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( + app_model=app + ) + + # Verify TraceQueueManager was called + mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py new file mode 100644 index 0000000000..7fef572c14 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -0,0 +1,1144 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from core.rag.index_processor.constant.built_in_field import BuiltInField +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document +from services.entities.knowledge_entities.knowledge_entities import MetadataArgs +from services.metadata_service import MetadataService + + +class TestMetadataService: + """Integration tests for MetadataService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.metadata_service.current_user") as mock_current_user, + patch("services.metadata_service.redis_client") as mock_redis_client, + patch("services.dataset_service.DocumentService") as mock_document_service, + ): + # Setup default mock returns + mock_redis_client.get.return_value = None + mock_redis_client.set.return_value = True + mock_redis_client.delete.return_value = 1 + + yield { + "current_user": mock_current_user, + "redis_client": mock_redis_client, + "document_service": mock_document_service, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant): + """ + Helper method to create a test dataset for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + account: Account instance + tenant: Tenant instance + + Returns: + Dataset: Created dataset instance + """ + fake = Faker() + + dataset = Dataset( + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + data_source_type="upload_file", + created_by=account.id, + built_in_field_enabled=False, + ) + + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account): + """ + Helper method to create a test document for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + dataset: Dataset instance + account: Account instance + + Returns: + Document: Created document instance + """ + fake = Faker() + + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + data_source_info="{}", + batch="test-batch", + name=fake.file_name(), + created_from="web", + created_by=account.id, + doc_form="text", + doc_language="en", + ) + + from extensions.ext_database import db + + db.session.add(document) + db.session.commit() + + return document + + def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful metadata creation with valid parameters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + metadata_args = MetadataArgs(type="string", name="test_metadata") + + # Act: Execute the method under test + result = MetadataService.create_metadata(dataset.id, metadata_args) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == "test_metadata" + assert result.type == "string" + assert result.dataset_id == dataset.id + assert result.tenant_id == tenant.id + assert result.created_by == account.id + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + assert result.created_at is not None + + def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata creation fails when name exceeds 255 characters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + long_name = "a" * 256 # 256 characters, exceeding 255 limit + metadata_args = MetadataArgs(type="string", name=long_name) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): + MetadataService.create_metadata(dataset.id, metadata_args) + + def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata creation fails when name already exists in the same dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create first metadata + first_metadata_args = MetadataArgs(type="string", name="duplicate_name") + MetadataService.create_metadata(dataset.id, first_metadata_args) + + # Try to create second metadata with same name + second_metadata_args = MetadataArgs(type="number", name="duplicate_name") + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name already exists."): + MetadataService.create_metadata(dataset.id, second_metadata_args) + + def test_create_metadata_name_conflicts_with_built_in_field( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata creation fails when name conflicts with built-in field names. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Try to create metadata with built-in field name + built_in_field_name = BuiltInField.document_name.value + metadata_args = MetadataArgs(type="string", name=built_in_field_name) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): + MetadataService.create_metadata(dataset.id, metadata_args) + + def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful metadata name update with valid parameters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Act: Execute the method under test + new_name = "new_name" + result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == new_name + assert result.updated_by == account.id + assert result.updated_at is not None + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.name == new_name + + def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata name update fails when new name exceeds 255 characters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Try to update with too long name + long_name = "a" * 256 # 256 characters, exceeding 255 limit + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): + MetadataService.update_metadata_name(dataset.id, metadata.id, long_name) + + def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata name update fails when new name already exists in the same dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create two metadata entries + first_metadata_args = MetadataArgs(type="string", name="first_metadata") + first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args) + + second_metadata_args = MetadataArgs(type="number", name="second_metadata") + second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args) + + # Try to update first metadata with second metadata's name + with pytest.raises(ValueError, match="Metadata name already exists."): + MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata") + + def test_update_metadata_name_conflicts_with_built_in_field( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata name update fails when new name conflicts with built-in field names. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Try to update with built-in field name + built_in_field_name = BuiltInField.document_name.value + + with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): + MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) + + def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata name update fails when metadata ID does not exist. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Try to update non-existent metadata + import uuid + + fake_metadata_id = str(uuid.uuid4()) # Use valid UUID format + new_name = "new_name" + + # Act: Execute the method under test + result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name) + + # Assert: Verify the method returns None when metadata is not found + assert result is None + + def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful metadata deletion with valid parameters. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata first + metadata_args = MetadataArgs(type="string", name="to_be_deleted") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Act: Execute the method under test + result = MetadataService.delete_metadata(dataset.id, metadata.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.id == metadata.id + + # Verify metadata was deleted from database + from extensions.ext_database import db + + deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + assert deleted_metadata is None + + def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test metadata deletion fails when metadata ID does not exist. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Try to delete non-existent metadata + import uuid + + fake_metadata_id = str(uuid.uuid4()) # Use valid UUID format + + # Act: Execute the method under test + result = MetadataService.delete_metadata(dataset.id, fake_metadata_id) + + # Assert: Verify the method returns None when metadata is not found + assert result is None + + def test_delete_metadata_with_document_bindings( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata deletion successfully removes document metadata bindings. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Create metadata binding + binding = DatasetMetadataBinding( + tenant_id=tenant.id, + dataset_id=dataset.id, + metadata_id=metadata.id, + document_id=document.id, + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(binding) + db.session.commit() + + # Set document metadata + document.doc_metadata = {"test_metadata": "test_value"} + db.session.add(document) + db.session.commit() + + # Act: Execute the method under test + result = MetadataService.delete_metadata(dataset.id, metadata.id) + + # Assert: Verify the expected outcomes + assert result is not None + + # Verify metadata was deleted from database + deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + assert deleted_metadata is None + + # Note: The service attempts to update document metadata but may not succeed + # due to mock configuration. The main functionality (metadata deletion) is verified. + + def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of built-in metadata fields. + """ + # Act: Execute the method under test + result = MetadataService.get_built_in_fields() + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 5 + + # Verify all expected built-in fields are present + field_names = [field["name"] for field in result] + field_types = [field["type"] for field in result] + + assert BuiltInField.document_name.value in field_names + assert BuiltInField.uploader.value in field_names + assert BuiltInField.upload_date.value in field_names + assert BuiltInField.last_update_date.value in field_names + assert BuiltInField.source.value in field_names + + # Verify field types + assert "string" in field_types + assert "time" in field_types + + def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful enabling of built-in fields for a dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ + document + ] + + # Verify dataset starts with built-in fields disabled + assert dataset.built_in_field_enabled is False + + # Act: Execute the method under test + MetadataService.enable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is True + + # Note: Document metadata update depends on DocumentService mock working correctly + # The main functionality (enabling built-in fields) is verified + + def test_enable_built_in_field_already_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test enabling built-in fields when they are already enabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Enable built-in fields first + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.enable_built_in_field(dataset) + + # Assert: Verify the method returns early without changes + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is True + + def test_enable_built_in_field_with_no_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test enabling built-in fields for a dataset with no documents. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Mock DocumentService.get_working_documents_by_dataset_id to return empty list + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.enable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is True + + def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful disabling of built-in fields for a dataset. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Enable built-in fields first + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Set document metadata with built-in fields + document.doc_metadata = { + BuiltInField.document_name.value: document.name, + BuiltInField.uploader.value: "test_uploader", + BuiltInField.upload_date.value: 1234567890.0, + BuiltInField.last_update_date.value: 1234567890.0, + BuiltInField.source.value: "test_source", + } + db.session.add(document) + db.session.commit() + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ + document + ] + + # Act: Execute the method under test + MetadataService.disable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is False + + # Note: Document metadata update depends on DocumentService mock working correctly + # The main functionality (disabling built-in fields) is verified + + def test_disable_built_in_field_already_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test disabling built-in fields when they are already disabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Verify dataset starts with built-in fields disabled + assert dataset.built_in_field_enabled is False + + # Mock DocumentService.get_working_documents_by_dataset_id + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.disable_built_in_field(dataset) + + # Assert: Verify the method returns early without changes + from extensions.ext_database import db + + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is False + + def test_disable_built_in_field_with_no_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test disabling built-in fields for a dataset with no documents. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Enable built-in fields first + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Mock DocumentService.get_working_documents_by_dataset_id to return empty list + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + + # Act: Execute the method under test + MetadataService.disable_built_in_field(dataset) + + # Assert: Verify the expected outcomes + db.session.refresh(dataset) + assert dataset.built_in_field_enabled is False + + def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful update of documents metadata. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Mock DocumentService.get_document + mock_external_service_dependencies["document_service"].get_document.return_value = document + + # Create metadata operation data + from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, + ) + + metadata_detail = MetadataDetail(id=metadata.id, name=metadata.name, value="test_value") + + operation = DocumentMetadataOperation(document_id=document.id, metadata_list=[metadata_detail]) + + operation_data = MetadataOperationData(operation_data=[operation]) + + # Act: Execute the method under test + MetadataService.update_documents_metadata(dataset, operation_data) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + # Verify document metadata was updated + db.session.refresh(document) + assert document.doc_metadata is not None + assert "test_metadata" in document.doc_metadata + assert document.doc_metadata["test_metadata"] == "test_value" + + # Verify metadata binding was created + binding = ( + db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first() + ) + assert binding is not None + assert binding.tenant_id == tenant.id + assert binding.dataset_id == dataset.id + + def test_update_documents_metadata_with_built_in_fields_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test update of documents metadata when built-in fields are enabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + # Enable built-in fields + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Mock DocumentService.get_document + mock_external_service_dependencies["document_service"].get_document.return_value = document + + # Create metadata operation data + from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, + ) + + metadata_detail = MetadataDetail(id=metadata.id, name=metadata.name, value="test_value") + + operation = DocumentMetadataOperation(document_id=document.id, metadata_list=[metadata_detail]) + + operation_data = MetadataOperationData(operation_data=[operation]) + + # Act: Execute the method under test + MetadataService.update_documents_metadata(dataset, operation_data) + + # Assert: Verify the expected outcomes + # Verify document metadata was updated with both custom and built-in fields + db.session.refresh(document) + assert document.doc_metadata is not None + assert "test_metadata" in document.doc_metadata + assert document.doc_metadata["test_metadata"] == "test_value" + + # Note: Built-in fields would be added if DocumentService mock works correctly + # The main functionality (custom metadata update) is verified + + def test_update_documents_metadata_document_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test update of documents metadata when document is not found. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Mock DocumentService.get_document to return None (document not found) + mock_external_service_dependencies["document_service"].get_document.return_value = None + + # Create metadata operation data + from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, + ) + + metadata_detail = MetadataDetail(id=metadata.id, name=metadata.name, value="test_value") + + operation = DocumentMetadataOperation(document_id="non-existent-document-id", metadata_list=[metadata_detail]) + + operation_data = MetadataOperationData(operation_data=[operation]) + + # Act: Execute the method under test + # The method should handle the error gracefully and continue + MetadataService.update_documents_metadata(dataset, operation_data) + + # Assert: Verify the method completes without raising exceptions + # The main functionality (error handling) is verified + + def test_knowledge_base_metadata_lock_check_dataset_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check for dataset operations. + """ + # Arrange: Setup mocks + mock_external_service_dependencies["redis_client"].get.return_value = None + mock_external_service_dependencies["redis_client"].set.return_value = True + + dataset_id = "test-dataset-id" + + # Act: Execute the method under test + MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) + + # Assert: Verify the expected outcomes + # Verify Redis lock was set + mock_external_service_dependencies["redis_client"].set.assert_called_once() + + # Verify lock key format + call_args = mock_external_service_dependencies["redis_client"].set.call_args + assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}" + + def test_knowledge_base_metadata_lock_check_document_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check for document operations. + """ + # Arrange: Setup mocks + mock_external_service_dependencies["redis_client"].get.return_value = None + mock_external_service_dependencies["redis_client"].set.return_value = True + + document_id = "test-document-id" + + # Act: Execute the method under test + MetadataService.knowledge_base_metadata_lock_check(None, document_id) + + # Assert: Verify the expected outcomes + # Verify Redis lock was set + mock_external_service_dependencies["redis_client"].set.assert_called_once() + + # Verify lock key format + call_args = mock_external_service_dependencies["redis_client"].set.call_args + assert call_args[0][0] == f"document_metadata_lock_{document_id}" + + def test_knowledge_base_metadata_lock_check_lock_exists( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check when lock already exists. + """ + # Arrange: Setup mocks to simulate existing lock + mock_external_service_dependencies["redis_client"].get.return_value = "1" # Lock exists + + dataset_id = "test-dataset-id" + + # Act & Assert: Verify proper error handling + with pytest.raises( + ValueError, match="Another knowledge base metadata operation is running, please wait a moment." + ): + MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) + + def test_knowledge_base_metadata_lock_check_document_lock_exists( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test metadata lock check when document lock already exists. + """ + # Arrange: Setup mocks to simulate existing lock + mock_external_service_dependencies["redis_client"].get.return_value = "1" # Lock exists + + document_id = "test-document-id" + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."): + MetadataService.knowledge_base_metadata_lock_check(None, document_id) + + def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of dataset metadata information. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Create document and metadata binding + document = self._create_test_document( + db_session_with_containers, mock_external_service_dependencies, dataset, account + ) + + binding = DatasetMetadataBinding( + tenant_id=tenant.id, + dataset_id=dataset.id, + metadata_id=metadata.id, + document_id=document.id, + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(binding) + db.session.commit() + + # Act: Execute the method under test + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert: Verify the expected outcomes + assert result is not None + assert "doc_metadata" in result + assert "built_in_field_enabled" in result + + # Verify metadata information + doc_metadata = result["doc_metadata"] + assert len(doc_metadata) == 1 + assert doc_metadata[0]["id"] == metadata.id + assert doc_metadata[0]["name"] == metadata.name + assert doc_metadata[0]["type"] == metadata.type + assert doc_metadata[0]["count"] == 1 # One document bound to this metadata + + # Verify built-in field status + assert result["built_in_field_enabled"] is False + + def test_get_dataset_metadatas_with_built_in_fields_enabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test retrieval of dataset metadata when built-in fields are enabled. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Enable built-in fields + dataset.built_in_field_enabled = True + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + # Setup mocks + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + # Create metadata + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args) + + # Act: Execute the method under test + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert: Verify the expected outcomes + assert result is not None + assert "doc_metadata" in result + assert "built_in_field_enabled" in result + + # Verify metadata information + doc_metadata = result["doc_metadata"] + assert len(doc_metadata) == 1 # Only custom metadata, built-in fields are not included in this list + + # Verify built-in field status + assert result["built_in_field_enabled"] is True + + def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test retrieval of dataset metadata when no metadata exists. + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, account, tenant + ) + + # Act: Execute the method under test + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert: Verify the expected outcomes + assert result is not None + assert "doc_metadata" in result + assert "built_in_field_enabled" in result + + # Verify metadata information + doc_metadata = result["doc_metadata"] + assert len(doc_metadata) == 0 # No metadata exists + + # Verify built-in field status + assert result["built_in_field_enabled"] is False diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py new file mode 100644 index 0000000000..cb20238f0c --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -0,0 +1,474 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from models.account import TenantAccountJoin, TenantAccountRole +from models.model import Account, Tenant +from models.provider import LoadBalancingModelConfig, Provider, ProviderModelSetting +from services.model_load_balancing_service import ModelLoadBalancingService + + +class TestModelLoadBalancingService: + """Integration tests for ModelLoadBalancingService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager, + patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager, + patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory, + patch("services.model_load_balancing_service.encrypter") as mock_encrypter, + ): + # Setup default mock returns + mock_provider_manager_instance = mock_provider_manager.return_value + + # Mock provider configuration + mock_provider_config = MagicMock() + mock_provider_config.provider.provider = "openai" + mock_provider_config.custom_configuration.provider = None + + # Mock provider model setting + mock_provider_model_setting = MagicMock() + mock_provider_model_setting.load_balancing_enabled = False + + mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting + + # Mock provider configurations dict + mock_provider_configs = {"openai": mock_provider_config} + mock_provider_manager_instance.get_configurations.return_value = mock_provider_configs + + # Mock LBModelManager + mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) + + # Mock ModelProviderFactory + mock_model_provider_factory_instance = mock_model_provider_factory.return_value + + # Mock credential schemas + mock_credential_schema = MagicMock() + mock_credential_schema.credential_form_schemas = [] + + # Mock provider configuration methods + mock_provider_config.extract_secret_variables.return_value = [] + mock_provider_config.obfuscated_credentials.return_value = {} + mock_provider_config._get_credential_schema.return_value = mock_credential_schema + + yield { + "provider_manager": mock_provider_manager, + "lb_model_manager": mock_lb_model_manager, + "model_provider_factory": mock_model_provider_factory, + "encrypter": mock_encrypter, + "provider_config": mock_provider_config, + "provider_model_setting": mock_provider_model_setting, + "credential_schema": mock_credential_schema, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_provider_and_setting( + self, db_session_with_containers, tenant_id, mock_external_service_dependencies + ): + """ + Helper method to create a test provider and provider model setting. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant_id: Tenant ID for the provider + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (provider, provider_model_setting) - Created provider and setting instances + """ + fake = Faker() + + from extensions.ext_database import db + + # Create provider + provider = Provider( + tenant_id=tenant_id, + provider_name="openai", + provider_type="custom", + is_valid=True, + ) + db.session.add(provider) + db.session.commit() + + # Create provider model setting + provider_model_setting = ProviderModelSetting( + tenant_id=tenant_id, + provider_name="openai", + model_name="gpt-3.5-turbo", + model_type="text-generation", # Use the origin model type that matches the query + enabled=True, + load_balancing_enabled=False, + ) + db.session.add(provider_model_setting) + db.session.commit() + + return provider, provider_model_setting + + def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful model load balancing enablement. + + This test verifies: + - Proper provider configuration retrieval + - Successful enablement of model load balancing + - Correct method calls to provider configuration + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Setup mocks for enable method + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_config.enable_model_load_balancing = MagicMock() + + # Act: Execute the method under test + service = ModelLoadBalancingService() + service.enable_model_load_balancing( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + mock_provider_config.enable_model_load_balancing.assert_called_once() + call_args = mock_provider_config.enable_model_load_balancing.call_args + assert call_args.kwargs["model"] == "gpt-3.5-turbo" + assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(provider) + db.session.refresh(provider_model_setting) + assert provider.id is not None + assert provider_model_setting.id is not None + + def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful model load balancing disablement. + + This test verifies: + - Proper provider configuration retrieval + - Successful disablement of model load balancing + - Correct method calls to provider configuration + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Setup mocks for disable method + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_config.disable_model_load_balancing = MagicMock() + + # Act: Execute the method under test + service = ModelLoadBalancingService() + service.disable_model_load_balancing( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + mock_provider_config.disable_model_load_balancing.assert_called_once() + call_args = mock_provider_config.disable_model_load_balancing.call_args + assert call_args.kwargs["model"] == "gpt-3.5-turbo" + assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(provider) + db.session.refresh(provider_model_setting) + assert provider.id is not None + assert provider_model_setting.id is not None + + def test_enable_model_load_balancing_provider_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when provider does not exist. + + This test verifies: + - Proper error handling for non-existent provider + - Correct exception type and message + - No database state changes + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup mocks to return empty provider configurations + mock_provider_manager = mock_external_service_dependencies["provider_manager"] + mock_provider_manager_instance = mock_provider_manager.return_value + mock_provider_manager_instance.get_configurations.return_value = {} + + # Act & Assert: Verify proper error handling + service = ModelLoadBalancingService() + with pytest.raises(ValueError) as exc_info: + service.enable_model_load_balancing( + tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm" + ) + + # Verify correct error message + assert "Provider nonexistent_provider does not exist." in str(exc_info.value) + + # Verify no database state changes occurred + from extensions.ext_database import db + + db.session.rollback() + + def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of load balancing configurations. + + This test verifies: + - Proper provider configuration retrieval + - Successful database query for load balancing configs + - Correct return format and data structure + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Create load balancing config + from extensions.ext_database import db + + load_balancing_config = LoadBalancingModelConfig( + tenant_id=tenant.id, + provider_name="openai", + model_name="gpt-3.5-turbo", + model_type="text-generation", # Use the origin model type that matches the query + name="config1", + encrypted_config='{"api_key": "test_key"}', + enabled=True, + ) + db.session.add(load_balancing_config) + db.session.commit() + + # Verify the config was created + db.session.refresh(load_balancing_config) + assert load_balancing_config.id is not None + + # Setup mocks for get_load_balancing_configs method + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"] + mock_provider_model_setting.load_balancing_enabled = True + + # Mock credential schema methods + mock_credential_schema = mock_external_service_dependencies["credential_schema"] + mock_credential_schema.credential_form_schemas = [] + + # Mock encrypter + mock_encrypter = mock_external_service_dependencies["encrypter"] + mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher") + + # Mock _get_credential_schema method + mock_provider_config._get_credential_schema.return_value = mock_credential_schema + + # Mock extract_secret_variables method + mock_provider_config.extract_secret_variables.return_value = [] + + # Mock obfuscated_credentials method + mock_provider_config.obfuscated_credentials.return_value = {} + + # Mock LBModelManager.get_config_in_cooldown_and_ttl + mock_lb_model_manager = mock_external_service_dependencies["lb_model_manager"] + mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) + + # Act: Execute the method under test + service = ModelLoadBalancingService() + is_enabled, configs = service.get_load_balancing_configs( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + assert is_enabled is True + assert len(configs) == 1 + assert configs[0]["id"] == load_balancing_config.id + assert configs[0]["name"] == "config1" + assert configs[0]["enabled"] is True + assert configs[0]["in_cooldown"] is False + assert configs[0]["ttl"] == 0 + + # Verify database state + db.session.refresh(load_balancing_config) + assert load_balancing_config.id is not None + + def test_get_load_balancing_configs_provider_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test error handling when provider does not exist in get_load_balancing_configs. + + This test verifies: + - Proper error handling for non-existent provider + - Correct exception type and message + - No database state changes + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Setup mocks to return empty provider configurations + mock_provider_manager = mock_external_service_dependencies["provider_manager"] + mock_provider_manager_instance = mock_provider_manager.return_value + mock_provider_manager_instance.get_configurations.return_value = {} + + # Act & Assert: Verify proper error handling + service = ModelLoadBalancingService() + with pytest.raises(ValueError) as exc_info: + service.get_load_balancing_configs( + tenant_id=tenant.id, provider="nonexistent_provider", model="gpt-3.5-turbo", model_type="llm" + ) + + # Verify correct error message + assert "Provider nonexistent_provider does not exist." in str(exc_info.value) + + # Verify no database state changes occurred + from extensions.ext_database import db + + db.session.rollback() + + def test_get_load_balancing_configs_with_inherit_config( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test load balancing configs retrieval with inherit configuration. + + This test verifies: + - Proper handling of inherit configuration + - Correct ordering of configurations + - Inherit config initialization when needed + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + provider, provider_model_setting = self._create_test_provider_and_setting( + db_session_with_containers, tenant.id, mock_external_service_dependencies + ) + + # Create load balancing config + from extensions.ext_database import db + + load_balancing_config = LoadBalancingModelConfig( + tenant_id=tenant.id, + provider_name="openai", + model_name="gpt-3.5-turbo", + model_type="text-generation", # Use the origin model type that matches the query + name="config1", + encrypted_config='{"api_key": "test_key"}', + enabled=True, + ) + db.session.add(load_balancing_config) + db.session.commit() + + # Setup mocks for inherit config scenario + mock_provider_config = mock_external_service_dependencies["provider_config"] + mock_provider_config.custom_configuration.provider = MagicMock() # Enable custom config + + mock_provider_model_setting = mock_external_service_dependencies["provider_model_setting"] + mock_provider_model_setting.load_balancing_enabled = True + + # Mock credential schema methods + mock_credential_schema = mock_external_service_dependencies["credential_schema"] + mock_credential_schema.credential_form_schemas = [] + + # Mock encrypter + mock_encrypter = mock_external_service_dependencies["encrypter"] + mock_encrypter.get_decrypt_decoding.return_value = ("key", "cipher") + + # Act: Execute the method under test + service = ModelLoadBalancingService() + is_enabled, configs = service.get_load_balancing_configs( + tenant_id=tenant.id, provider="openai", model="gpt-3.5-turbo", model_type="llm" + ) + + # Assert: Verify the expected outcomes + assert is_enabled is True + assert len(configs) == 2 # inherit config + existing config + + # First config should be inherit config + assert configs[0]["name"] == "__inherit__" + assert configs[0]["enabled"] is True + + # Second config should be the existing config + assert configs[1]["id"] == load_balancing_config.id + assert configs[1]["name"] == "config1" + + # Verify database state + db.session.refresh(load_balancing_config) + assert load_balancing_config.id is not None + + # Verify inherit config was created in database + inherit_configs = ( + db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all() + ) + assert len(inherit_configs) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py new file mode 100644 index 0000000000..8b7d44c1e4 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -0,0 +1,1172 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from core.entities.model_entities import ModelStatus +from core.model_runtime.entities.model_entities import FetchFrom, ModelType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType +from services.model_provider_service import ModelProviderService + + +class TestModelProviderService: + """Integration tests for ModelProviderService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.model_provider_service.ProviderManager") as mock_provider_manager, + patch("services.model_provider_service.ModelProviderFactory") as mock_model_provider_factory, + ): + # Setup default mock returns + mock_provider_manager.return_value.get_configurations.return_value = MagicMock() + mock_model_provider_factory.return_value.get_provider_icon.return_value = (None, None) + + yield { + "provider_manager": mock_provider_manager, + "model_provider_factory": mock_model_provider_factory, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_provider( + self, + db_session_with_containers, + mock_external_service_dependencies, + tenant_id: str, + provider_name: str = "openai", + ): + """ + Helper method to create a test provider for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant_id: Tenant ID for the provider + provider_name: Name of the provider + + Returns: + Provider: Created provider instance + """ + fake = Faker() + + provider = Provider( + tenant_id=tenant_id, + provider_name=provider_name, + provider_type="custom", + is_valid=True, + quota_type="free", + quota_limit=1000, + quota_used=0, + ) + + from extensions.ext_database import db + + db.session.add(provider) + db.session.commit() + + return provider + + def _create_test_provider_model( + self, + db_session_with_containers, + mock_external_service_dependencies, + tenant_id: str, + provider_name: str, + model_name: str = "gpt-3.5-turbo", + model_type: str = "llm", + ): + """ + Helper method to create a test provider model for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant_id: Tenant ID for the provider model + provider_name: Name of the provider + model_name: Name of the model + model_type: Type of the model + + Returns: + ProviderModel: Created provider model instance + """ + fake = Faker() + + provider_model = ProviderModel( + tenant_id=tenant_id, + provider_name=provider_name, + model_name=model_name, + model_type=model_type, + is_valid=True, + ) + + from extensions.ext_database import db + + db.session.add(provider_model) + db.session.commit() + + return provider_model + + def _create_test_provider_model_setting( + self, + db_session_with_containers, + mock_external_service_dependencies, + tenant_id: str, + provider_name: str, + model_name: str = "gpt-3.5-turbo", + model_type: str = "llm", + ): + """ + Helper method to create a test provider model setting for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant_id: Tenant ID for the provider model setting + provider_name: Name of the provider + model_name: Name of the model + model_type: Type of the model + + Returns: + ProviderModelSetting: Created provider model setting instance + """ + fake = Faker() + + provider_model_setting = ProviderModelSetting( + tenant_id=tenant_id, + provider_name=provider_name, + model_name=model_name, + model_type=model_type, + enabled=True, + load_balancing_enabled=False, + ) + + from extensions.ext_database import db + + db.session.add(provider_model_setting) + db.session.commit() + + return provider_model_setting + + def test_get_provider_list_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful provider list retrieval. + + This test verifies: + - Proper provider list retrieval with all required fields + - Correct filtering by model type + - Proper response structure and data mapping + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration + mock_provider_entity = MagicMock() + mock_provider_entity.provider = "openai" + mock_provider_entity.label = {"en_US": "OpenAI", "zh_Hans": "OpenAI"} + mock_provider_entity.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"} + mock_provider_entity.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} + mock_provider_entity.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} + mock_provider_entity.background = "#FF6B6B" + mock_provider_entity.help = None + mock_provider_entity.supported_model_types = [ModelType.LLM, ModelType.TEXT_EMBEDDING] + mock_provider_entity.configurate_methods = [] + mock_provider_entity.provider_credential_schema = None + mock_provider_entity.model_credential_schema = None + + mock_provider_config = MagicMock() + mock_provider_config.provider = mock_provider_entity + mock_provider_config.preferred_provider_type = ProviderType.CUSTOM + mock_provider_config.is_custom_configuration_available.return_value = True + mock_provider_config.system_configuration.enabled = True + mock_provider_config.system_configuration.current_quota_type = "free" + mock_provider_config.system_configuration.quota_configurations = [] + + mock_configurations = MagicMock() + mock_configurations.values.return_value = [mock_provider_config] + mock_provider_manager.get_configurations.return_value = mock_configurations + + # Act: Execute the method under test + service = ModelProviderService() + result = service.get_provider_list(tenant.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 1 + + provider_response = result[0] + assert provider_response.tenant_id == tenant.id + assert provider_response.provider == "openai" + assert provider_response.background == "#FF6B6B" + assert len(provider_response.supported_model_types) == 2 + assert ModelType.LLM in provider_response.supported_model_types + assert ModelType.TEXT_EMBEDDING in provider_response.supported_model_types + + # Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_config.is_custom_configuration_available.assert_called_once() + + def test_get_provider_list_with_model_type_filter( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test provider list retrieval with model type filtering. + + This test verifies: + - Proper filtering by model type + - Only providers supporting the specified model type are returned + - Correct handling of unsupported model types + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock ProviderManager to return multiple provider configurations + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configurations with different supported model types + mock_provider_entity_llm = MagicMock() + mock_provider_entity_llm.provider = "openai" + mock_provider_entity_llm.label = {"en_US": "OpenAI", "zh_Hans": "OpenAI"} + mock_provider_entity_llm.description = {"en_US": "OpenAI provider", "zh_Hans": "OpenAI 提供商"} + mock_provider_entity_llm.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} + mock_provider_entity_llm.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} + mock_provider_entity_llm.background = "#FF6B6B" + mock_provider_entity_llm.help = None + mock_provider_entity_llm.supported_model_types = [ModelType.LLM] + mock_provider_entity_llm.configurate_methods = [] + mock_provider_entity_llm.provider_credential_schema = None + mock_provider_entity_llm.model_credential_schema = None + + mock_provider_entity_embedding = MagicMock() + mock_provider_entity_embedding.provider = "cohere" + mock_provider_entity_embedding.label = {"en_US": "Cohere", "zh_Hans": "Cohere"} + mock_provider_entity_embedding.description = {"en_US": "Cohere provider", "zh_Hans": "Cohere 提供商"} + mock_provider_entity_embedding.icon_small = {"en_US": "icon_small.png", "zh_Hans": "icon_small.png"} + mock_provider_entity_embedding.icon_large = {"en_US": "icon_large.png", "zh_Hans": "icon_large.png"} + mock_provider_entity_embedding.background = "#4ECDC4" + mock_provider_entity_embedding.help = None + mock_provider_entity_embedding.supported_model_types = [ModelType.TEXT_EMBEDDING] + mock_provider_entity_embedding.configurate_methods = [] + mock_provider_entity_embedding.provider_credential_schema = None + mock_provider_entity_embedding.model_credential_schema = None + + mock_provider_config_llm = MagicMock() + mock_provider_config_llm.provider = mock_provider_entity_llm + mock_provider_config_llm.preferred_provider_type = ProviderType.CUSTOM + mock_provider_config_llm.is_custom_configuration_available.return_value = True + mock_provider_config_llm.system_configuration.enabled = True + mock_provider_config_llm.system_configuration.current_quota_type = "free" + mock_provider_config_llm.system_configuration.quota_configurations = [] + + mock_provider_config_embedding = MagicMock() + mock_provider_config_embedding.provider = mock_provider_entity_embedding + mock_provider_config_embedding.preferred_provider_type = ProviderType.CUSTOM + mock_provider_config_embedding.is_custom_configuration_available.return_value = True + mock_provider_config_embedding.system_configuration.enabled = True + mock_provider_config_embedding.system_configuration.current_quota_type = "free" + mock_provider_config_embedding.system_configuration.quota_configurations = [] + + mock_configurations = MagicMock() + mock_configurations.values.return_value = [mock_provider_config_llm, mock_provider_config_embedding] + mock_provider_manager.get_configurations.return_value = mock_configurations + + # Act: Execute the method under test with LLM filter + service = ModelProviderService() + result = service.get_provider_list(tenant.id, model_type="llm") + + # Assert: Verify only LLM providers are returned + assert result is not None + assert len(result) == 1 + assert result[0].provider == "openai" + assert ModelType.LLM in result[0].supported_model_types + + # Act: Execute the method under test with TEXT_EMBEDDING filter + result = service.get_provider_list(tenant.id, model_type="text-embedding") + + # Assert: Verify only TEXT_EMBEDDING providers are returned + assert result is not None + assert len(result) == 1 + assert result[0].provider == "cohere" + assert ModelType.TEXT_EMBEDDING in result[0].supported_model_types + + def test_get_models_by_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of models by provider. + + This test verifies: + - Proper model retrieval for a specific provider + - Correct response structure with tenant_id and model data + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider and models + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + provider_model_1 = self._create_test_provider_model( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai", "gpt-3.5-turbo", "llm" + ) + + provider_model_2 = self._create_test_provider_model( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai", "gpt-4", "llm" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock models + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity + from core.model_runtime.entities.common_entities import I18nObject + from core.model_runtime.entities.provider_entities import ProviderEntity + + # Create real model objects instead of mocks + provider_entity_1 = SimpleModelProviderEntity( + ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), + supported_model_types=[ModelType.LLM], + configurate_methods=[], + models=[], + ) + ) + + provider_entity_2 = SimpleModelProviderEntity( + ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), + supported_model_types=[ModelType.LLM], + configurate_methods=[], + models=[], + ) + ) + + mock_model_1 = ModelWithProviderEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo", zh_Hans="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + features=[], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + deprecated=False, + provider=provider_entity_1, + status="active", + load_balancing_enabled=False, + ) + + mock_model_2 = ModelWithProviderEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4", zh_Hans="GPT-4"), + model_type=ModelType.LLM, + features=[], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + deprecated=False, + provider=provider_entity_2, + status="active", + load_balancing_enabled=False, + ) + + mock_configurations = MagicMock() + mock_configurations.get_models.return_value = [mock_model_1, mock_model_2] + mock_provider_manager.get_configurations.return_value = mock_configurations + + # Act: Execute the method under test + service = ModelProviderService() + result = service.get_models_by_provider(tenant.id, "openai") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 2 + + # Verify first model + assert result[0].provider.tenant_id == tenant.id + assert result[0].model == "gpt-3.5-turbo" + assert result[0].provider.provider == "openai" + + # Verify second model + assert result[1].provider.tenant_id == tenant.id + assert result[1].model == "gpt-4" + assert result[1].provider.provider == "openai" + + # Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_configurations.get_models.assert_called_once_with(provider="openai") + + def test_get_provider_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of provider credentials. + + This test verifies: + - Proper credential retrieval for existing provider + - Correct handling of obfuscated credentials + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with credentials + mock_provider_configuration = MagicMock() + mock_provider_configuration.get_custom_credentials.return_value = { + "api_key": "sk-***123", + "base_url": "https://api.openai.com", + } + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Act: Execute the method under test + service = ModelProviderService() + result = service.get_provider_credentials(tenant.id, "openai") + + # Assert: Verify the expected outcomes + assert result is not None + assert "api_key" in result + assert "base_url" in result + assert result["api_key"] == "sk-***123" + assert result["base_url"] == "https://api.openai.com" + + # Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.get_custom_credentials.assert_called_once_with(obfuscated=True) + + def test_provider_credentials_validate_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful validation of provider credentials. + + This test verifies: + - Proper credential validation for existing provider + - Correct handling of valid credentials + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with validation method + mock_provider_configuration = MagicMock() + mock_provider_configuration.custom_credentials_validate.return_value = True + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Test credentials + test_credentials = {"api_key": "sk-test123", "base_url": "https://api.openai.com"} + + # Act: Execute the method under test + service = ModelProviderService() + # This should not raise an exception + service.provider_credentials_validate(tenant.id, "openai", test_credentials) + + # Assert: Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.custom_credentials_validate.assert_called_once_with(test_credentials) + + def test_provider_credentials_validate_invalid_provider( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test validation failure for non-existent provider. + + This test verifies: + - Proper error handling for non-existent provider + - Correct exception raising + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock ProviderManager to return empty configurations + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + mock_provider_manager.get_configurations.return_value = {} + + # Test credentials + test_credentials = {"api_key": "sk-test123", "base_url": "https://api.openai.com"} + + # Act & Assert: Execute the method under test and verify exception + service = ModelProviderService() + with pytest.raises(ValueError, match="Provider nonexistent does not exist."): + service.provider_credentials_validate(tenant.id, "nonexistent", test_credentials) + + # Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + + def test_get_default_model_of_model_type_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful retrieval of default model for a specific model type. + + This test verifies: + - Proper default model retrieval for tenant and model type + - Correct response structure with tenant_id and model data + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic default model + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock default model response + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + from core.model_runtime.entities.common_entities import I18nObject + + mock_default_model = DefaultModelEntity( + model="gpt-3.5-turbo", + model_type=ModelType.LLM, + provider=DefaultModelProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + icon_large=I18nObject(en_US="icon_large.png", zh_Hans="icon_large.png"), + supported_model_types=[ModelType.LLM], + ), + ) + + mock_provider_manager.get_default_model.return_value = mock_default_model + + # Act: Execute the method under test + service = ModelProviderService() + result = service.get_default_model_of_model_type(tenant.id, "llm") + + # Assert: Verify the expected outcomes + assert result is not None + assert result.model == "gpt-3.5-turbo" + assert result.model_type == ModelType.LLM + assert result.provider.tenant_id == tenant.id + assert result.provider.provider == "openai" + + # Verify mock interactions + mock_provider_manager.get_default_model.assert_called_once_with(tenant_id=tenant.id, model_type=ModelType.LLM) + + def test_update_default_model_of_model_type_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful update of default model for a specific model type. + + This test verifies: + - Proper default model update for tenant and model type + - Correct mock interactions with ProviderManager + - Database state management + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Act: Execute the method under test + service = ModelProviderService() + service.update_default_model_of_model_type(tenant.id, "llm", "openai", "gpt-4") + + # Assert: Verify mock interactions + mock_provider_manager.update_default_model_record.assert_called_once_with( + tenant_id=tenant.id, model_type=ModelType.LLM, provider="openai", model="gpt-4" + ) + + def test_get_model_provider_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of model provider icon. + + This test verifies: + - Proper icon retrieval for provider and icon type + - Correct response structure with byte data and mime type + - Mock interactions with ModelProviderFactory + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ModelProviderFactory to return realistic icon data + mock_model_provider_factory = mock_external_service_dependencies["model_provider_factory"].return_value + mock_model_provider_factory.get_provider_icon.return_value = (b"fake_icon_data", "image/png") + + # Act: Execute the method under test + service = ModelProviderService() + result = service.get_model_provider_icon(tenant.id, "openai", "icon_small", "en_US") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 2 + assert result[0] == b"fake_icon_data" + assert result[1] == "image/png" + + # Verify mock interactions + mock_model_provider_factory.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US") + + def test_switch_preferred_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful switching of preferred provider type. + + This test verifies: + - Proper provider type switching for tenant and provider + - Correct mock interactions with ProviderManager + - Provider configuration management + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with switch method + mock_provider_configuration = MagicMock() + mock_provider_configuration.switch_preferred_provider_type.return_value = None + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Act: Execute the method under test + service = ModelProviderService() + service.switch_preferred_provider(tenant.id, "openai", "custom") + + # Assert: Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.switch_preferred_provider_type.assert_called_once() + + def test_enable_model_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful enabling of a model. + + This test verifies: + - Proper model enabling for tenant, provider, and model + - Correct mock interactions with ProviderManager + - Model configuration management + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with enable method + mock_provider_configuration = MagicMock() + mock_provider_configuration.enable_model.return_value = None + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Act: Execute the method under test + service = ModelProviderService() + service.enable_model(tenant.id, "openai", "gpt-4", "llm") + + # Assert: Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.enable_model.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4") + + def test_get_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of model credentials. + + This test verifies: + - Proper credential retrieval for model + - Correct response structure with obfuscated credentials + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with model credentials + mock_provider_configuration = MagicMock() + mock_provider_configuration.get_custom_model_credentials.return_value = { + "api_key": "sk-***123", + "base_url": "https://api.openai.com", + } + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Act: Execute the method under test + service = ModelProviderService() + result = service.get_model_credentials(tenant.id, "openai", "llm", "gpt-4") + + # Assert: Verify the expected outcomes + assert result is not None + assert "api_key" in result + assert "base_url" in result + assert result["api_key"] == "sk-***123" + assert result["base_url"] == "https://api.openai.com" + + # Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.get_custom_model_credentials.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4", obfuscated=True + ) + + def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful validation of model credentials. + + This test verifies: + - Proper credential validation for model + - Correct mock interactions with ProviderManager + - Model credential validation process + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with validation method + mock_provider_configuration = MagicMock() + mock_provider_configuration.custom_model_credentials_validate.return_value = True + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Test credentials + test_credentials = {"api_key": "sk-test123", "base_url": "https://api.openai.com"} + + # Act: Execute the method under test + service = ModelProviderService() + # This should not raise an exception + service.model_credentials_validate(tenant.id, "openai", "llm", "gpt-4", test_credentials) + + # Assert: Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.custom_model_credentials_validate.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials + ) + + def test_save_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful saving of model credentials. + + This test verifies: + - Proper credential saving for model + - Correct mock interactions with ProviderManager + - Model credential management + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with save method + mock_provider_configuration = MagicMock() + mock_provider_configuration.add_or_update_custom_model_credentials.return_value = None + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Test credentials + test_credentials = {"api_key": "sk-test123", "base_url": "https://api.openai.com"} + + # Act: Execute the method under test + service = ModelProviderService() + service.save_model_credentials(tenant.id, "openai", "llm", "gpt-4", test_credentials) + + # Assert: Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.add_or_update_custom_model_credentials.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials + ) + + def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful removal of model credentials. + + This test verifies: + - Proper credential removal for model + - Correct mock interactions with ProviderManager + - Model credential cleanup + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with remove method + mock_provider_configuration = MagicMock() + mock_provider_configuration.delete_custom_model_credentials.return_value = None + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Act: Execute the method under test + service = ModelProviderService() + service.remove_model_credentials(tenant.id, "openai", "llm", "gpt-4") + + # Assert: Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.delete_custom_model_credentials.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4" + ) + + def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of models by model type. + + This test verifies: + - Proper model retrieval for specific model type + - Correct response structure with provider grouping + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configurations object with get_models method + mock_provider_configurations = MagicMock() + mock_provider_configurations.get_models.return_value = [ + MagicMock( + provider=MagicMock( + provider="openai", + label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, + icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, + icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, + ), + model="gpt-3.5-turbo", + model_type=ModelType.LLM, + status=ModelStatus.ACTIVE, + deprecated=False, + label={"en_US": "GPT-3.5 Turbo", "zh_Hans": "GPT-3.5 Turbo"}, + features=[], + fetch_from="predefined-model", + model_properties={}, + load_balancing_enabled=False, + ), + MagicMock( + provider=MagicMock( + provider="openai", + label={"en_US": "OpenAI", "zh_Hans": "OpenAI"}, + icon_small={"en_US": "icon_small.png", "zh_Hans": "icon_small.png"}, + icon_large={"en_US": "icon_large.png", "zh_Hans": "icon_large.png"}, + ), + model="gpt-4", + model_type=ModelType.LLM, + status=ModelStatus.ACTIVE, + deprecated=False, + label={"en_US": "GPT-4", "zh_Hans": "GPT-4"}, + features=[], + fetch_from="predefined-model", + model_properties={}, + load_balancing_enabled=False, + ), + ] + mock_provider_manager.get_configurations.return_value = mock_provider_configurations + + # Act: Execute the method under test + service = ModelProviderService() + result = service.get_models_by_model_type(tenant.id, "llm") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 1 # One provider group + assert result[0].provider == "openai" + assert len(result[0].models) == 2 # Two models in the provider + + # Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM) + + def test_get_model_parameter_rules_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of model parameter rules. + + This test verifies: + - Proper parameter rules retrieval for model + - Correct mock interactions with ProviderManager + - Model schema handling + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with parameter rules + mock_provider_configuration = MagicMock() + mock_credentials = {"api_key": "sk-test123"} + mock_model_schema = MagicMock() + + # Create mock parameter rules with proper return values + mock_temperature_rule = MagicMock() + mock_temperature_rule.name = "temperature" + mock_temperature_rule.type = "float" + mock_temperature_rule.min = 0.0 + mock_temperature_rule.max = 2.0 + + mock_max_tokens_rule = MagicMock() + mock_max_tokens_rule.name = "max_tokens" + mock_max_tokens_rule.type = "integer" + mock_max_tokens_rule.min = 1 + mock_max_tokens_rule.max = 4096 + + mock_model_schema.parameter_rules = [mock_temperature_rule, mock_max_tokens_rule] + + mock_provider_configuration.get_current_credentials.return_value = mock_credentials + mock_provider_configuration.get_model_schema.return_value = mock_model_schema + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Act: Execute the method under test + service = ModelProviderService() + result = service.get_model_parameter_rules(tenant.id, "openai", "gpt-4") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 2 + assert result[0].name == "temperature" + assert result[1].name == "max_tokens" + + # Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.get_current_credentials.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4" + ) + mock_provider_configuration.get_model_schema.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4", credentials=mock_credentials + ) + + def test_get_model_parameter_rules_no_credentials( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test parameter rules retrieval when no credentials are available. + + This test verifies: + - Proper handling of missing credentials + - Empty result when no credentials exist + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create test provider + provider = self._create_test_provider( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "openai" + ) + + # Mock ProviderManager to return realistic configuration + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + + # Create mock provider configuration with no credentials + mock_provider_configuration = MagicMock() + mock_provider_configuration.get_current_credentials.return_value = None + mock_provider_manager.get_configurations.return_value = {"openai": mock_provider_configuration} + + # Act: Execute the method under test + service = ModelProviderService() + result = service.get_model_parameter_rules(tenant.id, "openai", "gpt-4") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + + # Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) + mock_provider_configuration.get_current_credentials.assert_called_once_with( + model_type=ModelType.LLM, model="gpt-4" + ) + + def test_get_model_parameter_rules_provider_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test parameter rules retrieval when provider does not exist. + + This test verifies: + - Proper error handling for non-existent provider + - ValueError is raised with appropriate message + - Mock interactions with ProviderManager + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Mock ProviderManager to return empty configurations + mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value + mock_provider_manager.get_configurations.return_value = {} + + # Act & Assert: Execute the method under test and expect ValueError + service = ModelProviderService() + with pytest.raises(ValueError, match="Provider openai does not exist."): + service.get_model_parameter_rules(tenant.id, "openai", "gpt-4") + + # Verify mock interactions + mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py new file mode 100644 index 0000000000..9e6b9837ae --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -0,0 +1,620 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from models.model import EndUser, Message +from models.web import SavedMessage +from services.app_service import AppService +from services.saved_message_service import SavedMessageService + + +class TestSavedMessageService: + """Integration tests for SavedMessageService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.saved_message_service.MessageService") as mock_message_service, + ): + # Setup default mock returns + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Mock ModelManager for app creation + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + # Mock MessageService + mock_message_service.get_message.return_value = None + mock_message_service.pagination_by_last_id.return_value = None + + yield { + "account_feature_service": mock_account_feature_service, + "model_manager": mock_model_manager, + "message_service": mock_message_service, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant first + from services.account_service import AccountService, TenantService + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app with realistic data + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + return app, account + + def _create_test_end_user(self, db_session_with_containers, app): + """ + Helper method to create a test end user for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: App instance to associate the end user with + + Returns: + EndUser: Created end user instance + """ + fake = Faker() + + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + external_user_id=fake.uuid4(), + name=fake.name(), + type="normal", + session_id=fake.uuid4(), + is_anonymous=False, + ) + + from extensions.ext_database import db + + db.session.add(end_user) + db.session.commit() + + return end_user + + def _create_test_message(self, db_session_with_containers, app, user): + """ + Helper method to create a test message for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: App instance to associate the message with + user: User instance (Account or EndUser) to associate the message with + + Returns: + Message: Created message instance + """ + fake = Faker() + + # Create a simple conversation first + from models.model import Conversation + + conversation = Conversation( + app_id=app.id, + from_source="account" if hasattr(user, "current_tenant") else "end_user", + from_end_user_id=user.id if not hasattr(user, "current_tenant") else None, + from_account_id=user.id if hasattr(user, "current_tenant") else None, + name=fake.sentence(nb_words=3), + inputs={}, + status="normal", + mode="chat", + ) + + from extensions.ext_database import db + + db.session.add(conversation) + db.session.commit() + + # Create message + message = Message( + app_id=app.id, + conversation_id=conversation.id, + from_source="account" if hasattr(user, "current_tenant") else "end_user", + from_end_user_id=user.id if not hasattr(user, "current_tenant") else None, + from_account_id=user.id if hasattr(user, "current_tenant") else None, + inputs={}, + query=fake.sentence(nb_words=5), + message=fake.text(max_nb_chars=100), + answer=fake.text(max_nb_chars=200), + message_tokens=50, + answer_tokens=100, + message_unit_price=0.001, + answer_unit_price=0.002, + total_price=0.003, + currency="USD", + status="success", + ) + + db.session.add(message) + db.session.commit() + + return message + + def test_pagination_by_last_id_success_with_account_user( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful pagination by last ID with account user. + + This test verifies: + - Proper pagination with account user + - Correct filtering by app_id and user + - Proper role identification for account users + - MessageService integration + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create test messages + message1 = self._create_test_message(db_session_with_containers, app, account) + message2 = self._create_test_message(db_session_with_containers, app, account) + + # Create saved messages + saved_message1 = SavedMessage( + app_id=app.id, + message_id=message1.id, + created_by_role="account", + created_by=account.id, + ) + saved_message2 = SavedMessage( + app_id=app.id, + message_id=message2.id, + created_by_role="account", + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add_all([saved_message1, saved_message2]) + db.session.commit() + + # Mock MessageService.pagination_by_last_id return value + from libs.infinite_scroll_pagination import InfiniteScrollPagination + + mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=10, has_more=False) + mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination + + # Act: Execute the method under test + result = SavedMessageService.pagination_by_last_id(app_model=app, user=account, last_id=None, limit=10) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.data == [message1, message2] + assert result.limit == 10 + assert result.has_more is False + + # Verify MessageService was called with correct parameters + # Sort the IDs to handle database query order variations + expected_include_ids = sorted([message1.id, message2.id]) + actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args + actual_include_ids = sorted(actual_call.kwargs.get("include_ids", [])) + + assert actual_call.kwargs["app_model"] == app + assert actual_call.kwargs["user"] == account + assert actual_call.kwargs["last_id"] is None + assert actual_call.kwargs["limit"] == 10 + assert actual_include_ids == expected_include_ids + + # Verify database state + db.session.refresh(saved_message1) + db.session.refresh(saved_message2) + assert saved_message1.id is not None + assert saved_message2.id is not None + assert saved_message1.created_by_role == "account" + assert saved_message2.created_by_role == "account" + + def test_pagination_by_last_id_success_with_end_user( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful pagination by last ID with end user. + + This test verifies: + - Proper pagination with end user + - Correct filtering by app_id and user + - Proper role identification for end users + - MessageService integration + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + end_user = self._create_test_end_user(db_session_with_containers, app) + + # Create test messages + message1 = self._create_test_message(db_session_with_containers, app, end_user) + message2 = self._create_test_message(db_session_with_containers, app, end_user) + + # Create saved messages + saved_message1 = SavedMessage( + app_id=app.id, + message_id=message1.id, + created_by_role="end_user", + created_by=end_user.id, + ) + saved_message2 = SavedMessage( + app_id=app.id, + message_id=message2.id, + created_by_role="end_user", + created_by=end_user.id, + ) + + from extensions.ext_database import db + + db.session.add_all([saved_message1, saved_message2]) + db.session.commit() + + # Mock MessageService.pagination_by_last_id return value + from libs.infinite_scroll_pagination import InfiniteScrollPagination + + mock_pagination = InfiniteScrollPagination(data=[message1, message2], limit=5, has_more=True) + mock_external_service_dependencies["message_service"].pagination_by_last_id.return_value = mock_pagination + + # Act: Execute the method under test + result = SavedMessageService.pagination_by_last_id( + app_model=app, user=end_user, last_id="test_last_id", limit=5 + ) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.data == [message1, message2] + assert result.limit == 5 + assert result.has_more is True + + # Verify MessageService was called with correct parameters + # Sort the IDs to handle database query order variations + expected_include_ids = sorted([message1.id, message2.id]) + actual_call = mock_external_service_dependencies["message_service"].pagination_by_last_id.call_args + actual_include_ids = sorted(actual_call.kwargs.get("include_ids", [])) + + assert actual_call.kwargs["app_model"] == app + assert actual_call.kwargs["user"] == end_user + assert actual_call.kwargs["last_id"] == "test_last_id" + assert actual_call.kwargs["limit"] == 5 + assert actual_include_ids == expected_include_ids + + # Verify database state + db.session.refresh(saved_message1) + db.session.refresh(saved_message2) + assert saved_message1.id is not None + assert saved_message2.id is not None + assert saved_message1.created_by_role == "end_user" + assert saved_message2.created_by_role == "end_user" + + def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful save of a new message. + + This test verifies: + - Proper creation of new saved message + - Correct database state after save + - Proper relationship establishment + - MessageService integration for message retrieval + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Mock MessageService.get_message return value + mock_external_service_dependencies["message_service"].get_message.return_value = message + + # Act: Execute the method under test + SavedMessageService.save(app_model=app, user=account, message_id=message.id) + + # Assert: Verify the expected outcomes + # Check if saved message was created in database + from extensions.ext_database import db + + saved_message = ( + db.session.query(SavedMessage) + .where( + SavedMessage.app_id == app.id, + SavedMessage.message_id == message.id, + SavedMessage.created_by_role == "account", + SavedMessage.created_by == account.id, + ) + .first() + ) + + assert saved_message is not None + assert saved_message.app_id == app.id + assert saved_message.message_id == message.id + assert saved_message.created_by_role == "account" + assert saved_message.created_by == account.id + assert saved_message.created_at is not None + + # Verify MessageService.get_message was called + mock_external_service_dependencies["message_service"].get_message.assert_called_once_with( + app_model=app, user=account, message_id=message.id + ) + + # Verify database state + db.session.refresh(saved_message) + assert saved_message.id is not None + + def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test error handling when no user is provided. + + This test verifies: + - Proper error handling for missing user + - ValueError is raised when user is None + - No database operations are performed + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + + assert "User is required" in str(exc_info.value) + + # Verify no database operations were performed + from extensions.ext_database import db + + saved_messages = db.session.query(SavedMessage).all() + assert len(saved_messages) == 0 + + def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test error handling when saving message with no user. + + This test verifies: + - Method returns early when user is None + - No database operations are performed + - No exceptions are raised + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Act: Execute the method under test with None user + result = SavedMessageService.save(app_model=app, user=None, message_id=message.id) + + # Assert: Verify the expected outcomes + assert result is None + + # Verify no saved message was created + from extensions.ext_database import db + + saved_message = ( + db.session.query(SavedMessage) + .where( + SavedMessage.app_id == app.id, + SavedMessage.message_id == message.id, + ) + .first() + ) + + assert saved_message is None + + def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful deletion of an existing saved message. + + This test verifies: + - Proper deletion of existing saved message + - Correct database state after deletion + - No errors during deletion process + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Create a saved message first + saved_message = SavedMessage( + app_id=app.id, + message_id=message.id, + created_by_role="account", + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(saved_message) + db.session.commit() + + # Verify saved message exists + assert ( + db.session.query(SavedMessage) + .where( + SavedMessage.app_id == app.id, + SavedMessage.message_id == message.id, + SavedMessage.created_by_role == "account", + SavedMessage.created_by == account.id, + ) + .first() + is not None + ) + + # Act: Execute the method under test + SavedMessageService.delete(app_model=app, user=account, message_id=message.id) + + # Assert: Verify the expected outcomes + # Check if saved message was deleted from database + deleted_saved_message = ( + db.session.query(SavedMessage) + .where( + SavedMessage.app_id == app.id, + SavedMessage.message_id == message.id, + SavedMessage.created_by_role == "account", + SavedMessage.created_by == account.id, + ) + .first() + ) + + assert deleted_saved_message is None + + # Verify database state + db.session.commit() + # The message should still exist, only the saved_message should be deleted + assert db.session.query(Message).where(Message.id == message.id).first() is not None + + def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test error handling when no user is provided. + + This test verifies: + - Proper error handling for missing user + - ValueError is raised when user is None + - No database operations are performed + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=10) + + assert "User is required" in str(exc_info.value) + + # Verify no database operations were performed for this specific test + # Note: We don't check total count as other tests may have created data + # Instead, we verify that the error was properly raised + pass + + def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test error handling when saving message with no user. + + This test verifies: + - Method returns early when user is None + - No database operations are performed + - No exceptions are raised + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Act: Execute the method under test with None user + result = SavedMessageService.save(app_model=app, user=None, message_id=message.id) + + # Assert: Verify the expected outcomes + assert result is None + + # Verify no saved message was created + from extensions.ext_database import db + + saved_message = ( + db.session.query(SavedMessage) + .where( + SavedMessage.app_id == app.id, + SavedMessage.message_id == message.id, + ) + .first() + ) + + assert saved_message is None + + def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful deletion of an existing saved message. + + This test verifies: + - Proper deletion of existing saved message + - Correct database state after deletion + - No errors during deletion process + """ + # Arrange: Create test data + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + message = self._create_test_message(db_session_with_containers, app, account) + + # Create a saved message first + saved_message = SavedMessage( + app_id=app.id, + message_id=message.id, + created_by_role="account", + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(saved_message) + db.session.commit() + + # Verify saved message exists + assert ( + db.session.query(SavedMessage) + .where( + SavedMessage.app_id == app.id, + SavedMessage.message_id == message.id, + SavedMessage.created_by_role == "account", + SavedMessage.created_by == account.id, + ) + .first() + is not None + ) + + # Act: Execute the method under test + SavedMessageService.delete(app_model=app, user=account, message_id=message.id) + + # Assert: Verify the expected outcomes + # Check if saved message was deleted from database + deleted_saved_message = ( + db.session.query(SavedMessage) + .where( + SavedMessage.app_id == app.id, + SavedMessage.message_id == message.id, + SavedMessage.created_by_role == "account", + SavedMessage.created_by == account.id, + ) + .first() + ) + + assert deleted_saved_message is None + + # Verify database state + db.session.commit() + # The message should still exist, only the saved_message should be deleted + assert db.session.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py new file mode 100644 index 0000000000..2d5cdf426d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -0,0 +1,1192 @@ +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import NotFound + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset +from models.model import App, Tag, TagBinding +from services.tag_service import TagService + + +class TestTagService: + """Integration tests for TagService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.tag_service.current_user") as mock_current_user, + ): + # Setup default mock returns + mock_current_user.current_tenant_id = "test-tenant-id" + mock_current_user.id = "test-user-id" + + yield { + "current_user": mock_current_user, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + # Update mock to use real tenant ID + mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id + mock_external_service_dependencies["current_user"].id = account.id + + return account, tenant + + def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + """ + Helper method to create a test dataset for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant_id: Tenant ID for the dataset + + Returns: + Dataset: Created dataset instance + """ + fake = Faker() + + dataset = Dataset( + name=fake.company(), + description=fake.text(max_nb_chars=100), + provider="vendor", + permission="only_me", + data_source_type="upload", + indexing_technique="high_quality", + tenant_id=tenant_id, + created_by=mock_external_service_dependencies["current_user"].id, + ) + + from extensions.ext_database import db + + db.session.add(dataset) + db.session.commit() + + return dataset + + def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + """ + Helper method to create a test app for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant_id: Tenant ID for the app + + Returns: + App: Created app instance + """ + fake = Faker() + + app = App( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + enable_site=False, + enable_api=False, + tenant_id=tenant_id, + created_by=mock_external_service_dependencies["current_user"].id, + ) + + from extensions.ext_database import db + + db.session.add(app) + db.session.commit() + + return app + + def _create_test_tags( + self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3 + ): + """ + Helper method to create test tags for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant_id: Tenant ID for the tags + tag_type: Type of tags to create + count: Number of tags to create + + Returns: + list: List of created tag instances + """ + fake = Faker() + tags = [] + + for i in range(count): + tag = Tag( + name=f"tag_{tag_type}_{i}_{fake.word()}", + type=tag_type, + tenant_id=tenant_id, + created_by=mock_external_service_dependencies["current_user"].id, + ) + tags.append(tag) + + from extensions.ext_database import db + + for tag in tags: + db.session.add(tag) + db.session.commit() + + return tags + + def _create_test_tag_bindings( + self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id + ): + """ + Helper method to create test tag bindings for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tags: List of tags to bind + target_id: Target ID to bind tags to + tenant_id: Tenant ID for the bindings + + Returns: + list: List of created tag binding instances + """ + tag_bindings = [] + + for tag in tags: + tag_binding = TagBinding( + tag_id=tag.id, + target_id=target_id, + tenant_id=tenant_id, + created_by=mock_external_service_dependencies["current_user"].id, + ) + tag_bindings.append(tag_binding) + + from extensions.ext_database import db + + for tag_binding in tag_bindings: + db.session.add(tag_binding) + db.session.commit() + + return tag_bindings + + def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of tags with binding count. + + This test verifies: + - Proper tag retrieval with binding count + - Correct filtering by tag type and tenant + - Proper ordering by creation date + - Binding count calculation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 3 + ) + + # Create dataset and bind tags + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, tags[:2], dataset.id, tenant.id + ) + + # Act: Execute the method under test + result = TagService.get_tags("knowledge", tenant.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 3 + + # Verify tag data structure + for tag_result in result: + assert hasattr(tag_result, "id") + assert hasattr(tag_result, "type") + assert hasattr(tag_result, "name") + assert hasattr(tag_result, "binding_count") + assert tag_result.type == "knowledge" + + # Verify binding count + tag_with_bindings = next((t for t in result if t.binding_count > 0), None) + assert tag_with_bindings is not None + assert tag_with_bindings.binding_count >= 1 + + # Verify ordering (newest first) - note: created_at is not in SELECT but used in ORDER BY + # The ordering is handled by the database, we just verify the results are returned + assert len(result) == 3 + + def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval with keyword filtering. + + This test verifies: + - Proper keyword filtering functionality + - Case-insensitive search + - Partial match functionality + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags with specific names + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 3 + ) + + # Update tag names to make them searchable + from extensions.ext_database import db + + tags[0].name = "python_development" + tags[1].name = "machine_learning" + tags[2].name = "web_development" + db.session.commit() + + # Act: Execute the method under test with keyword filter + result = TagService.get_tags("app", tenant.id, keyword="development") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 2 # Should find python_development and web_development + + # Verify filtered results contain the keyword + for tag_result in result: + assert "development" in tag_result.name.lower() + + # Verify no results for non-matching keyword + result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent") + assert len(result_no_match) == 0 + + def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval when no tags exist. + + This test verifies: + - Proper handling of empty tag sets + - Correct return value for no results + """ + # Arrange: Create test data without tags + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute the method under test + result = TagService.get_tags("knowledge", tenant.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of target IDs by tag IDs. + + This test verifies: + - Proper target ID retrieval for valid tag IDs + - Correct filtering by tag type and tenant + - Proper handling of tag bindings + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 3 + ) + + # Create multiple datasets and bind tags + datasets = [] + for i in range(2): + dataset = self._create_test_dataset( + db_session_with_containers, mock_external_service_dependencies, tenant.id + ) + datasets.append(dataset) + # Bind first two tags to first dataset, last tag to second dataset + tags_to_bind = tags[:2] if i == 0 else tags[2:] + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, tags_to_bind, dataset.id, tenant.id + ) + + # Act: Execute the method under test + tag_ids = [tag.id for tag in tags] + result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, tag_ids) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 3 # Should find 3 target IDs (2 from first dataset, 1 from second) + + # Verify all dataset IDs are returned + dataset_ids = [dataset.id for dataset in datasets] + for target_id in result: + assert target_id in dataset_ids + + # Verify the first dataset appears twice (for the first two tags) + first_dataset_count = result.count(datasets[0].id) + assert first_dataset_count == 2 + + # Verify the second dataset appears once (for the last tag) + second_dataset_count = result.count(datasets[1].id) + assert second_dataset_count == 1 + + def test_get_target_ids_by_tag_ids_empty_tag_ids( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test target ID retrieval with empty tag IDs list. + + This test verifies: + - Proper handling of empty tag IDs + - Correct return value for empty input + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute the method under test with empty tag IDs + result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, []) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_get_target_ids_by_tag_ids_no_matching_tags( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test target ID retrieval when no tags match the criteria. + + This test verifies: + - Proper handling of non-existent tag IDs + - Correct return value for no matches + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent tag IDs + import uuid + + non_existent_tag_ids = [str(uuid.uuid4()), str(uuid.uuid4())] + + # Act: Execute the method under test + result = TagService.get_target_ids_by_tag_ids("knowledge", tenant.id, non_existent_tag_ids) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of tags by tag name. + + This test verifies: + - Proper tag retrieval by name + - Correct filtering by tag type and tenant + - Proper return value structure + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags with specific names + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 2 + ) + + # Update tag names to make them searchable + from extensions.ext_database import db + + tags[0].name = "python_tag" + tags[1].name = "ml_tag" + db.session.commit() + + # Act: Execute the method under test + result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 1 + assert result[0].name == "python_tag" + assert result[0].type == "app" + assert result[0].tenant_id == tenant.id + + def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval by name when no matches exist. + + This test verifies: + - Proper handling of non-existent tag names + - Correct return value for no matches + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute the method under test with non-existent tag name + result = TagService.get_tag_by_tag_name("knowledge", tenant.id, "nonexistent_tag") + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval by name with empty parameters. + + This test verifies: + - Proper handling of empty tag type + - Proper handling of empty tag name + - Correct return value for invalid input + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute the method under test with empty parameters + result_empty_type = TagService.get_tag_by_tag_name("", tenant.id, "test_tag") + result_empty_name = TagService.get_tag_by_tag_name("knowledge", tenant.id, "") + + # Assert: Verify the expected outcomes + assert result_empty_type is not None + assert len(result_empty_type) == 0 + assert result_empty_name is not None + assert len(result_empty_name) == 0 + + def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of tags by target ID. + + This test verifies: + - Proper tag retrieval for a specific target + - Correct filtering by tag type and tenant + - Proper join with tag bindings + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 3 + ) + + # Create app and bind tags + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, tags, app.id, tenant.id + ) + + # Act: Execute the method under test + result = TagService.get_tags_by_target_id("app", tenant.id, app.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 3 + + # Verify all tags are returned + for tag in result: + assert tag.type == "app" + assert tag.tenant_id == tenant.id + assert tag.id in [t.id for t in tags] + + def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag retrieval by target ID when no tags are bound. + + This test verifies: + - Proper handling of targets with no tag bindings + - Correct return value for no results + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create app without binding any tags + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Execute the method under test + result = TagService.get_tags_by_target_id("app", tenant.id, app.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert len(result) == 0 + assert isinstance(result, list) + + def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag creation. + + This test verifies: + - Proper tag creation with all required fields + - Correct database state after creation + - Proper UUID generation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + tag_args = {"name": "test_tag_name", "type": "knowledge"} + + # Act: Execute the method under test + result = TagService.save_tags(tag_args) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == "test_tag_name" + assert result.type == "knowledge" + assert result.tenant_id == tenant.id + assert result.created_by == account.id + assert result.id is not None + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + + # Verify tag was actually saved to database + saved_tag = db.session.query(Tag).where(Tag.id == result.id).first() + assert saved_tag is not None + assert saved_tag.name == "test_tag_name" + + def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag creation with duplicate name. + + This test verifies: + - Proper error handling for duplicate tag names + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create first tag + tag_args = {"name": "duplicate_tag", "type": "app"} + TagService.save_tags(tag_args) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + TagService.save_tags(tag_args) + assert "Tag name already exists" in str(exc_info.value) + + def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag update. + + This test verifies: + - Proper tag update with new name + - Correct database state after update + - Proper error handling for non-existent tags + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create a tag to update + tag_args = {"name": "original_name", "type": "knowledge"} + tag = TagService.save_tags(tag_args) + + # Update args + update_args = {"name": "updated_name", "type": "knowledge"} + + # Act: Execute the method under test + result = TagService.update_tags(update_args, tag.id) + + # Assert: Verify the expected outcomes + assert result is not None + assert result.name == "updated_name" + assert result.type == "knowledge" + assert result.id == tag.id + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.name == "updated_name" + + # Verify tag was actually updated in database + updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first() + assert updated_tag is not None + assert updated_tag.name == "updated_name" + + def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag update for non-existent tag. + + This test verifies: + - Proper error handling for non-existent tags + - Correct exception type + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent tag ID + import uuid + + non_existent_tag_id = str(uuid.uuid4()) + + update_args = {"name": "updated_name", "type": "knowledge"} + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.update_tags(update_args, non_existent_tag_id) + assert "Tag not found" in str(exc_info.value) + + def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag update with duplicate name. + + This test verifies: + - Proper error handling for duplicate tag names during update + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create two tags + tag1_args = {"name": "first_tag", "type": "app"} + tag1 = TagService.save_tags(tag1_args) + + tag2_args = {"name": "second_tag", "type": "app"} + tag2 = TagService.save_tags(tag2_args) + + # Try to update second tag with first tag's name + update_args = {"name": "first_tag", "type": "app"} + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + TagService.update_tags(update_args, tag2.id) + assert "Tag name already exists" in str(exc_info.value) + + def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of tag binding count. + + This test verifies: + - Proper binding count calculation + - Correct handling of tags with no bindings + - Proper database query execution + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2 + ) + + # Create dataset and bind first tag + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, [tags[0]], dataset.id, tenant.id + ) + + # Act: Execute the method under test + result_tag_with_bindings = TagService.get_tag_binding_count(tags[0].id) + result_tag_without_bindings = TagService.get_tag_binding_count(tags[1].id) + + # Assert: Verify the expected outcomes + assert result_tag_with_bindings == 1 + assert result_tag_without_bindings == 0 + + def test_get_tag_binding_count_non_existent_tag( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test binding count retrieval for non-existent tag. + + This test verifies: + - Proper handling of non-existent tag IDs + - Correct return value for non-existent tags + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent tag ID + import uuid + + non_existent_tag_id = str(uuid.uuid4()) + + # Act: Execute the method under test + result = TagService.get_tag_binding_count(non_existent_tag_id) + + # Assert: Verify the expected outcomes + assert result == 0 + + def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag deletion. + + This test verifies: + - Proper tag deletion from database + - Proper cleanup of associated tag bindings + - Correct database state after deletion + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag with bindings + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 1 + )[0] + + # Create app and bind tag + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, [tag], app.id, tenant.id + ) + + # Verify tag and binding exist before deletion + from extensions.ext_database import db + + tag_before = db.session.query(Tag).where(Tag.id == tag.id).first() + assert tag_before is not None + + binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + assert binding_before is not None + + # Act: Execute the method under test + TagService.delete_tag(tag.id) + + # Assert: Verify the expected outcomes + # Verify tag was deleted + tag_after = db.session.query(Tag).where(Tag.id == tag.id).first() + assert tag_after is None + + # Verify tag binding was deleted + binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + assert binding_after is None + + def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag deletion for non-existent tag. + + This test verifies: + - Proper error handling for non-existent tags + - Correct exception type + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent tag ID + import uuid + + non_existent_tag_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.delete_tag(non_existent_tag_id) + assert "Tag not found" in str(exc_info.value) + + def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag binding creation. + + This test verifies: + - Proper tag binding creation + - Correct handling of duplicate bindings + - Proper database state after creation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2 + ) + + # Create dataset + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Execute the method under test + binding_args = {"type": "knowledge", "target_id": dataset.id, "tag_ids": [tag.id for tag in tags]} + TagService.save_tag_binding(binding_args) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + # Verify tag bindings were created + for tag in tags: + binding = ( + db.session.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() + ) + assert binding is not None + assert binding.tenant_id == tenant.id + assert binding.created_by == account.id + + def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag binding creation with duplicate bindings. + + This test verifies: + - Proper handling of duplicate tag bindings + - No errors when trying to create existing bindings + - Correct database state after operation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 1 + )[0] + + # Create app + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Create first binding + binding_args = {"type": "app", "target_id": app.id, "tag_ids": [tag.id]} + TagService.save_tag_binding(binding_args) + + # Act: Try to create duplicate binding + TagService.save_tag_binding(binding_args) + + # Assert: Verify the expected outcomes + from extensions.ext_database import db + + # Verify only one binding exists + bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + assert len(bindings) == 1 + + def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test tag binding creation with invalid target type. + + This test verifies: + - Proper error handling for invalid target types + - Correct exception type + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 1 + )[0] + + # Create non-existent target ID + import uuid + + non_existent_target_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + binding_args = {"type": "invalid_type", "target_id": non_existent_target_id, "tag_ids": [tag.id]} + + with pytest.raises(NotFound) as exc_info: + TagService.save_tag_binding(binding_args) + assert "Invalid binding type" in str(exc_info.value) + + def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful tag binding deletion. + + This test verifies: + - Proper tag binding deletion from database + - Correct database state after deletion + - Proper error handling for non-existent bindings + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 1 + )[0] + + # Create dataset and bind tag + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + self._create_test_tag_bindings( + db_session_with_containers, mock_external_service_dependencies, [tag], dataset.id, tenant.id + ) + + # Verify binding exists before deletion + from extensions.ext_database import db + + binding_before = ( + db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + ) + assert binding_before is not None + + # Act: Execute the method under test + delete_args = {"type": "knowledge", "target_id": dataset.id, "tag_id": tag.id} + TagService.delete_tag_binding(delete_args) + + # Assert: Verify the expected outcomes + # Verify tag binding was deleted + binding_after = ( + db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + ) + assert binding_after is None + + def test_delete_tag_binding_non_existent_binding( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test tag binding deletion for non-existent binding. + + This test verifies: + - Proper handling of non-existent tag bindings + - No errors when trying to delete non-existent bindings + - Correct database state after operation + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create tag and dataset without binding + tag = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "app", 1 + )[0] + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Try to delete non-existent binding + delete_args = {"type": "app", "target_id": app.id, "tag_id": tag.id} + TagService.delete_tag_binding(delete_args) + + # Assert: Verify the expected outcomes + # No error should be raised, and database state should remain unchanged + from extensions.ext_database import db + + bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + assert len(bindings) == 0 + + def test_check_target_exists_knowledge_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful target existence check for knowledge type. + + This test verifies: + - Proper validation of knowledge dataset existence + - Correct error handling for non-existent datasets + - Proper tenant filtering + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create dataset + dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Execute the method under test + TagService.check_target_exists("knowledge", dataset.id) + + # Assert: Verify the expected outcomes + # No exception should be raised for existing dataset + + def test_check_target_exists_knowledge_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test target existence check for non-existent knowledge dataset. + + This test verifies: + - Proper error handling for non-existent knowledge datasets + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent dataset ID + import uuid + + non_existent_dataset_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.check_target_exists("knowledge", non_existent_dataset_id) + assert "Dataset not found" in str(exc_info.value) + + def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful target existence check for app type. + + This test verifies: + - Proper validation of app existence + - Correct error handling for non-existent apps + - Proper tenant filtering + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create app + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) + + # Act: Execute the method under test + TagService.check_target_exists("app", app.id) + + # Assert: Verify the expected outcomes + # No exception should be raised for existing app + + def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test target existence check for non-existent app. + + This test verifies: + - Proper error handling for non-existent apps + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent app ID + import uuid + + non_existent_app_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.check_target_exists("app", non_existent_app_id) + assert "App not found" in str(exc_info.value) + + def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test target existence check for invalid type. + + This test verifies: + - Proper error handling for invalid target types + - Correct exception type and message + """ + # Arrange: Create test data + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Create non-existent target ID + import uuid + + non_existent_target_id = str(uuid.uuid4()) + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + TagService.check_target_exists("invalid_type", non_existent_target_id) + assert "Invalid binding type" in str(exc_info.value) diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py new file mode 100644 index 0000000000..6d6f1dab72 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -0,0 +1,574 @@ +from unittest.mock import patch + +import pytest +from faker import Faker + +from core.app.entities.app_invoke_entities import InvokeFrom +from models.account import Account +from models.model import Conversation, EndUser +from models.web import PinnedConversation +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.web_conversation_service import WebConversationService + + +class TestWebConversationService: + """Integration tests for WebConversationService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for app service + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + # Mock ModelManager for model configuration + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test app and account for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (app, account) - Created app and account instances + """ + fake = Faker() + + # Setup mocks for account creation + mock_external_service_dependencies[ + "account_feature_service" + ].get_system_features.return_value.is_allow_register = True + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create app with realistic data + app_args = { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + "api_rph": 100, + "api_rpm": 10, + } + + app_service = AppService() + app = app_service.create_app(tenant.id, app_args, account) + + return app, account + + def _create_test_end_user(self, db_session_with_containers, app): + """ + Helper method to create a test end user for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: App instance + + Returns: + EndUser: Created end user instance + """ + fake = Faker() + + end_user = EndUser( + session_id=fake.uuid4(), + app_id=app.id, + type="normal", + is_anonymous=False, + tenant_id=app.tenant_id, + ) + + from extensions.ext_database import db + + db.session.add(end_user) + db.session.commit() + + return end_user + + def _create_test_conversation(self, db_session_with_containers, app, user, fake): + """ + Helper method to create a test conversation for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: App instance + user: User instance (Account or EndUser) + fake: Faker instance + + Returns: + Conversation: Created conversation instance + """ + conversation = Conversation( + app_id=app.id, + app_model_config_id=app.app_model_config_id, + model_provider="openai", + model_id="gpt-3.5-turbo", + mode="chat", + name=fake.sentence(nb_words=3), + summary=fake.text(max_nb_chars=100), + inputs={}, + introduction=fake.text(max_nb_chars=200), + system_instruction=fake.text(max_nb_chars=300), + system_instruction_tokens=50, + status="normal", + invoke_from=InvokeFrom.WEB_APP.value, + from_source="console" if isinstance(user, Account) else "api", + from_end_user_id=user.id if isinstance(user, EndUser) else None, + from_account_id=user.id if isinstance(user, Account) else None, + dialogue_count=0, + is_deleted=False, + ) + + from extensions.ext_database import db + + db.session.add(conversation) + db.session.commit() + + return conversation + + def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful pagination by last ID with basic parameters. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create multiple conversations + conversations = [] + for i in range(5): + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + conversations.append(conversation) + + # Test pagination without pinned filter + result = WebConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=3, + invoke_from=InvokeFrom.WEB_APP, + pinned=None, + sort_by="-updated_at", + ) + + # Verify results + assert result.limit == 3 + assert len(result.data) == 3 + assert result.has_more is True + + # Verify conversations are in descending order by updated_at + assert result.data[0].updated_at >= result.data[1].updated_at + assert result.data[1].updated_at >= result.data[2].updated_at + + def test_pagination_by_last_id_with_pinned_filter( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test pagination by last ID with pinned conversation filter. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create conversations + conversations = [] + for i in range(5): + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + conversations.append(conversation) + + # Pin some conversations + pinned_conversation1 = PinnedConversation( + app_id=app.id, + conversation_id=conversations[0].id, + created_by_role="account", + created_by=account.id, + ) + pinned_conversation2 = PinnedConversation( + app_id=app.id, + conversation_id=conversations[2].id, + created_by_role="account", + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(pinned_conversation1) + db.session.add(pinned_conversation2) + db.session.commit() + + # Test pagination with pinned filter + result = WebConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + pinned=True, + sort_by="-updated_at", + ) + + # Verify only pinned conversations are returned + assert result.limit == 10 + assert len(result.data) == 2 + assert result.has_more is False + + # Verify the returned conversations are the pinned ones + returned_ids = [conv.id for conv in result.data] + expected_ids = [conversations[0].id, conversations[2].id] + assert set(returned_ids) == set(expected_ids) + + def test_pagination_by_last_id_with_unpinned_filter( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test pagination by last ID with unpinned conversation filter. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create conversations + conversations = [] + for i in range(5): + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + conversations.append(conversation) + + # Pin one conversation + pinned_conversation = PinnedConversation( + app_id=app.id, + conversation_id=conversations[0].id, + created_by_role="account", + created_by=account.id, + ) + + from extensions.ext_database import db + + db.session.add(pinned_conversation) + db.session.commit() + + # Test pagination with unpinned filter + result = WebConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + pinned=False, + sort_by="-updated_at", + ) + + # Verify unpinned conversations are returned (should be 4 out of 5) + assert result.limit == 10 + assert len(result.data) == 4 + assert result.has_more is False + + # Verify the pinned conversation is not in the results + returned_ids = [conv.id for conv in result.data] + assert conversations[0].id not in returned_ids + + # Verify all other conversations are in the results + expected_unpinned_ids = [conv.id for conv in conversations[1:]] + assert set(returned_ids) == set(expected_unpinned_ids) + + def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful pinning of a conversation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Pin the conversation + WebConversationService.pin(app, conversation.id, account) + + # Verify the conversation was pinned + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is not None + assert pinned_conversation.app_id == app.id + assert pinned_conversation.conversation_id == conversation.id + assert pinned_conversation.created_by_role == "account" + assert pinned_conversation.created_by == account.id + + def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test pinning a conversation that is already pinned (should not create duplicate). + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Pin the conversation first time + WebConversationService.pin(app, conversation.id, account) + + # Pin the conversation again + WebConversationService.pin(app, conversation.id, account) + + # Verify only one pinned conversation record exists + from extensions.ext_database import db + + pinned_conversations = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .all() + ) + + assert len(pinned_conversations) == 1 + + def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test pinning a conversation with an end user. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create an end user + end_user = self._create_test_end_user(db_session_with_containers, app) + + # Create a conversation for the end user + conversation = self._create_test_conversation(db_session_with_containers, app, end_user, fake) + + # Pin the conversation + WebConversationService.pin(app, conversation.id, end_user) + + # Verify the conversation was pinned + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "end_user", + PinnedConversation.created_by == end_user.id, + ) + .first() + ) + + assert pinned_conversation is not None + assert pinned_conversation.app_id == app.id + assert pinned_conversation.conversation_id == conversation.id + assert pinned_conversation.created_by_role == "end_user" + assert pinned_conversation.created_by == end_user.id + + def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful unpinning of a conversation. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Pin the conversation first + WebConversationService.pin(app, conversation.id, account) + + # Verify it was pinned + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is not None + + # Unpin the conversation + WebConversationService.unpin(app, conversation.id, account) + + # Verify it was unpinned + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is None + + def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test unpinning a conversation that is not pinned (should not cause error). + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Try to unpin a conversation that was never pinned + WebConversationService.unpin(app, conversation.id, account) + + # Verify no pinned conversation record exists + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is None + + def test_pagination_by_last_id_user_required_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test that pagination_by_last_id raises ValueError when user is None. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Test with None user + with pytest.raises(ValueError, match="User is required"): + WebConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=None, + last_id=None, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + pinned=None, + sort_by="-updated_at", + ) + + def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that pin method returns early when user is None. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Try to pin with None user + WebConversationService.pin(app, conversation.id, None) + + # Verify no pinned conversation was created + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + ) + .first() + ) + + assert pinned_conversation is None + + def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that unpin method returns early when user is None. + """ + fake = Faker() + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + # Create a conversation + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + + # Pin the conversation first + WebConversationService.pin(app, conversation.id, account) + + # Verify it was pinned + from extensions.ext_database import db + + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is not None + + # Try to unpin with None user + WebConversationService.unpin(app, conversation.id, None) + + # Verify the conversation is still pinned + pinned_conversation = ( + db.session.query(PinnedConversation) + .where( + PinnedConversation.app_id == app.id, + PinnedConversation.conversation_id == conversation.id, + PinnedConversation.created_by_role == "account", + PinnedConversation.created_by == account.id, + ) + .first() + ) + + assert pinned_conversation is not None diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py new file mode 100644 index 0000000000..666b083ba6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -0,0 +1,877 @@ +from unittest.mock import patch + +import pytest +from faker import Faker +from werkzeug.exceptions import NotFound, Unauthorized + +from libs.password import hash_password +from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole +from models.model import App, Site +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError +from services.webapp_auth_service import WebAppAuthService, WebAppAuthType + + +class TestWebAppAuthService: + """Integration tests for WebAppAuthService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.webapp_auth_service.PassportService") as mock_passport_service, + patch("services.webapp_auth_service.TokenManager") as mock_token_manager, + patch("services.webapp_auth_service.send_email_code_login_mail_task") as mock_mail_task, + patch("services.webapp_auth_service.AppService") as mock_app_service, + patch("services.webapp_auth_service.EnterpriseService") as mock_enterprise_service, + ): + # Setup default mock returns + mock_passport_service.return_value.issue.return_value = "mock_jwt_token" + mock_token_manager.generate_token.return_value = "mock_token" + mock_token_manager.get_token_data.return_value = {"code": "123456"} + mock_mail_task.delay.return_value = None + mock_app_service.get_app_id_by_code.return_value = "mock_app_id" + mock_enterprise_service.WebAppAuth.get_app_access_mode_by_id.return_value = type( + "MockWebAppAuth", (), {"access_mode": "private"} + )() + mock_enterprise_service.WebAppAuth.get_app_access_mode_by_code.return_value = type( + "MockWebAppAuth", (), {"access_mode": "private"} + )() + + yield { + "passport_service": mock_passport_service, + "token_manager": mock_token_manager, + "mail_task": mock_mail_task, + "app_service": mock_app_service, + "enterprise_service": mock_enterprise_service, + } + + def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account with password for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + tuple: (account, tenant, password) - Created account, tenant and password + """ + fake = Faker() + password = fake.password(length=12) + + # Create account with password + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + # Hash password + salt = b"test_salt_16_bytes" + password_hash = hash_password(password, salt) + + # Convert to base64 for storage + import base64 + + account.password = base64.b64encode(password_hash).decode() + account.password_salt = base64.b64encode(salt).decode() + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant, password + + def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant): + """ + Helper method to create a test app and site for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + tenant: Tenant instance to associate with + + Returns: + tuple: (app, site) - Created app and site instances + """ + fake = Faker() + + # Create app + app = App( + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + enable_site=True, + enable_api=True, + ) + + from extensions.ext_database import db + + db.session.add(app) + db.session.commit() + + # Create site + site = Site( + app_id=app.id, + title=fake.company(), + code=fake.unique.lexify(text="??????"), + description=fake.text(max_nb_chars=100), + default_language="en-US", + status="normal", + customize_token_strategy="not_allow", + ) + db.session.add(site) + db.session.commit() + + return app, site + + def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful authentication with valid email and password. + + This test verifies: + - Proper authentication with valid credentials + - Correct account return + - Database state consistency + """ + # Arrange: Create test data + account, tenant, password = self._create_test_account_with_password( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute authentication + result = WebAppAuthService.authenticate(account.email, password) + + # Assert: Verify successful authentication + assert result is not None + assert result.id == account.id + assert result.email == account.email + assert result.name == account.name + assert result.status == AccountStatus.ACTIVE.value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + assert result.password is not None + assert result.password_salt is not None + + def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with non-existent email. + + This test verifies: + - Proper error handling for non-existent accounts + - Correct exception type and message + """ + # Arrange: Use non-existent email + fake = Faker() + non_existent_email = fake.email() + + # Act & Assert: Verify proper error handling + with pytest.raises(AccountNotFoundError): + WebAppAuthService.authenticate(non_existent_email, "any_password") + + def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with banned account. + + This test verifies: + - Proper error handling for banned accounts + - Correct exception type and message + """ + # Arrange: Create banned account + fake = Faker() + password = fake.password(length=12) + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status=AccountStatus.BANNED.value, + ) + + # Hash password + salt = b"test_salt_16_bytes" + password_hash = hash_password(password, salt) + + # Convert to base64 for storage + import base64 + + account.password = base64.b64encode(password_hash).decode() + account.password_salt = base64.b64encode(salt).decode() + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Act & Assert: Verify proper error handling + with pytest.raises(AccountLoginError) as exc_info: + WebAppAuthService.authenticate(account.email, password) + + assert "Account is banned." in str(exc_info.value) + + def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test authentication with invalid password. + + This test verifies: + - Proper error handling for invalid passwords + - Correct exception type and message + """ + # Arrange: Create account with password + account, tenant, correct_password = self._create_test_account_with_password( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act & Assert: Verify proper error handling with wrong password + with pytest.raises(AccountPasswordError) as exc_info: + WebAppAuthService.authenticate(account.email, "wrong_password") + + assert "Invalid email or password." in str(exc_info.value) + + def test_authenticate_account_without_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test authentication for account without password. + + This test verifies: + - Proper error handling for accounts without password + - Correct exception type and message + """ + # Arrange: Create account without password + fake = Faker() + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Act & Assert: Verify proper error handling + with pytest.raises(AccountPasswordError) as exc_info: + WebAppAuthService.authenticate(account.email, "any_password") + + assert "Invalid email or password." in str(exc_info.value) + + def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful login and JWT token generation. + + This test verifies: + - Proper JWT token generation + - Correct token format and content + - Mock service integration + """ + # Arrange: Create test account + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute login + result = WebAppAuthService.login(account) + + # Assert: Verify successful login + assert result is not None + assert result == "mock_jwt_token" + + # Verify mock service was called correctly + mock_external_service_dependencies["passport_service"].return_value.issue.assert_called_once() + call_args = mock_external_service_dependencies["passport_service"].return_value.issue.call_args[0][0] + + assert call_args["sub"] == "Web API Passport" + assert call_args["user_id"] == account.id + assert call_args["session_id"] == account.email + assert call_args["token_source"] == "webapp_login_token" + assert call_args["auth_type"] == "internal" + assert "exp" in call_args + + def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful user retrieval through email. + + This test verifies: + - Proper user retrieval by email + - Correct account return + - Database state consistency + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute user retrieval + result = WebAppAuthService.get_user_through_email(account.email) + + # Assert: Verify successful retrieval + assert result is not None + assert result.id == account.id + assert result.email == account.email + assert result.name == account.name + assert result.status == AccountStatus.ACTIVE.value + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + + def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test user retrieval with non-existent email. + + This test verifies: + - Proper handling for non-existent users + - Correct return value (None) + """ + # Arrange: Use non-existent email + fake = Faker() + non_existent_email = fake.email() + + # Act: Execute user retrieval + result = WebAppAuthService.get_user_through_email(non_existent_email) + + # Assert: Verify proper handling + assert result is None + + def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test user retrieval with banned account. + + This test verifies: + - Proper error handling for banned accounts + - Correct exception type and message + """ + # Arrange: Create banned account + fake = Faker() + + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status=AccountStatus.BANNED.value, + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Act & Assert: Verify proper error handling + with pytest.raises(Unauthorized) as exc_info: + WebAppAuthService.get_user_through_email(account.email) + + assert "Account is banned." in str(exc_info.value) + + def test_send_email_code_login_email_with_account( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test sending email code login email with account. + + This test verifies: + - Proper email code generation + - Token generation with correct data + - Mail task scheduling + - Mock service integration + """ + # Arrange: Create test account + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Act: Execute email code login email sending + result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US") + + # Assert: Verify successful email sending + assert result is not None + assert result == "mock_token" + + # Verify mock services were called correctly + mock_external_service_dependencies["token_manager"].generate_token.assert_called_once() + mock_external_service_dependencies["mail_task"].delay.assert_called_once() + + # Verify token generation parameters + token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args + assert token_call_args[1]["account"] == account + assert token_call_args[1]["email"] == account.email + assert token_call_args[1]["token_type"] == "email_code_login" + assert "code" in token_call_args[1]["additional_data"] + + # Verify mail task parameters + mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args + assert mail_call_args[1]["language"] == "en-US" + assert mail_call_args[1]["to"] == account.email + assert "code" in mail_call_args[1] + + def test_send_email_code_login_email_with_email_only( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test sending email code login email with email only. + + This test verifies: + - Proper email code generation without account + - Token generation with email only + - Mail task scheduling + - Mock service integration + """ + # Arrange: Use test email + fake = Faker() + test_email = fake.email() + + # Act: Execute email code login email sending + result = WebAppAuthService.send_email_code_login_email(email=test_email, language="zh-Hans") + + # Assert: Verify successful email sending + assert result is not None + assert result == "mock_token" + + # Verify mock services were called correctly + mock_external_service_dependencies["token_manager"].generate_token.assert_called_once() + mock_external_service_dependencies["mail_task"].delay.assert_called_once() + + # Verify token generation parameters + token_call_args = mock_external_service_dependencies["token_manager"].generate_token.call_args + assert token_call_args[1]["account"] is None + assert token_call_args[1]["email"] == test_email + assert token_call_args[1]["token_type"] == "email_code_login" + assert "code" in token_call_args[1]["additional_data"] + + # Verify mail task parameters + mail_call_args = mock_external_service_dependencies["mail_task"].delay.call_args + assert mail_call_args[1]["language"] == "zh-Hans" + assert mail_call_args[1]["to"] == test_email + assert "code" in mail_call_args[1] + + def test_send_email_code_login_email_no_email_provided( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test sending email code login email without providing email. + + This test verifies: + - Proper error handling when no email is provided + - Correct exception type and message + """ + # Arrange: No email provided + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebAppAuthService.send_email_code_login_email() + + assert "Email must be provided." in str(exc_info.value) + + def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful retrieval of email code login data. + + This test verifies: + - Proper token data retrieval + - Correct data format + - Mock service integration + """ + # Arrange: Setup mock return + expected_data = {"code": "123456", "email": "test@example.com"} + mock_external_service_dependencies["token_manager"].get_token_data.return_value = expected_data + + # Act: Execute data retrieval + result = WebAppAuthService.get_email_code_login_data("mock_token") + + # Assert: Verify successful retrieval + assert result is not None + assert result == expected_data + assert result["code"] == "123456" + assert result["email"] == "test@example.com" + + # Verify mock service was called correctly + mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with( + "mock_token", "email_code_login" + ) + + def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test email code login data retrieval when no data exists. + + This test verifies: + - Proper handling when no token data exists + - Correct return value (None) + - Mock service integration + """ + # Arrange: Setup mock return for no data + mock_external_service_dependencies["token_manager"].get_token_data.return_value = None + + # Act: Execute data retrieval + result = WebAppAuthService.get_email_code_login_data("invalid_token") + + # Assert: Verify proper handling + assert result is None + + # Verify mock service was called correctly + mock_external_service_dependencies["token_manager"].get_token_data.assert_called_once_with( + "invalid_token", "email_code_login" + ) + + def test_revoke_email_code_login_token_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful revocation of email code login token. + + This test verifies: + - Proper token revocation + - Mock service integration + """ + # Arrange: Setup mock + + # Act: Execute token revocation + WebAppAuthService.revoke_email_code_login_token("mock_token") + + # Assert: Verify mock service was called correctly + mock_external_service_dependencies["token_manager"].revoke_token.assert_called_once_with( + "mock_token", "email_code_login" + ) + + def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful end user creation. + + This test verifies: + - Proper end user creation with valid app code + - Correct database state after creation + - Proper relationship establishment + - Mock service integration + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + app, site = self._create_test_app_and_site( + db_session_with_containers, mock_external_service_dependencies, tenant + ) + + # Act: Execute end user creation + result = WebAppAuthService.create_end_user(site.code, "test@example.com") + + # Assert: Verify successful creation + assert result is not None + assert result.tenant_id == app.tenant_id + assert result.app_id == app.id + assert result.type == "browser" + assert result.is_anonymous is False + assert result.session_id == "test@example.com" + assert result.name == "enterpriseuser" + assert result.external_user_id == "enterpriseuser" + + # Verify database state + from extensions.ext_database import db + + db.session.refresh(result) + assert result.id is not None + assert result.created_at is not None + assert result.updated_at is not None + + def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test end user creation with non-existent site code. + + This test verifies: + - Proper error handling for non-existent sites + - Correct exception type and message + """ + # Arrange: Use non-existent site code + fake = Faker() + non_existent_code = fake.unique.lexify(text="??????") + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + WebAppAuthService.create_end_user(non_existent_code, "test@example.com") + + assert "Site not found." in str(exc_info.value) + + def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test end user creation when app is not found. + + This test verifies: + - Proper error handling when app is missing + - Correct exception type and message + """ + # Arrange: Create site without app + fake = Faker() + tenant = Tenant( + name=fake.company(), + status="normal", + ) + + from extensions.ext_database import db + + db.session.add(tenant) + db.session.commit() + + site = Site( + app_id="00000000-0000-0000-0000-000000000000", + title=fake.company(), + code=fake.unique.lexify(text="??????"), + description=fake.text(max_nb_chars=100), + default_language="en-US", + status="normal", + customize_token_strategy="not_allow", + ) + db.session.add(site) + db.session.commit() + + # Act & Assert: Verify proper error handling + with pytest.raises(NotFound) as exc_info: + WebAppAuthService.create_end_user(site.code, "test@example.com") + + assert "App not found." in str(exc_info.value) + + def test_is_app_require_permission_check_with_access_mode_private( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test permission check requirement for private access mode. + + This test verifies: + - Proper permission check requirement for private mode + - Correct return value + - Mock service integration + """ + # Arrange: Setup test with private access mode + + # Act: Execute permission check requirement test + result = WebAppAuthService.is_app_require_permission_check(access_mode="private") + + # Assert: Verify correct result + assert result is True + + def test_is_app_require_permission_check_with_access_mode_public( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test permission check requirement for public access mode. + + This test verifies: + - Proper permission check requirement for public mode + - Correct return value + - Mock service integration + """ + # Arrange: Setup test with public access mode + + # Act: Execute permission check requirement test + result = WebAppAuthService.is_app_require_permission_check(access_mode="public") + + # Assert: Verify correct result + assert result is False + + def test_is_app_require_permission_check_with_app_code( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test permission check requirement using app code. + + This test verifies: + - Proper permission check requirement using app code + - Correct return value + - Mock service integration + """ + # Arrange: Setup mock for app service + mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id" + + # Act: Execute permission check requirement test + result = WebAppAuthService.is_app_require_permission_check(app_code="mock_app_code") + + # Assert: Verify correct result + assert result is True + + # Verify mock service was called correctly + mock_external_service_dependencies["app_service"].get_app_id_by_code.assert_called_once_with("mock_app_code") + mock_external_service_dependencies[ + "enterprise_service" + ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id") + + def test_is_app_require_permission_check_no_parameters( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test permission check requirement with no parameters. + + This test verifies: + - Proper error handling when no parameters provided + - Correct exception type and message + """ + # Arrange: No parameters provided + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebAppAuthService.is_app_require_permission_check() + + assert "Either app_code or app_id must be provided." in str(exc_info.value) + + def test_get_app_auth_type_with_access_mode_public( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test app authentication type for public access mode. + + This test verifies: + - Proper authentication type determination for public mode + - Correct return value + - Mock service integration + """ + # Arrange: Setup test with public access mode + + # Act: Execute authentication type determination + result = WebAppAuthService.get_app_auth_type(access_mode="public") + + # Assert: Verify correct result + assert result == WebAppAuthType.PUBLIC + + def test_get_app_auth_type_with_access_mode_private( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test app authentication type for private access mode. + + This test verifies: + - Proper authentication type determination for private mode + - Correct return value + - Mock service integration + """ + # Arrange: Setup test with private access mode + + # Act: Execute authentication type determination + result = WebAppAuthService.get_app_auth_type(access_mode="private") + + # Assert: Verify correct result + assert result == WebAppAuthType.INTERNAL + + def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app authentication type using app code. + + This test verifies: + - Proper authentication type determination using app code + - Correct return value + - Mock service integration + """ + # Arrange: Setup mock for enterprise service + mock_webapp_auth = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})() + mock_external_service_dependencies[ + "enterprise_service" + ].WebAppAuth.get_app_access_mode_by_code.return_value = mock_webapp_auth + + # Act: Execute authentication type determination + result = WebAppAuthService.get_app_auth_type(app_code="mock_app_code") + + # Assert: Verify correct result + assert result == WebAppAuthType.EXTERNAL + + # Verify mock service was called correctly + mock_external_service_dependencies[ + "enterprise_service" + ].WebAppAuth.get_app_access_mode_by_code.assert_called_once_with("mock_app_code") + + def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test app authentication type with no parameters. + + This test verifies: + - Proper error handling when no parameters provided + - Correct exception type and message + """ + # Arrange: No parameters provided + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebAppAuthService.get_app_auth_type() + + assert "Either app_code or access_mode must be provided." in str(exc_info.value) diff --git a/api/tests/test_containers_integration_tests/services/test_website_service.py b/api/tests/test_containers_integration_tests/services/test_website_service.py new file mode 100644 index 0000000000..ec2f1556af --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_website_service.py @@ -0,0 +1,1437 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from services.website_service import ( + CrawlOptions, + ScrapeRequest, + WebsiteCrawlApiRequest, + WebsiteCrawlStatusApiRequest, + WebsiteService, +) + + +class TestWebsiteService: + """Integration tests for WebsiteService using testcontainers.""" + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.website_service.ApiKeyAuthService") as mock_api_key_auth_service, + patch("services.website_service.FirecrawlApp") as mock_firecrawl_app, + patch("services.website_service.WaterCrawlProvider") as mock_watercrawl_provider, + patch("services.website_service.requests") as mock_requests, + patch("services.website_service.redis_client") as mock_redis_client, + patch("services.website_service.storage") as mock_storage, + patch("services.website_service.encrypter") as mock_encrypter, + ): + # Setup default mock returns + mock_api_key_auth_service.get_auth_credentials.return_value = { + "config": {"api_key": "encrypted_api_key", "base_url": "https://api.example.com"} + } + mock_encrypter.decrypt_token.return_value = "decrypted_api_key" + + # Mock FirecrawlApp + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.crawl_url.return_value = "test_job_id_123" + mock_firecrawl_instance.check_crawl_status.return_value = { + "status": "completed", + "total": 5, + "current": 5, + "data": [{"source_url": "https://example.com", "title": "Test Page"}], + } + mock_firecrawl_app.return_value = mock_firecrawl_instance + + # Mock WaterCrawlProvider + mock_watercrawl_instance = MagicMock() + mock_watercrawl_instance.crawl_url.return_value = {"status": "active", "job_id": "watercrawl_job_123"} + mock_watercrawl_instance.get_crawl_status.return_value = { + "status": "completed", + "job_id": "watercrawl_job_123", + "total": 3, + "current": 3, + "data": [], + } + mock_watercrawl_instance.get_crawl_url_data.return_value = { + "title": "WaterCrawl Page", + "source_url": "https://example.com", + "description": "Test description", + "markdown": "# Test Content", + } + mock_watercrawl_instance.scrape_url.return_value = { + "title": "Scraped Page", + "content": "Test content", + "url": "https://example.com", + } + mock_watercrawl_provider.return_value = mock_watercrawl_instance + + # Mock requests + mock_response = MagicMock() + mock_response.json.return_value = {"code": 200, "data": {"taskId": "jina_job_123"}} + mock_requests.get.return_value = mock_response + mock_requests.post.return_value = mock_response + + # Mock Redis + mock_redis_client.setex.return_value = None + mock_redis_client.get.return_value = str(datetime.now().timestamp()) + mock_redis_client.delete.return_value = None + + # Mock Storage + mock_storage.exists.return_value = False + mock_storage.load_once.return_value = None + + yield { + "api_key_auth_service": mock_api_key_auth_service, + "firecrawl_app": mock_firecrawl_app, + "watercrawl_provider": mock_watercrawl_provider, + "requests": mock_requests, + "redis_client": mock_redis_client, + "storage": mock_storage, + "encrypter": mock_encrypter, + } + + def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies): + """ + Helper method to create a test account with proper tenant setup. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies + + Returns: + Account: Created account instance + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant for the account + tenant = Tenant( + name=fake.company(), + status="normal", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account + + def test_document_create_args_validate_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful argument validation for document creation. + + This test verifies: + - Valid arguments are accepted without errors + - All required fields are properly validated + - Optional fields are handled correctly + """ + # Arrange: Prepare valid arguments + valid_args = { + "provider": "firecrawl", + "url": "https://example.com", + "options": { + "limit": 5, + "crawl_sub_pages": True, + "only_main_content": False, + "includes": "blog,news", + "excludes": "admin,private", + "max_depth": 3, + "use_sitemap": True, + }, + } + + # Act: Validate arguments + WebsiteService.document_create_args_validate(valid_args) + + # Assert: No exception should be raised + # If we reach here, validation passed successfully + + def test_document_create_args_validate_missing_provider( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test argument validation fails when provider is missing. + + This test verifies: + - Missing provider raises ValueError + - Proper error message is provided + - Validation stops at first missing required field + """ + # Arrange: Prepare arguments without provider + invalid_args = {"url": "https://example.com", "options": {"limit": 5, "crawl_sub_pages": True}} + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.document_create_args_validate(invalid_args) + + assert "Provider is required" in str(exc_info.value) + + def test_document_create_args_validate_missing_url( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test argument validation fails when URL is missing. + + This test verifies: + - Missing URL raises ValueError + - Proper error message is provided + - Validation continues after provider check + """ + # Arrange: Prepare arguments without URL + invalid_args = {"provider": "firecrawl", "options": {"limit": 5, "crawl_sub_pages": True}} + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.document_create_args_validate(invalid_args) + + assert "URL is required" in str(exc_info.value) + + def test_crawl_url_firecrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful URL crawling with Firecrawl provider. + + This test verifies: + - Firecrawl provider is properly initialized + - API credentials are retrieved and decrypted + - Crawl parameters are correctly formatted + - Job ID is returned with active status + - Redis cache is properly set + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + fake = Faker() + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlApiRequest( + provider="firecrawl", + url="https://example.com", + options={ + "limit": 10, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "blog,news", + "excludes": "admin,private", + "max_depth": 2, + "use_sitemap": True, + }, + ) + + # Act: Execute crawl operation + result = WebsiteService.crawl_url(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "test_job_id_123" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "firecrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + mock_external_service_dependencies["firecrawl_app"].assert_called_once_with( + api_key="decrypted_api_key", base_url="https://api.example.com" + ) + + # Verify Redis cache was set + mock_external_service_dependencies["redis_client"].setex.assert_called_once() + + def test_crawl_url_watercrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful URL crawling with WaterCrawl provider. + + This test verifies: + - WaterCrawl provider is properly initialized + - API credentials are retrieved and decrypted + - Crawl options are correctly passed to provider + - Provider returns expected response format + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlApiRequest( + provider="watercrawl", + url="https://example.com", + options={ + "limit": 5, + "crawl_sub_pages": False, + "only_main_content": False, + "includes": None, + "excludes": None, + "max_depth": None, + "use_sitemap": False, + }, + ) + + # Act: Execute crawl operation + result = WebsiteService.crawl_url(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "watercrawl_job_123" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "watercrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + mock_external_service_dependencies["watercrawl_provider"].assert_called_once_with( + api_key="decrypted_api_key", base_url="https://api.example.com" + ) + + def test_crawl_url_jinareader_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful URL crawling with JinaReader provider. + + This test verifies: + - JinaReader provider handles single page crawling + - API credentials are retrieved and decrypted + - HTTP requests are made with proper headers + - Response is properly parsed and returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request for single page crawling + api_request = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={ + "limit": 1, + "crawl_sub_pages": False, + "only_main_content": True, + "includes": None, + "excludes": None, + "max_depth": None, + "use_sitemap": False, + }, + ) + + # Act: Execute crawl operation + result = WebsiteService.crawl_url(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["data"] is not None + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "jinareader" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify HTTP request was made + mock_external_service_dependencies["requests"].get.assert_called_once_with( + "https://r.jina.ai/https://example.com", + headers={"Accept": "application/json", "Authorization": "Bearer decrypted_api_key"}, + ) + + def test_crawl_url_invalid_provider(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test crawl operation fails with invalid provider. + + This test verifies: + - Invalid provider raises ValueError + - Proper error message is provided + - Service handles unsupported providers gracefully + """ + # Arrange: Create test account and prepare request with invalid provider + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request with invalid provider + api_request = WebsiteCrawlApiRequest( + provider="invalid_provider", + url="https://example.com", + options={"limit": 5, "crawl_sub_pages": False, "only_main_content": False}, + ) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.crawl_url(api_request) + + assert "Invalid provider" in str(exc_info.value) + + def test_get_crawl_status_firecrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful crawl status retrieval with Firecrawl provider. + + This test verifies: + - Firecrawl status is properly retrieved + - API credentials are retrieved and decrypted + - Status data includes all required fields + - Redis cache is properly managed for completed jobs + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") + + # Act: Get crawl status + result = WebsiteService.get_crawl_status_typed(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "completed" + assert result["job_id"] == "test_job_id_123" + assert result["total"] == 5 + assert result["current"] == 5 + assert "data" in result + assert "time_consuming" in result + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "firecrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify Redis cache was accessed and cleaned up + mock_external_service_dependencies["redis_client"].get.assert_called_once() + mock_external_service_dependencies["redis_client"].delete.assert_called_once() + + def test_get_crawl_status_watercrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful crawl status retrieval with WaterCrawl provider. + + This test verifies: + - WaterCrawl status is properly retrieved + - API credentials are retrieved and decrypted + - Provider returns expected status format + - All required status fields are present + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="watercrawl", job_id="watercrawl_job_123") + + # Act: Get crawl status + result = WebsiteService.get_crawl_status_typed(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "completed" + assert result["job_id"] == "watercrawl_job_123" + assert result["total"] == 3 + assert result["current"] == 3 + assert "data" in result + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "watercrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + def test_get_crawl_status_jinareader_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful crawl status retrieval with JinaReader provider. + + This test verifies: + - JinaReader status is properly retrieved + - API credentials are retrieved and decrypted + - HTTP requests are made with proper parameters + - Status data is properly formatted and returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="jinareader", job_id="jina_job_123") + + # Act: Get crawl status + result = WebsiteService.get_crawl_status_typed(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "jina_job_123" + assert "total" in result + assert "current" in result + assert "data" in result + assert "time_consuming" in result + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "jinareader" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify HTTP request was made + mock_external_service_dependencies["requests"].post.assert_called_once() + + def test_get_crawl_status_invalid_provider(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test crawl status retrieval fails with invalid provider. + + This test verifies: + - Invalid provider raises ValueError + - Proper error message is provided + - Service handles unsupported providers gracefully + """ + # Arrange: Create test account and prepare request with invalid provider + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request with invalid provider + api_request = WebsiteCrawlStatusApiRequest(provider="invalid_provider", job_id="test_job_id_123") + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_status_typed(api_request) + + assert "Invalid provider" in str(exc_info.value) + + def test_get_crawl_status_missing_credentials(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test crawl status retrieval fails when credentials are missing. + + This test verifies: + - Missing credentials raises ValueError + - Proper error message is provided + - Service handles authentication failures gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Mock missing credentials + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = None + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_status_typed(api_request) + + assert "No valid credentials found for the provider" in str(exc_info.value) + + def test_get_crawl_status_missing_api_key(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test crawl status retrieval fails when API key is missing from config. + + This test verifies: + - Missing API key raises ValueError + - Proper error message is provided + - Service handles configuration failures gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Mock missing API key in config + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.return_value = { + "config": {"base_url": "https://api.example.com"} + } + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="test_job_id_123") + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_status_typed(api_request) + + assert "API key not found in configuration" in str(exc_info.value) + + def test_get_crawl_url_data_firecrawl_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful URL data retrieval with Firecrawl provider. + + This test verifies: + - Firecrawl URL data is properly retrieved + - API credentials are retrieved and decrypted + - Data is returned for matching URL + - Storage fallback works when needed + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock storage to return existing data + mock_external_service_dependencies["storage"].exists.return_value = True + mock_external_service_dependencies["storage"].load_once.return_value = ( + b"[" + b'{"source_url": "https://example.com", "title": "Test Page", ' + b'"description": "Test Description", "markdown": "# Test Content"}' + b"]" + ) + + # Act: Get URL data + result = WebsiteService.get_crawl_url_data( + job_id="test_job_id_123", + provider="firecrawl", + url="https://example.com", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["source_url"] == "https://example.com" + assert result["title"] == "Test Page" + assert result["description"] == "Test Description" + assert result["markdown"] == "# Test Content" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "firecrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify storage was accessed + mock_external_service_dependencies["storage"].exists.assert_called_once() + mock_external_service_dependencies["storage"].load_once.assert_called_once() + + def test_get_crawl_url_data_watercrawl_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL data retrieval with WaterCrawl provider. + + This test verifies: + - WaterCrawl URL data is properly retrieved + - API credentials are retrieved and decrypted + - Provider returns expected data format + - All required data fields are present + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Act: Get URL data + result = WebsiteService.get_crawl_url_data( + job_id="watercrawl_job_123", + provider="watercrawl", + url="https://example.com", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "WaterCrawl Page" + assert result["source_url"] == "https://example.com" + assert result["description"] == "Test description" + assert result["markdown"] == "# Test Content" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "watercrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + def test_get_crawl_url_data_jinareader_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL data retrieval with JinaReader provider. + + This test verifies: + - JinaReader URL data is properly retrieved + - API credentials are retrieved and decrypted + - HTTP requests are made with proper parameters + - Data is properly formatted and returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock successful response for JinaReader + mock_response = MagicMock() + mock_response.json.return_value = { + "code": 200, + "data": { + "title": "JinaReader Page", + "url": "https://example.com", + "description": "Test description", + "content": "# Test Content", + }, + } + mock_external_service_dependencies["requests"].get.return_value = mock_response + + # Act: Get URL data without job_id (single page scraping) + result = WebsiteService.get_crawl_url_data( + job_id="", provider="jinareader", url="https://example.com", tenant_id=account.current_tenant.id + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "JinaReader Page" + assert result["url"] == "https://example.com" + assert result["description"] == "Test description" + assert result["content"] == "# Test Content" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "jinareader" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify HTTP request was made + mock_external_service_dependencies["requests"].get.assert_called_once_with( + "https://r.jina.ai/https://example.com", + headers={"Accept": "application/json", "Authorization": "Bearer decrypted_api_key"}, + ) + + def test_get_scrape_url_data_firecrawl_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL scraping with Firecrawl provider. + + This test verifies: + - Firecrawl scraping is properly executed + - API credentials are retrieved and decrypted + - Scraping parameters are correctly passed + - Scraped data is returned in expected format + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock FirecrawlApp scraping response + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.scrape_url.return_value = { + "title": "Scraped Page Title", + "content": "This is the scraped content", + "url": "https://example.com", + "description": "Page description", + } + mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance + + # Act: Scrape URL + result = WebsiteService.get_scrape_url_data( + provider="firecrawl", url="https://example.com", tenant_id=account.current_tenant.id, only_main_content=True + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "Scraped Page Title" + assert result["content"] == "This is the scraped content" + assert result["url"] == "https://example.com" + assert result["description"] == "Page description" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "firecrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify FirecrawlApp was called with correct parameters + mock_external_service_dependencies["firecrawl_app"].assert_called_once_with( + api_key="decrypted_api_key", base_url="https://api.example.com" + ) + mock_firecrawl_instance.scrape_url.assert_called_once_with( + url="https://example.com", params={"onlyMainContent": True} + ) + + def test_get_scrape_url_data_watercrawl_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL scraping with WaterCrawl provider. + + This test verifies: + - WaterCrawl scraping is properly executed + - API credentials are retrieved and decrypted + - Provider returns expected scraping format + - All required data fields are present + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Act: Scrape URL + result = WebsiteService.get_scrape_url_data( + provider="watercrawl", + url="https://example.com", + tenant_id=account.current_tenant.id, + only_main_content=False, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "Scraped Page" + assert result["content"] == "Test content" + assert result["url"] == "https://example.com" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "watercrawl" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify WaterCrawlProvider was called with correct parameters + mock_external_service_dependencies["watercrawl_provider"].assert_called_once_with( + api_key="decrypted_api_key", base_url="https://api.example.com" + ) + + def test_get_scrape_url_data_invalid_provider(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test URL scraping fails with invalid provider. + + This test verifies: + - Invalid provider raises ValueError + - Proper error message is provided + - Service handles unsupported providers gracefully + """ + # Arrange: Create test account and prepare request with invalid provider + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_scrape_url_data( + provider="invalid_provider", + url="https://example.com", + tenant_id=account.current_tenant.id, + only_main_content=False, + ) + + assert "Invalid provider" in str(exc_info.value) + + def test_crawl_options_include_exclude_paths(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test CrawlOptions include and exclude path methods. + + This test verifies: + - Include paths are properly parsed from comma-separated string + - Exclude paths are properly parsed from comma-separated string + - Empty or None values are handled correctly + - Path lists are returned in expected format + """ + # Arrange: Create CrawlOptions with various path configurations + options_with_paths = CrawlOptions(includes="blog,news,articles", excludes="admin,private,test") + + options_without_paths = CrawlOptions(includes=None, excludes="") + + # Act: Get include and exclude paths + include_paths = options_with_paths.get_include_paths() + exclude_paths = options_with_paths.get_exclude_paths() + + empty_include_paths = options_without_paths.get_include_paths() + empty_exclude_paths = options_without_paths.get_exclude_paths() + + # Assert: Verify path parsing + assert include_paths == ["blog", "news", "articles"] + assert exclude_paths == ["admin", "private", "test"] + assert empty_include_paths == [] + assert empty_exclude_paths == [] + + def test_website_crawl_api_request_conversion(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test WebsiteCrawlApiRequest conversion to CrawlRequest. + + This test verifies: + - API request is properly converted to internal CrawlRequest + - All options are correctly mapped + - Default values are applied when options are missing + - Conversion maintains data integrity + """ + # Arrange: Create API request with various options + api_request = WebsiteCrawlApiRequest( + provider="firecrawl", + url="https://example.com", + options={ + "limit": 10, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "blog,news", + "excludes": "admin,private", + "max_depth": 3, + "use_sitemap": False, + }, + ) + + # Act: Convert to CrawlRequest + crawl_request = api_request.to_crawl_request() + + # Assert: Verify conversion + assert crawl_request.url == "https://example.com" + assert crawl_request.provider == "firecrawl" + assert crawl_request.options.limit == 10 + assert crawl_request.options.crawl_sub_pages is True + assert crawl_request.options.only_main_content is True + assert crawl_request.options.includes == "blog,news" + assert crawl_request.options.excludes == "admin,private" + assert crawl_request.options.max_depth == 3 + assert crawl_request.options.use_sitemap is False + + def test_website_crawl_api_request_from_args(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test WebsiteCrawlApiRequest creation from Flask arguments. + + This test verifies: + - Request is properly created from parsed arguments + - Required fields are validated + - Optional fields are handled correctly + - Validation errors are properly raised + """ + # Arrange: Prepare valid arguments + valid_args = {"provider": "watercrawl", "url": "https://example.com", "options": {"limit": 5}} + + # Act: Create request from args + request = WebsiteCrawlApiRequest.from_args(valid_args) + + # Assert: Verify request creation + assert request.provider == "watercrawl" + assert request.url == "https://example.com" + assert request.options == {"limit": 5} + + # Test missing provider + invalid_args = {"url": "https://example.com", "options": {}} + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlApiRequest.from_args(invalid_args) + assert "Provider is required" in str(exc_info.value) + + # Test missing URL + invalid_args = {"provider": "watercrawl", "options": {}} + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlApiRequest.from_args(invalid_args) + assert "URL is required" in str(exc_info.value) + + # Test missing options + invalid_args = {"provider": "watercrawl", "url": "https://example.com"} + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlApiRequest.from_args(invalid_args) + assert "Options are required" in str(exc_info.value) + + def test_crawl_url_jinareader_sub_pages_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test successful URL crawling with JinaReader provider for sub-pages. + + This test verifies: + - JinaReader provider handles sub-page crawling correctly + - HTTP POST request is made with proper parameters + - Job ID is returned for multi-page crawling + - All required parameters are passed correctly + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request for sub-page crawling + api_request = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={ + "limit": 5, + "crawl_sub_pages": True, + "only_main_content": False, + "includes": None, + "excludes": None, + "max_depth": None, + "use_sitemap": True, + }, + ) + + # Act: Execute crawl operation + result = WebsiteService.crawl_url(api_request) + + # Assert: Verify successful operation + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "jina_job_123" + + # Verify external service interactions + mock_external_service_dependencies["api_key_auth_service"].get_auth_credentials.assert_called_once_with( + account.current_tenant.id, "website", "jinareader" + ) + mock_external_service_dependencies["encrypter"].decrypt_token.assert_called_once_with( + tenant_id=account.current_tenant.id, token="encrypted_api_key" + ) + + # Verify HTTP POST request was made for sub-page crawling + mock_external_service_dependencies["requests"].post.assert_called_once_with( + "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", + json={"url": "https://example.com", "maxPages": 5, "useSitemap": True}, + headers={"Content-Type": "application/json", "Authorization": "Bearer decrypted_api_key"}, + ) + + def test_crawl_url_jinareader_failed_response(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test JinaReader crawling fails when API returns error. + + This test verifies: + - Failed API response raises ValueError + - Proper error message is provided + - Service handles API failures gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock failed response + mock_failed_response = MagicMock() + mock_failed_response.json.return_value = {"code": 500, "error": "Internal server error"} + mock_external_service_dependencies["requests"].get.return_value = mock_failed_response + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"limit": 1, "crawl_sub_pages": False, "only_main_content": True}, + ) + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.crawl_url(api_request) + + assert "Failed to crawl" in str(exc_info.value) + + def test_get_crawl_status_firecrawl_active_job( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Firecrawl status retrieval for active (not completed) job. + + This test verifies: + - Active job status is properly returned + - Redis cache is not deleted for active jobs + - Time consuming is not calculated for active jobs + - All required status fields are present + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock active job status + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.check_crawl_status.return_value = { + "status": "active", + "total": 10, + "current": 3, + "data": [], + } + mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance + + # Mock current_user for the test + with patch("services.website_service.current_user") as mock_current_user: + mock_current_user.current_tenant_id = account.current_tenant.id + + # Create API request + api_request = WebsiteCrawlStatusApiRequest(provider="firecrawl", job_id="active_job_123") + + # Act: Get crawl status + result = WebsiteService.get_crawl_status_typed(api_request) + + # Assert: Verify active job status + assert result is not None + assert result["status"] == "active" + assert result["job_id"] == "active_job_123" + assert result["total"] == 10 + assert result["current"] == 3 + assert "data" in result + assert "time_consuming" not in result + + # Verify Redis cache was not accessed for active jobs + mock_external_service_dependencies["redis_client"].get.assert_not_called() + mock_external_service_dependencies["redis_client"].delete.assert_not_called() + + def test_get_crawl_url_data_firecrawl_storage_fallback( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Firecrawl URL data retrieval with storage fallback. + + This test verifies: + - Storage fallback works when storage has data + - API call is not made when storage has data + - Data is properly parsed from storage + - Correct URL data is returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock storage to return existing data + mock_external_service_dependencies["storage"].exists.return_value = True + mock_external_service_dependencies["storage"].load_once.return_value = ( + b"[" + b'{"source_url": "https://example.com/page1", ' + b'"title": "Page 1", "description": "Description 1", "markdown": "# Page 1"}, ' + b'{"source_url": "https://example.com/page2", "title": "Page 2", ' + b'"description": "Description 2", "markdown": "# Page 2"}' + b"]" + ) + + # Act: Get URL data for specific URL + result = WebsiteService.get_crawl_url_data( + job_id="test_job_id_123", + provider="firecrawl", + url="https://example.com/page1", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["source_url"] == "https://example.com/page1" + assert result["title"] == "Page 1" + assert result["description"] == "Description 1" + assert result["markdown"] == "# Page 1" + + # Verify storage was accessed + mock_external_service_dependencies["storage"].exists.assert_called_once() + mock_external_service_dependencies["storage"].load_once.assert_called_once() + + def test_get_crawl_url_data_firecrawl_api_fallback( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Firecrawl URL data retrieval with API fallback when storage is empty. + + This test verifies: + - API fallback works when storage has no data + - FirecrawlApp is called to get data + - Completed job status is checked + - Data is returned from API response + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock storage to return no data + mock_external_service_dependencies["storage"].exists.return_value = False + + # Mock FirecrawlApp for API fallback + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.check_crawl_status.return_value = { + "status": "completed", + "data": [ + { + "source_url": "https://example.com/api_page", + "title": "API Page", + "description": "API Description", + "markdown": "# API Content", + } + ], + } + mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance + + # Act: Get URL data + result = WebsiteService.get_crawl_url_data( + job_id="test_job_id_123", + provider="firecrawl", + url="https://example.com/api_page", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["source_url"] == "https://example.com/api_page" + assert result["title"] == "API Page" + assert result["description"] == "API Description" + assert result["markdown"] == "# API Content" + + # Verify API was called + mock_external_service_dependencies["firecrawl_app"].assert_called_once() + + def test_get_crawl_url_data_firecrawl_incomplete_job( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test Firecrawl URL data retrieval fails for incomplete job. + + This test verifies: + - Incomplete job raises ValueError + - Proper error message is provided + - Service handles incomplete jobs gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock storage to return no data + mock_external_service_dependencies["storage"].exists.return_value = False + + # Mock incomplete job status + mock_firecrawl_instance = MagicMock() + mock_firecrawl_instance.check_crawl_status.return_value = {"status": "active", "data": []} + mock_external_service_dependencies["firecrawl_app"].return_value = mock_firecrawl_instance + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_url_data( + job_id="test_job_id_123", + provider="firecrawl", + url="https://example.com/page", + tenant_id=account.current_tenant.id, + ) + + assert "Crawl job is not completed" in str(exc_info.value) + + def test_get_crawl_url_data_jinareader_with_job_id( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test JinaReader URL data retrieval with job ID for multi-page crawling. + + This test verifies: + - JinaReader handles job ID-based data retrieval + - Status check is performed before data retrieval + - Processed data is properly formatted + - Correct URL data is returned + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock successful status response + mock_status_response = MagicMock() + mock_status_response.json.return_value = { + "code": 200, + "data": { + "status": "completed", + "processed": { + "https://example.com/page1": { + "data": { + "title": "Page 1", + "url": "https://example.com/page1", + "description": "Description 1", + "content": "# Content 1", + } + } + }, + }, + } + mock_external_service_dependencies["requests"].post.return_value = mock_status_response + + # Act: Get URL data with job ID + result = WebsiteService.get_crawl_url_data( + job_id="jina_job_123", + provider="jinareader", + url="https://example.com/page1", + tenant_id=account.current_tenant.id, + ) + + # Assert: Verify successful operation + assert result is not None + assert result["title"] == "Page 1" + assert result["url"] == "https://example.com/page1" + assert result["description"] == "Description 1" + assert result["content"] == "# Content 1" + + # Verify HTTP requests were made + assert mock_external_service_dependencies["requests"].post.call_count == 2 + + def test_get_crawl_url_data_jinareader_incomplete_job( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test JinaReader URL data retrieval fails for incomplete job. + + This test verifies: + - Incomplete job raises ValueError + - Proper error message is provided + - Service handles incomplete jobs gracefully + """ + # Arrange: Create test account and prepare request + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + # Mock incomplete job status + mock_status_response = MagicMock() + mock_status_response.json.return_value = {"code": 200, "data": {"status": "active", "processed": {}}} + mock_external_service_dependencies["requests"].post.return_value = mock_status_response + + # Act & Assert: Verify proper error handling + with pytest.raises(ValueError) as exc_info: + WebsiteService.get_crawl_url_data( + job_id="jina_job_123", + provider="jinareader", + url="https://example.com/page", + tenant_id=account.current_tenant.id, + ) + + assert "Crawl job is not completed" in str(exc_info.value) + + def test_crawl_options_default_values(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test CrawlOptions default values and initialization. + + This test verifies: + - Default values are properly set + - Optional fields can be None + - Boolean fields have correct defaults + - Integer fields have correct defaults + """ + # Arrange: Create CrawlOptions with minimal parameters + options = CrawlOptions() + + # Assert: Verify default values + assert options.limit == 1 + assert options.crawl_sub_pages is False + assert options.only_main_content is False + assert options.includes is None + assert options.excludes is None + assert options.max_depth is None + assert options.use_sitemap is True + + # Test with custom values + custom_options = CrawlOptions( + limit=10, + crawl_sub_pages=True, + only_main_content=True, + includes="blog,news", + excludes="admin", + max_depth=3, + use_sitemap=False, + ) + + assert custom_options.limit == 10 + assert custom_options.crawl_sub_pages is True + assert custom_options.only_main_content is True + assert custom_options.includes == "blog,news" + assert custom_options.excludes == "admin" + assert custom_options.max_depth == 3 + assert custom_options.use_sitemap is False + + def test_website_crawl_status_api_request_from_args( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test WebsiteCrawlStatusApiRequest creation from Flask arguments. + + This test verifies: + - Request is properly created from parsed arguments + - Required fields are validated + - Job ID is properly handled + - Validation errors are properly raised + """ + # Arrange: Prepare valid arguments + valid_args = {"provider": "firecrawl"} + job_id = "test_job_123" + + # Act: Create request from args + request = WebsiteCrawlStatusApiRequest.from_args(valid_args, job_id) + + # Assert: Verify request creation + assert request.provider == "firecrawl" + assert request.job_id == "test_job_123" + + # Test missing provider + invalid_args = {} + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlStatusApiRequest.from_args(invalid_args, job_id) + assert "Provider is required" in str(exc_info.value) + + # Test missing job ID + with pytest.raises(ValueError) as exc_info: + WebsiteCrawlStatusApiRequest.from_args(valid_args, "") + assert "Job ID is required" in str(exc_info.value) + + def test_scrape_request_initialization(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test ScrapeRequest dataclass initialization and properties. + + This test verifies: + - ScrapeRequest is properly initialized + - All fields are correctly set + - Boolean field works correctly + - String fields are properly assigned + """ + # Arrange: Create ScrapeRequest + request = ScrapeRequest( + provider="firecrawl", url="https://example.com", tenant_id="tenant_123", only_main_content=True + ) + + # Assert: Verify initialization + assert request.provider == "firecrawl" + assert request.url == "https://example.com" + assert request.tenant_id == "tenant_123" + assert request.only_main_content is True + + # Test with different values + request2 = ScrapeRequest( + provider="watercrawl", url="https://test.com", tenant_id="tenant_456", only_main_content=False + ) + + assert request2.provider == "watercrawl" + assert request2.url == "https://test.com" + assert request2.tenant_id == "tenant_456" + assert request2.only_main_content is False diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py new file mode 100644 index 0000000000..d73fb7e4be --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -0,0 +1,840 @@ +import pytest +from faker import Faker + +from core.variables.segments import StringSegment +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from models import App, Workflow +from models.enums import DraftVariableType +from models.workflow import WorkflowDraftVariable +from services.workflow_draft_variable_service import ( + UpdateNotSupportedError, + WorkflowDraftVariableService, +) + + +def _get_random_variable_name(fake: Faker): + return "".join(fake.random_letters(length=10)) + + +class TestWorkflowDraftVariableService: + """ + Comprehensive integration tests for WorkflowDraftVariableService using testcontainers. + + This test class covers all major functionality of the WorkflowDraftVariableService: + - CRUD operations for workflow draft variables (Create, Read, Update, Delete) + - Variable listing and filtering by type (conversation, system, node) + - Variable updates and resets with proper validation + - Variable deletion operations at different scopes + - Special functionality like prefill and conversation ID retrieval + - Error handling for various edge cases and invalid operations + + All tests use the testcontainers infrastructure to ensure proper database isolation + and realistic testing environment with actual database interactions. + """ + + @pytest.fixture + def mock_external_service_dependencies(self): + """ + Mock setup for external service dependencies. + + WorkflowDraftVariableService doesn't have external dependencies that need mocking, + so this fixture returns an empty dictionary to maintain consistency with other test classes. + This ensures the test structure remains consistent across different service test files. + """ + # WorkflowDraftVariableService doesn't have external dependencies that need mocking + return {} + + def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None): + """ + Helper method to create a test app with realistic data for testing. + + This method creates a complete App instance with all required fields populated + using Faker for generating realistic test data. The app is configured for + workflow mode to support workflow draft variable testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + mock_external_service_dependencies: Mock dependencies (unused in this service) + fake: Faker instance for generating test data, creates new instance if not provided + + Returns: + App: Created test app instance with all required fields populated + """ + fake = fake or Faker() + app = App() + app.id = fake.uuid4() + app.tenant_id = fake.uuid4() + app.name = fake.company() + app.description = fake.text() + app.mode = "workflow" + app.icon_type = "emoji" + app.icon = "🤖" + app.icon_background = "#FFEAD5" + app.enable_site = True + app.enable_api = True + app.created_by = fake.uuid4() + app.updated_by = app.created_by + + from extensions.ext_database import db + + db.session.add(app) + db.session.commit() + return app + + def _create_test_workflow(self, db_session_with_containers, app, fake=None): + """ + Helper method to create a test workflow associated with an app. + + This method creates a Workflow instance using the proper factory method + to ensure all required fields are set correctly. The workflow is configured + as a draft version with basic graph structure for testing workflow variables. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app: The app to associate the workflow with + fake: Faker instance for generating test data, creates new instance if not provided + + Returns: + Workflow: Created test workflow instance with proper configuration + """ + fake = fake or Faker() + workflow = Workflow.new( + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + version="draft", + graph='{"nodes": [], "edges": []}', + features="{}", + created_by=app.created_by, + environment_variables=[], + conversation_variables=[], + ) + from extensions.ext_database import db + + db.session.add(workflow) + db.session.commit() + return workflow + + def _create_test_variable( + self, + db_session_with_containers, + app_id, + node_id, + name, + value, + variable_type: DraftVariableType = DraftVariableType.CONVERSATION, + fake=None, + ): + """ + Helper method to create a test workflow draft variable with proper configuration. + + This method creates different types of variables (conversation, system, node) using + the appropriate factory methods to ensure proper initialization. Each variable type + has specific requirements and this method handles the creation logic for all types. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + app_id: ID of the app to associate the variable with + node_id: ID of the node (or special constants like CONVERSATION_VARIABLE_NODE_ID) + name: Name of the variable for identification + value: StringSegment value for the variable content + variable_type: Type of variable ("conversation", "system", "node") determining creation method + fake: Faker instance for generating test data, creates new instance if not provided + + Returns: + WorkflowDraftVariable: Created test variable instance with proper type configuration + """ + fake = fake or Faker() + if variable_type == "conversation": + # Create conversation variable using the appropriate factory method + variable = WorkflowDraftVariable.new_conversation_variable( + app_id=app_id, + name=name, + value=value, + description=fake.text(max_nb_chars=20), + ) + elif variable_type == "system": + # Create system variable with editable flag and execution context + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app_id, + name=name, + value=value, + node_execution_id=fake.uuid4(), + editable=True, + ) + else: # node variable + # Create node variable with visibility and editability settings + variable = WorkflowDraftVariable.new_node_variable( + app_id=app_id, + node_id=node_id, + name=name, + value=value, + node_execution_id=fake.uuid4(), + visible=True, + editable=True, + ) + from extensions.ext_database import db + + db.session.add(variable) + db.session.commit() + return variable + + def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting a single variable by ID successfully. + + This test verifies that the service can retrieve a specific variable + by its ID and that the returned variable contains the correct data. + It ensures the basic CRUD read operation works correctly for workflow draft variables. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + test_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_variable = service.get_variable(variable.id) + assert retrieved_variable is not None + assert retrieved_variable.id == variable.id + assert retrieved_variable.name == "test_var" + assert retrieved_variable.app_id == app.id + assert retrieved_variable.get_value().value == test_value.value + + def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting a variable that doesn't exist. + + This test verifies that the service returns None when trying to + retrieve a variable with a non-existent ID. This ensures proper + handling of missing data scenarios. + """ + fake = Faker() + non_existent_id = fake.uuid4() + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_variable = service.get_variable(non_existent_id) + assert retrieved_variable is None + + def test_get_draft_variables_by_selectors_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting variables by selectors successfully. + + This test verifies that the service can retrieve multiple variables + using selector pairs (node_id, variable_name) and returns the correct + variables for each selector. This is useful for bulk variable retrieval + operations in workflow execution contexts. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + var1_value = StringSegment(value=fake.word()) + var2_value = StringSegment(value=fake.word()) + var3_value = StringSegment(value=fake.word()) + var1 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var1", var1_value, fake=fake + ) + var2 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "var2", var2_value, fake=fake + ) + var3 = self._create_test_variable( + db_session_with_containers, + app.id, + "test_node_1", + "var3", + var3_value, + variable_type=DraftVariableType.NODE, + fake=fake, + ) + selectors = [ + [CONVERSATION_VARIABLE_NODE_ID, "var1"], + [CONVERSATION_VARIABLE_NODE_ID, "var2"], + ["test_node_1", "var3"], + ] + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors) + assert len(retrieved_variables) == 3 + var_names = [var.name for var in retrieved_variables] + assert "var1" in var_names + assert "var2" in var_names + assert "var3" in var_names + for var in retrieved_variables: + if var.name == "var1": + assert var.get_value().value == var1_value.value + elif var.name == "var2": + assert var.get_value().value == var2_value.value + elif var.name == "var3": + assert var.get_value().value == var3_value.value + + def test_list_variables_without_values_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test listing variables without values successfully with pagination. + + This test verifies that the service can list variables with pagination + and that the returned variables don't include their values (for performance). + This is important for scenarios where only variable metadata is needed + without loading the actual content. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + for i in range(5): + test_value = StringSegment(value=fake.numerify("value######")) + self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + _get_random_variable_name(fake), + test_value, + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_variables_without_values(app.id, page=1, limit=3) + assert result.total == 5 + assert len(result.variables) == 3 + assert result.variables[0].created_at >= result.variables[1].created_at + assert result.variables[1].created_at >= result.variables[2].created_at + for var in result.variables: + assert var.name is not None + assert var.app_id == app.id + + def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test listing variables for a specific node successfully. + + This test verifies that the service can filter and return only + variables associated with a specific node ID. This is crucial for + workflow execution where variables need to be scoped to specific nodes. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + node_id = fake.word() + var1_value = StringSegment(value=fake.word()) + var2_value = StringSegment(value=fake.word()) + var3_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, + app.id, + node_id, + "var1", + var1_value, + variable_type=DraftVariableType.NODE, + fake=fake, + ) + self._create_test_variable( + db_session_with_containers, + app.id, + node_id, + "var2", + var3_value, + variable_type=DraftVariableType.NODE, + fake=fake, + ) + self._create_test_variable( + db_session_with_containers, + app.id, + "other_node", + "var3", + var2_value, + variable_type=DraftVariableType.NODE, + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_node_variables(app.id, node_id) + assert len(result.variables) == 2 + for var in result.variables: + assert var.node_id == node_id + assert var.app_id == app.id + var_names = [var.name for var in result.variables] + assert "var1" in var_names + assert "var2" in var_names + assert "var3" not in var_names + + def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test listing conversation variables successfully. + + This test verifies that the service can filter and return only + conversation variables, excluding system and node variables. + Conversation variables are user-facing variables that can be + modified during conversation flows. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + conv_var1_value = StringSegment(value=fake.word()) + conv_var2_value = StringSegment(value=fake.word()) + conv_var1 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var1", conv_var1_value, fake=fake + ) + conv_var2 = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var2", conv_var2_value, fake=fake + ) + sys_var_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, + app.id, + SYSTEM_VARIABLE_NODE_ID, + "sys_var", + sys_var_value, + variable_type=DraftVariableType.SYS, + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_conversation_variables(app.id) + assert len(result.variables) == 2 + for var in result.variables: + assert var.node_id == CONVERSATION_VARIABLE_NODE_ID + assert var.app_id == app.id + assert var.get_variable_type() == DraftVariableType.CONVERSATION + var_names = [var.name for var in result.variables] + assert "conv_var1" in var_names + assert "conv_var2" in var_names + assert "sys_var" not in var_names + + def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test updating a variable's name and value successfully. + + This test verifies that the service can update both the name and value + of an editable variable and that the changes are persisted correctly. + It also checks that the last_edited_at timestamp is updated appropriately. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + original_value = StringSegment(value=fake.word()) + new_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "original_name", + original_value, + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + updated_variable = service.update_variable(variable, name="new_name", value=new_value) + assert updated_variable.name == "new_name" + assert updated_variable.get_value().value == new_value.value + assert updated_variable.last_edited_at is not None + from extensions.ext_database import db + + db.session.refresh(variable) + assert variable.name == "new_name" + assert variable.get_value().value == new_value.value + assert variable.last_edited_at is not None + + def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test that updating a non-editable variable raises an exception. + + This test verifies that the service properly prevents updates to + variables that are not marked as editable. This is important for + maintaining data integrity and preventing unauthorized modifications + to system-controlled variables. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + original_value = StringSegment(value=fake.word()) + new_value = StringSegment(value=fake.word()) + variable = WorkflowDraftVariable.new_sys_variable( + app_id=app.id, + name=fake.word(), # This is typically not editable + value=original_value, + node_execution_id=fake.uuid4(), + editable=False, # Set as non-editable + ) + from extensions.ext_database import db + + db.session.add(variable) + db.session.commit() + service = WorkflowDraftVariableService(db_session_with_containers) + with pytest.raises(UpdateNotSupportedError) as exc_info: + service.update_variable(variable, name="new_name", value=new_value) + assert "variable not support updating" in str(exc_info.value) + assert variable.id in str(exc_info.value) + + def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test resetting conversation variable successfully. + + This test verifies that the service can reset a conversation variable + to its default value and clear the last_edited_at timestamp. + This functionality is useful for reverting user modifications + back to the original workflow configuration. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) + from core.variables.variables import StringVariable + + conv_var = StringVariable( + id=fake.uuid4(), + name="test_conv_var", + value="default_value", + selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"], + ) + workflow.conversation_variables = [conv_var] + from extensions.ext_database import db + + db.session.commit() + modified_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "test_conv_var", + modified_value, + fake=fake, + ) + variable.last_edited_at = fake.date_time() + db.session.commit() + service = WorkflowDraftVariableService(db_session_with_containers) + reset_variable = service.reset_variable(workflow, variable) + assert reset_variable is not None + assert reset_variable.get_value().value == "default_value" + assert reset_variable.last_edited_at is None + db.session.refresh(variable) + assert variable.get_value().value == "default_value" + assert variable.last_edited_at is None + + def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting a single variable successfully. + + This test verifies that the service can delete a specific variable + and that it's properly removed from the database. It ensures that + the deletion operation is atomic and complete. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + test_value = StringSegment(value=fake.word()) + variable = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake + ) + from extensions.ext_database import db + + assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None + service = WorkflowDraftVariableService(db_session_with_containers) + service.delete_variable(variable) + assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None + + def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting all variables for a workflow successfully. + + This test verifies that the service can delete all variables + associated with a specific app/workflow. This is useful for + cleanup operations when workflows are deleted or reset. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + for i in range(3): + test_value = StringSegment(value=fake.numerify("value######")) + self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + _get_random_variable_name(fake), + test_value, + fake=fake, + ) + other_app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + other_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, + other_app.id, + CONVERSATION_VARIABLE_NODE_ID, + _get_random_variable_name(fake), + other_value, + fake=fake, + ) + from extensions.ext_database import db + + app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + assert len(app_variables) == 3 + assert len(other_app_variables) == 1 + service = WorkflowDraftVariableService(db_session_with_containers) + service.delete_workflow_variables(app.id) + app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + assert len(app_variables_after) == 0 + assert len(other_app_variables_after) == 1 + + def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test deleting all variables for a specific node successfully. + + This test verifies that the service can delete all variables + associated with a specific node while preserving variables + for other nodes and conversation variables. This is important + for node-specific cleanup operations in workflow management. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + node_id = fake.word() + for i in range(2): + test_value = StringSegment(value=fake.numerify("node_value######")) + self._create_test_variable( + db_session_with_containers, + app.id, + node_id, + _get_random_variable_name(fake), + test_value, + variable_type=DraftVariableType.NODE, + fake=fake, + ) + other_node_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, + app.id, + "other_node", + _get_random_variable_name(fake), + other_node_value, + variable_type=DraftVariableType.NODE, + fake=fake, + ) + conv_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + _get_random_variable_name(fake), + conv_value, + fake=fake, + ) + from extensions.ext_database import db + + target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + other_node_variables = ( + db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + ) + conv_variables = ( + db.session.query(WorkflowDraftVariable) + .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) + .all() + ) + assert len(target_node_variables) == 2 + assert len(other_node_variables) == 1 + assert len(conv_variables) == 1 + service = WorkflowDraftVariableService(db_session_with_containers) + service.delete_node_variables(app.id, node_id) + target_node_variables_after = ( + db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + ) + other_node_variables_after = ( + db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + ) + conv_variables_after = ( + db.session.query(WorkflowDraftVariable) + .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) + .all() + ) + assert len(target_node_variables_after) == 0 + assert len(other_node_variables_after) == 1 + assert len(conv_variables_after) == 1 + + def test_prefill_conversation_variable_default_values_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test prefill conversation variable default values successfully. + + This test verifies that the service can automatically create + conversation variables with default values based on the workflow + configuration when none exist. This is important for initializing + workflow variables with proper defaults from the workflow definition. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) + from core.variables.variables import StringVariable + + conv_var1 = StringVariable( + id=fake.uuid4(), + name="conv_var1", + value="default_value1", + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var1"], + ) + conv_var2 = StringVariable( + id=fake.uuid4(), + name="conv_var2", + value="default_value2", + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"], + ) + workflow.conversation_variables = [conv_var1, conv_var2] + from extensions.ext_database import db + + db.session.commit() + service = WorkflowDraftVariableService(db_session_with_containers) + service.prefill_conversation_variable_default_values(workflow) + draft_variables = ( + db.session.query(WorkflowDraftVariable) + .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) + .all() + ) + assert len(draft_variables) == 2 + var_names = [var.name for var in draft_variables] + assert "conv_var1" in var_names + assert "conv_var2" in var_names + for var in draft_variables: + assert var.app_id == app.id + assert var.node_id == CONVERSATION_VARIABLE_NODE_ID + assert var.editable is True + assert var.get_variable_type() == DraftVariableType.CONVERSATION + + def test_get_conversation_id_from_draft_variable_success( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting conversation ID from draft variable successfully. + + This test verifies that the service can extract the conversation ID + from a system variable named "conversation_id". This is important + for maintaining conversation context across workflow executions. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + conversation_id = fake.uuid4() + conv_id_value = StringSegment(value=conversation_id) + self._create_test_variable( + db_session_with_containers, + app.id, + SYSTEM_VARIABLE_NODE_ID, + "conversation_id", + conv_id_value, + variable_type=DraftVariableType.SYS, + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) + assert retrieved_conv_id == conversation_id + + def test_get_conversation_id_from_draft_variable_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test getting conversation ID when it doesn't exist. + + This test verifies that the service returns None when no + conversation_id variable exists for the app. This ensures + proper handling of missing conversation context scenarios. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) + assert retrieved_conv_id is None + + def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test listing system variables successfully. + + This test verifies that the service can filter and return only + system variables, excluding conversation and node variables. + System variables are internal variables used by the workflow + engine for maintaining state and context. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + sys_var1_value = StringSegment(value=fake.word()) + sys_var2_value = StringSegment(value=fake.word()) + sys_var1 = self._create_test_variable( + db_session_with_containers, + app.id, + SYSTEM_VARIABLE_NODE_ID, + "sys_var1", + sys_var1_value, + variable_type=DraftVariableType.SYS, + fake=fake, + ) + sys_var2 = self._create_test_variable( + db_session_with_containers, + app.id, + SYSTEM_VARIABLE_NODE_ID, + "sys_var2", + sys_var2_value, + variable_type=DraftVariableType.SYS, + fake=fake, + ) + conv_var_value = StringSegment(value=fake.word()) + self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake + ) + service = WorkflowDraftVariableService(db_session_with_containers) + result = service.list_system_variables(app.id) + assert len(result.variables) == 2 + for var in result.variables: + assert var.node_id == SYSTEM_VARIABLE_NODE_ID + assert var.app_id == app.id + assert var.get_variable_type() == DraftVariableType.SYS + var_names = [var.name for var in result.variables] + assert "sys_var1" in var_names + assert "sys_var2" in var_names + assert "conv_var" not in var_names + + def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting variables by name successfully for different types. + + This test verifies that the service can retrieve variables by name + for different variable types (conversation, system, node). This + functionality is important for variable lookup operations during + workflow execution and user interactions. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + test_value = StringSegment(value=fake.word()) + conv_var = self._create_test_variable( + db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_conv_var", test_value, fake=fake + ) + sys_var = self._create_test_variable( + db_session_with_containers, + app.id, + SYSTEM_VARIABLE_NODE_ID, + "test_sys_var", + test_value, + variable_type=DraftVariableType.SYS, + fake=fake, + ) + node_var = self._create_test_variable( + db_session_with_containers, + app.id, + "test_node", + "test_node_var", + test_value, + variable_type=DraftVariableType.NODE, + fake=fake, + ) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var") + assert retrieved_conv_var is not None + assert retrieved_conv_var.name == "test_conv_var" + assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID + retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var") + assert retrieved_sys_var is not None + assert retrieved_sys_var.name == "test_sys_var" + assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID + retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var") + assert retrieved_node_var is not None + assert retrieved_node_var.name == "test_node_var" + assert retrieved_node_var.node_id == "test_node" + + def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test getting variables by name when they don't exist. + + This test verifies that the service returns None when trying to + retrieve variables by name that don't exist. This ensures proper + handling of missing variable scenarios for all variable types. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + service = WorkflowDraftVariableService(db_session_with_containers) + retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var") + assert retrieved_conv_var is None + retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var") + assert retrieved_sys_var is None + retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var") + assert retrieved_node_var is None diff --git a/api/tests/test_containers_integration_tests/workflow/__init__.py b/api/tests/test_containers_integration_tests/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/__init__.py b/api/tests/test_containers_integration_tests/workflow/nodes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/__init__.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_executor.py new file mode 100644 index 0000000000..487178ff58 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_executor.py @@ -0,0 +1,11 @@ +import pytest + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor + +CODE_LANGUAGE = "unsupported_language" + + +def test_unsupported_with_code_template(): + with pytest.raises(CodeExecutionError) as e: + CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) + assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py new file mode 100644 index 0000000000..19a41b6186 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -0,0 +1,47 @@ +from textwrap import dedent + +from .test_utils import CodeExecutorTestMixin + + +class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): + """Test class for JavaScript code executor functionality.""" + + def test_javascript_plain(self, flask_app_with_containers): + """Test basic JavaScript code execution with console.log output""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + code = 'console.log("Hello World")' + result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) + assert result_message == "Hello World\n" + + def test_javascript_json(self, flask_app_with_containers): + """Test JavaScript code execution with JSON output""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + code = dedent(""" + obj = {'Hello': 'World'} + console.log(JSON.stringify(obj)) + """) + result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) + assert result == '{"Hello":"World"}\n' + + def test_javascript_with_code_template(self, flask_app_with_containers): + """Test JavaScript workflow code template execution with inputs""" + CodeExecutor, CodeLanguage = self.code_executor_imports + JavascriptCodeProvider, _ = self.javascript_imports + + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JAVASCRIPT, + code=JavascriptCodeProvider.get_default_code(), + inputs={"arg1": "Hello", "arg2": "World"}, + ) + assert result == {"result": "HelloWorld"} + + def test_javascript_get_runner_script(self, flask_app_with_containers): + """Test JavaScript template transformer runner script generation""" + _, NodeJsTemplateTransformer = self.javascript_imports + + runner_script = NodeJsTemplateTransformer.get_runner_script() + assert runner_script.count(NodeJsTemplateTransformer._code_placeholder) == 1 + assert runner_script.count(NodeJsTemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(NodeJsTemplateTransformer._result_tag) == 2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py new file mode 100644 index 0000000000..c764801170 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -0,0 +1,42 @@ +import base64 + +from .test_utils import CodeExecutorTestMixin + + +class TestJinja2CodeExecutor(CodeExecutorTestMixin): + """Test class for Jinja2 code executor functionality.""" + + def test_jinja2(self, flask_app_with_containers): + """Test basic Jinja2 template execution with variable substitution""" + CodeExecutor, CodeLanguage = self.code_executor_imports + _, Jinja2TemplateTransformer = self.jinja2_imports + + template = "Hello {{template}}" + inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") + code = ( + Jinja2TemplateTransformer.get_runner_script() + .replace(Jinja2TemplateTransformer._code_placeholder, template) + .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) + ) + result = CodeExecutor.execute_code( + language=CodeLanguage.JINJA2, preload=Jinja2TemplateTransformer.get_preload_script(), code=code + ) + assert result == "<>Hello World<>\n" + + def test_jinja2_with_code_template(self, flask_app_with_containers): + """Test Jinja2 workflow code template execution with inputs""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, code="Hello {{template}}", inputs={"template": "World"} + ) + assert result == {"result": "Hello World"} + + def test_jinja2_get_runner_script(self, flask_app_with_containers): + """Test Jinja2 template transformer runner script generation""" + _, Jinja2TemplateTransformer = self.jinja2_imports + + runner_script = Jinja2TemplateTransformer.get_runner_script() + assert runner_script.count(Jinja2TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py new file mode 100644 index 0000000000..6d93df2472 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -0,0 +1,47 @@ +from textwrap import dedent + +from .test_utils import CodeExecutorTestMixin + + +class TestPython3CodeExecutor(CodeExecutorTestMixin): + """Test class for Python3 code executor functionality.""" + + def test_python3_plain(self, flask_app_with_containers): + """Test basic Python3 code execution with print output""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + code = 'print("Hello World")' + result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) + assert result == "Hello World\n" + + def test_python3_json(self, flask_app_with_containers): + """Test Python3 code execution with JSON output""" + CodeExecutor, CodeLanguage = self.code_executor_imports + + code = dedent(""" + import json + print(json.dumps({'Hello': 'World'})) + """) + result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) + assert result == '{"Hello": "World"}\n' + + def test_python3_with_code_template(self, flask_app_with_containers): + """Test Python3 workflow code template execution with inputs""" + CodeExecutor, CodeLanguage = self.code_executor_imports + Python3CodeProvider, _ = self.python3_imports + + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.PYTHON3, + code=Python3CodeProvider.get_default_code(), + inputs={"arg1": "Hello", "arg2": "World"}, + ) + assert result == {"result": "HelloWorld"} + + def test_python3_get_runner_script(self, flask_app_with_containers): + """Test Python3 template transformer runner script generation""" + _, Python3TemplateTransformer = self.python3_imports + + runner_script = Python3TemplateTransformer.get_runner_script() + assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1 + assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1 + assert runner_script.count(Python3TemplateTransformer._result_tag) == 2 diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_utils.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_utils.py new file mode 100644 index 0000000000..35a095b049 --- /dev/null +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_utils.py @@ -0,0 +1,115 @@ +""" +Test utilities for code executor integration tests. + +This module provides lazy import functions to avoid module loading issues +that occur when modules are imported before the flask_app_with_containers fixture +has set up the proper environment variables and configuration. +""" + +import importlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + pass + + +def force_reload_code_executor(): + """ + Force reload the code_executor module to reinitialize code_execution_endpoint_url. + + This function should be called after setting up environment variables + to ensure the code_execution_endpoint_url is initialized with the correct value. + """ + try: + import core.helper.code_executor.code_executor + + importlib.reload(core.helper.code_executor.code_executor) + except Exception as e: + # Log the error but don't fail the test + print(f"Warning: Failed to reload code_executor module: {e}") + + +def get_code_executor_imports(): + """ + Lazy import function for core CodeExecutor classes. + + Returns: + tuple: (CodeExecutor, CodeLanguage) classes + """ + from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage + + return CodeExecutor, CodeLanguage + + +def get_javascript_imports(): + """ + Lazy import function for JavaScript-specific modules. + + Returns: + tuple: (JavascriptCodeProvider, NodeJsTemplateTransformer) classes + """ + from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider + from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer + + return JavascriptCodeProvider, NodeJsTemplateTransformer + + +def get_python3_imports(): + """ + Lazy import function for Python3-specific modules. + + Returns: + tuple: (Python3CodeProvider, Python3TemplateTransformer) classes + """ + from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider + from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer + + return Python3CodeProvider, Python3TemplateTransformer + + +def get_jinja2_imports(): + """ + Lazy import function for Jinja2-specific modules. + + Returns: + tuple: (None, Jinja2TemplateTransformer) classes + """ + from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer + + return None, Jinja2TemplateTransformer + + +class CodeExecutorTestMixin: + """ + Mixin class providing lazy import methods for code executor tests. + + This mixin helps avoid module loading issues by deferring imports + until after the flask_app_with_containers fixture has set up the environment. + """ + + def setup_method(self): + """ + Setup method called before each test method. + Force reload the code_executor module to ensure fresh initialization. + """ + force_reload_code_executor() + + @property + def code_executor_imports(self): + """Property to get CodeExecutor and CodeLanguage classes.""" + return get_code_executor_imports() + + @property + def javascript_imports(self): + """Property to get JavaScript-specific classes.""" + return get_javascript_imports() + + @property + def python3_imports(self): + """Property to get Python3-specific classes.""" + return get_python3_imports() + + @property + def jinja2_imports(self): + """Property to get Jinja2-specific classes.""" + return get_jinja2_imports() diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index e9d4ee1935..0ae6a09f5b 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -1,5 +1,6 @@ import os +import pytest from flask import Flask from packaging.version import Version from yarl import URL @@ -137,3 +138,61 @@ def test_db_extras_options_merging(monkeypatch): options = engine_options["connect_args"]["options"] assert "search_path=myschema" in options assert "timezone=UTC" in options + + +@pytest.mark.parametrize( + ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), + [ + ("redis://localhost:6379/1", "localhost", 6379, None, None, "1"), + ("redis://:password@localhost:6379/1", "localhost", 6379, None, "password", "1"), + ("redis://:mypass%23123@localhost:6379/1", "localhost", 6379, None, "mypass#123", "1"), + ("redis://user:pass%40word@redis-host:6380/2", "redis-host", 6380, "user", "pass@word", "2"), + ("redis://admin:complex%23pass%40word@127.0.0.1:6379/0", "127.0.0.1", 6379, "admin", "complex#pass@word", "0"), + ( + "redis://user%40domain:secret%23123@redis.example.com:6380/3", + "redis.example.com", + 6380, + "user@domain", + "secret#123", + "3", + ), + # Password containing %23 substring (double encoding scenario) + ("redis://:mypass%2523@localhost:6379/1", "localhost", 6379, None, "mypass%23", "1"), + # Username and password both containing encoded characters + ("redis://user%2525%40:pass%2523@localhost:6379/1", "localhost", 6379, "user%25@", "pass%23", "1"), + ], +) +def test_celery_broker_url_with_special_chars_password( + monkeypatch, broker_url, expected_host, expected_port, expected_username, expected_password, expected_db +): + """Test that CELERY_BROKER_URL with various formats are handled correctly.""" + from kombu.utils.url import parse_url + + # clear system environment variables + os.environ.clear() + + # Set up basic required environment variables (following existing pattern) + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + # Set the CELERY_BROKER_URL to test + monkeypatch.setenv("CELERY_BROKER_URL", broker_url) + + # Create config and verify the URL is stored correctly + config = DifyConfig() + assert broker_url == config.CELERY_BROKER_URL + + # Test actual parsing behavior using kombu's parse_url (same as production) + redis_config = parse_url(config.CELERY_BROKER_URL) + + # Verify the parsing results match expectations (using kombu's field names) + assert redis_config["hostname"] == expected_host + assert redis_config["port"] == expected_port + assert redis_config["userid"] == expected_username # kombu uses 'userid' not 'username' + assert redis_config["password"] == expected_password + assert redis_config["virtual_host"] == expected_db # kombu uses 'virtual_host' not 'db' diff --git a/api/tests/unit_tests/controllers/console/app/test_description_validation.py b/api/tests/unit_tests/controllers/console/app/test_description_validation.py new file mode 100644 index 0000000000..178267e560 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_description_validation.py @@ -0,0 +1,252 @@ +import pytest + +from controllers.console.app.app import _validate_description_length as app_validate +from controllers.console.datasets.datasets import _validate_description_length as dataset_validate +from controllers.service_api.dataset.dataset import _validate_description_length as service_dataset_validate + + +class TestDescriptionValidationUnit: + """Unit tests for description validation functions in App and Dataset APIs""" + + def test_app_validate_description_length_valid(self): + """Test App validation function with valid descriptions""" + # Empty string should be valid + assert app_validate("") == "" + + # None should be valid + assert app_validate(None) is None + + # Short description should be valid + short_desc = "Short description" + assert app_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert app_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert app_validate(just_under) == just_under + + def test_app_validate_description_length_invalid(self): + """Test App validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + app_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + app_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 1000 characters should fail + very_long = "x" * 1000 + with pytest.raises(ValueError) as exc_info: + app_validate(very_long) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_dataset_validate_description_length_valid(self): + """Test Dataset validation function with valid descriptions""" + # Empty string should be valid + assert dataset_validate("") == "" + + # Short description should be valid + short_desc = "Short description" + assert dataset_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert dataset_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert dataset_validate(just_under) == just_under + + def test_dataset_validate_description_length_invalid(self): + """Test Dataset validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + dataset_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + dataset_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_service_dataset_validate_description_length_valid(self): + """Test Service Dataset validation function with valid descriptions""" + # Empty string should be valid + assert service_dataset_validate("") == "" + + # None should be valid + assert service_dataset_validate(None) is None + + # Short description should be valid + short_desc = "Short description" + assert service_dataset_validate(short_desc) == short_desc + + # Exactly 400 characters should be valid + exactly_400 = "x" * 400 + assert service_dataset_validate(exactly_400) == exactly_400 + + # Just under limit should be valid + just_under = "x" * 399 + assert service_dataset_validate(just_under) == just_under + + def test_service_dataset_validate_description_length_invalid(self): + """Test Service Dataset validation function with invalid descriptions""" + # 401 characters should fail + just_over = "x" * 401 + with pytest.raises(ValueError) as exc_info: + service_dataset_validate(just_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + # 500 characters should fail + way_over = "x" * 500 + with pytest.raises(ValueError) as exc_info: + service_dataset_validate(way_over) + assert "Description cannot exceed 400 characters." in str(exc_info.value) + + def test_app_dataset_validation_consistency(self): + """Test that App and Dataset validation functions behave identically""" + test_cases = [ + "", # Empty string + "Short description", # Normal description + "x" * 100, # Medium description + "x" * 400, # Exactly at limit + ] + + # Test valid cases produce same results + for test_desc in test_cases: + assert app_validate(test_desc) == dataset_validate(test_desc) == service_dataset_validate(test_desc) + + # Test invalid cases produce same errors + invalid_cases = [ + "x" * 401, # Just over limit + "x" * 500, # Way over limit + "x" * 1000, # Very long + ] + + for invalid_desc in invalid_cases: + app_error = None + dataset_error = None + service_dataset_error = None + + # Capture App validation error + try: + app_validate(invalid_desc) + except ValueError as e: + app_error = str(e) + + # Capture Dataset validation error + try: + dataset_validate(invalid_desc) + except ValueError as e: + dataset_error = str(e) + + # Capture Service Dataset validation error + try: + service_dataset_validate(invalid_desc) + except ValueError as e: + service_dataset_error = str(e) + + # All should produce errors + assert app_error is not None, f"App validation should fail for {len(invalid_desc)} characters" + assert dataset_error is not None, f"Dataset validation should fail for {len(invalid_desc)} characters" + error_msg = f"Service Dataset validation should fail for {len(invalid_desc)} characters" + assert service_dataset_error is not None, error_msg + + # Errors should be identical + error_msg = f"Error messages should be identical for {len(invalid_desc)} characters" + assert app_error == dataset_error == service_dataset_error, error_msg + assert app_error == "Description cannot exceed 400 characters." + + def test_boundary_values(self): + """Test boundary values around the 400 character limit""" + boundary_tests = [ + (0, True), # Empty + (1, True), # Minimum + (399, True), # Just under limit + (400, True), # Exactly at limit + (401, False), # Just over limit + (402, False), # Over limit + (500, False), # Way over limit + ] + + for length, should_pass in boundary_tests: + test_desc = "x" * length + + if should_pass: + # Should not raise exception + assert app_validate(test_desc) == test_desc + assert dataset_validate(test_desc) == test_desc + assert service_dataset_validate(test_desc) == test_desc + else: + # Should raise ValueError + with pytest.raises(ValueError): + app_validate(test_desc) + with pytest.raises(ValueError): + dataset_validate(test_desc) + with pytest.raises(ValueError): + service_dataset_validate(test_desc) + + def test_special_characters(self): + """Test validation with special characters, Unicode, etc.""" + # Unicode characters + unicode_desc = "测试描述" * 100 # Chinese characters + if len(unicode_desc) <= 400: + assert app_validate(unicode_desc) == unicode_desc + assert dataset_validate(unicode_desc) == unicode_desc + assert service_dataset_validate(unicode_desc) == unicode_desc + + # Special characters + special_desc = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?" * 10 + if len(special_desc) <= 400: + assert app_validate(special_desc) == special_desc + assert dataset_validate(special_desc) == special_desc + assert service_dataset_validate(special_desc) == special_desc + + # Mixed content + mixed_desc = "Mixed content: 测试 123 !@# " * 15 + if len(mixed_desc) <= 400: + assert app_validate(mixed_desc) == mixed_desc + assert dataset_validate(mixed_desc) == mixed_desc + assert service_dataset_validate(mixed_desc) == mixed_desc + elif len(mixed_desc) > 400: + with pytest.raises(ValueError): + app_validate(mixed_desc) + with pytest.raises(ValueError): + dataset_validate(mixed_desc) + with pytest.raises(ValueError): + service_dataset_validate(mixed_desc) + + def test_whitespace_handling(self): + """Test validation with various whitespace scenarios""" + # Leading/trailing whitespace + whitespace_desc = " Description with whitespace " + if len(whitespace_desc) <= 400: + assert app_validate(whitespace_desc) == whitespace_desc + assert dataset_validate(whitespace_desc) == whitespace_desc + assert service_dataset_validate(whitespace_desc) == whitespace_desc + + # Newlines and tabs + multiline_desc = "Line 1\nLine 2\tTabbed content" + if len(multiline_desc) <= 400: + assert app_validate(multiline_desc) == multiline_desc + assert dataset_validate(multiline_desc) == multiline_desc + assert service_dataset_validate(multiline_desc) == multiline_desc + + # Only whitespace over limit + only_spaces = " " * 401 + with pytest.raises(ValueError): + app_validate(only_spaces) + with pytest.raises(ValueError): + dataset_validate(only_spaces) + with pytest.raises(ValueError): + service_dataset_validate(only_spaces) diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index f26be6702a..ac3c8e45c9 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -1,9 +1,8 @@ -import datetime import uuid from collections import OrderedDict from typing import Any, NamedTuple -from flask_restful import marshal +from flask_restx import marshal from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -13,6 +12,7 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from libs.datetime_utils import naive_utc_now from models.workflow import WorkflowDraftVariable from services.workflow_draft_variable_service import WorkflowDraftVariableList @@ -57,7 +57,7 @@ class TestWorkflowDraftVariableFields: ) sys_var.id = str(uuid.uuid4()) - sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + sys_var.last_edited_at = naive_utc_now() sys_var.visible = True expected_without_value = OrderedDict( @@ -88,7 +88,7 @@ class TestWorkflowDraftVariableFields: ) node_var.id = str(uuid.uuid4()) - node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + node_var.last_edited_at = naive_utc_now() expected_without_value: OrderedDict[str, Any] = OrderedDict( { diff --git a/api/tests/unit_tests/controllers/console/test_files_security.py b/api/tests/unit_tests/controllers/console/test_files_security.py new file mode 100644 index 0000000000..2630fbcfd0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_files_security.py @@ -0,0 +1,278 @@ +import io +from unittest.mock import patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.common.errors import ( + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from services.errors.file import FileTooLargeError as ServiceFileTooLargeError +from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError + + +class TestFileUploadSecurity: + """Test file upload security logic without complex framework setup""" + + # Test 1: Basic file validation + def test_should_validate_file_presence(self): + """Test that missing file is detected""" + from flask import Flask, request + + app = Flask(__name__) + + with app.test_request_context(method="POST", data={}): + # Simulate the check in FileApi.post() + if "file" not in request.files: + with pytest.raises(NoFileUploadedError): + raise NoFileUploadedError() + + def test_should_validate_multiple_files(self): + """Test that multiple files are rejected""" + from flask import Flask, request + + app = Flask(__name__) + + file_data = { + "file": (io.BytesIO(b"content1"), "file1.txt", "text/plain"), + "file2": (io.BytesIO(b"content2"), "file2.txt", "text/plain"), + } + + with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"): + # Simulate the check in FileApi.post() + if len(request.files) > 1: + with pytest.raises(TooManyFilesError): + raise TooManyFilesError() + + def test_should_validate_empty_filename(self): + """Test that empty filename is rejected""" + from flask import Flask, request + + app = Flask(__name__) + + file_data = {"file": (io.BytesIO(b"content"), "", "text/plain")} + + with app.test_request_context(method="POST", data=file_data, content_type="multipart/form-data"): + file = request.files["file"] + if not file.filename: + with pytest.raises(FilenameNotExistsError): + raise FilenameNotExistsError + + # Test 2: Security - Filename sanitization + def test_should_detect_path_traversal_in_filename(self): + """Test protection against directory traversal attacks""" + dangerous_filenames = [ + "../../../etc/passwd", + "..\\..\\windows\\system32\\config\\sam", + "../../../../etc/shadow", + "./../../../sensitive.txt", + ] + + for filename in dangerous_filenames: + # Any filename containing .. should be considered dangerous + assert ".." in filename, f"Filename {filename} should be detected as path traversal" + + def test_should_detect_null_byte_injection(self): + """Test protection against null byte injection""" + dangerous_filenames = [ + "file.jpg\x00.php", + "document.pdf\x00.exe", + "image.png\x00.sh", + ] + + for filename in dangerous_filenames: + # Null bytes should be detected + assert "\x00" in filename, f"Filename {filename} should be detected as null byte injection" + + def test_should_sanitize_special_characters(self): + """Test that special characters in filenames are handled safely""" + # Characters that could be problematic in various contexts + dangerous_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\x00"] + + for char in dangerous_chars: + filename = f"file{char}name.txt" + # These characters should be detected or sanitized + assert any(c in filename for c in dangerous_chars) + + # Test 3: Permission validation + def test_should_validate_dataset_permissions(self): + """Test dataset upload permission logic""" + + class MockUser: + is_dataset_editor = False + + user = MockUser() + source = "datasets" + + # Simulate the permission check in FileApi.post() + if source == "datasets" and not user.is_dataset_editor: + with pytest.raises(Forbidden): + raise Forbidden() + + def test_should_allow_general_upload_without_permission(self): + """Test general upload doesn't require dataset permission""" + + class MockUser: + is_dataset_editor = False + + user = MockUser() + source = None # General upload + + # This should not raise an exception + if source == "datasets" and not user.is_dataset_editor: + raise Forbidden() + # Test passes if no exception is raised + + # Test 4: Service error handling + @patch("services.file_service.FileService.upload_file") + def test_should_handle_file_too_large_error(self, mock_upload): + """Test that service FileTooLargeError is properly converted""" + mock_upload.side_effect = ServiceFileTooLargeError("File too large") + + try: + mock_upload(filename="test.txt", content=b"data", mimetype="text/plain", user=None, source=None) + except ServiceFileTooLargeError as e: + # Simulate the error conversion in FileApi.post() + with pytest.raises(FileTooLargeError): + raise FileTooLargeError(e.description) + + @patch("services.file_service.FileService.upload_file") + def test_should_handle_unsupported_file_type_error(self, mock_upload): + """Test that service UnsupportedFileTypeError is properly converted""" + mock_upload.side_effect = ServiceUnsupportedFileTypeError() + + try: + mock_upload( + filename="test.exe", content=b"data", mimetype="application/octet-stream", user=None, source=None + ) + except ServiceUnsupportedFileTypeError: + # Simulate the error conversion in FileApi.post() + with pytest.raises(UnsupportedFileTypeError): + raise UnsupportedFileTypeError() + + # Test 5: File type security + def test_should_identify_dangerous_file_extensions(self): + """Test detection of potentially dangerous file extensions""" + dangerous_extensions = [ + ".php", + ".PHP", + ".pHp", # PHP files (case variations) + ".exe", + ".EXE", # Executables + ".sh", + ".SH", # Shell scripts + ".bat", + ".BAT", # Batch files + ".cmd", + ".CMD", # Command files + ".ps1", + ".PS1", # PowerShell + ".jar", + ".JAR", # Java archives + ".vbs", + ".VBS", # VBScript + ] + + safe_extensions = [".txt", ".pdf", ".jpg", ".png", ".doc", ".docx"] + + # Just verify our test data is correct + for ext in dangerous_extensions: + assert ext.lower() in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"] + + for ext in safe_extensions: + assert ext.lower() not in [".php", ".exe", ".sh", ".bat", ".cmd", ".ps1", ".jar", ".vbs"] + + def test_should_detect_double_extensions(self): + """Test detection of double extension attacks""" + suspicious_filenames = [ + "image.jpg.php", + "document.pdf.exe", + "photo.png.sh", + "file.txt.bat", + ] + + for filename in suspicious_filenames: + # Check that these have multiple extensions + parts = filename.split(".") + assert len(parts) > 2, f"Filename {filename} should have multiple extensions" + + # Test 6: Configuration validation + def test_upload_configuration_structure(self): + """Test that upload configuration has correct structure""" + # Simulate the configuration returned by FileApi.get() + config = { + "file_size_limit": 15, + "batch_count_limit": 5, + "image_file_size_limit": 10, + "video_file_size_limit": 500, + "audio_file_size_limit": 50, + "workflow_file_upload_limit": 10, + } + + # Verify all required fields are present + required_fields = [ + "file_size_limit", + "batch_count_limit", + "image_file_size_limit", + "video_file_size_limit", + "audio_file_size_limit", + "workflow_file_upload_limit", + ] + + for field in required_fields: + assert field in config, f"Missing required field: {field}" + assert isinstance(config[field], int), f"Field {field} should be an integer" + assert config[field] > 0, f"Field {field} should be positive" + + # Test 7: Source parameter handling + def test_source_parameter_normalization(self): + """Test that source parameter is properly normalized""" + test_cases = [ + ("datasets", "datasets"), + ("other", None), + ("", None), + (None, None), + ] + + for input_source, expected in test_cases: + # Simulate the source normalization in FileApi.post() + source = "datasets" if input_source == "datasets" else None + if source not in ("datasets", None): + source = None + assert source == expected + + # Test 8: Boundary conditions + def test_should_handle_edge_case_file_sizes(self): + """Test handling of boundary file sizes""" + test_cases = [ + (0, "Empty file"), # 0 bytes + (1, "Single byte"), # 1 byte + (15 * 1024 * 1024 - 1, "Just under limit"), # Just under 15MB + (15 * 1024 * 1024, "At limit"), # Exactly 15MB + (15 * 1024 * 1024 + 1, "Just over limit"), # Just over 15MB + ] + + for size, description in test_cases: + # Just verify our test data + assert isinstance(size, int), f"{description}: Size should be integer" + assert size >= 0, f"{description}: Size should be non-negative" + + def test_should_handle_special_mime_types(self): + """Test handling of various MIME types""" + mime_type_tests = [ + ("application/octet-stream", "Generic binary"), + ("text/plain", "Plain text"), + ("image/jpeg", "JPEG image"), + ("application/pdf", "PDF document"), + ("", "Empty MIME type"), + (None, "None MIME type"), + ] + + for mime_type, description in mime_type_tests: + # Verify test data structure + if mime_type is not None: + assert isinstance(mime_type, str), f"{description}: MIME type should be string or None" diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py new file mode 100644 index 0000000000..5c484403a6 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -0,0 +1,336 @@ +""" +Unit tests for Service API File Preview endpoint +""" + +import uuid +from unittest.mock import Mock, patch + +import pytest + +from controllers.service_api.app.error import FileAccessDeniedError, FileNotFoundError +from controllers.service_api.app.file_preview import FilePreviewApi +from models.model import App, EndUser, Message, MessageFile, UploadFile + + +class TestFilePreviewApi: + """Test suite for FilePreviewApi""" + + @pytest.fixture + def file_preview_api(self): + """Create FilePreviewApi instance for testing""" + return FilePreviewApi() + + @pytest.fixture + def mock_app(self): + """Mock App model""" + app = Mock(spec=App) + app.id = str(uuid.uuid4()) + app.tenant_id = str(uuid.uuid4()) + return app + + @pytest.fixture + def mock_end_user(self): + """Mock EndUser model""" + end_user = Mock(spec=EndUser) + end_user.id = str(uuid.uuid4()) + return end_user + + @pytest.fixture + def mock_upload_file(self): + """Mock UploadFile model""" + upload_file = Mock(spec=UploadFile) + upload_file.id = str(uuid.uuid4()) + upload_file.name = "test_file.jpg" + upload_file.mime_type = "image/jpeg" + upload_file.size = 1024 + upload_file.key = "storage/key/test_file.jpg" + upload_file.tenant_id = str(uuid.uuid4()) + return upload_file + + @pytest.fixture + def mock_message_file(self): + """Mock MessageFile model""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.upload_file_id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + return message_file + + @pytest.fixture + def mock_message(self): + """Mock Message model""" + message = Mock(spec=Message) + message.id = str(uuid.uuid4()) + message.app_id = str(uuid.uuid4()) + return message + + def test_validate_file_ownership_success( + self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test successful file ownership validation""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up the mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # Execute the method + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + + # Assertions + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + def test_validate_file_ownership_file_not_found(self, file_preview_api): + """Test file ownership validation when MessageFile not found""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile not found + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Execute and assert exception + with pytest.raises(FileNotFoundError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "File not found in message context" in str(exc_info.value) + + def test_validate_file_ownership_access_denied(self, file_preview_api, mock_message_file): + """Test file ownership validation when Message not owned by app""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile found but Message not owned by app + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query - found + None, # Message query - not found (access denied) + ] + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "not owned by requesting app" in str(exc_info.value) + + def test_validate_file_ownership_upload_file_not_found(self, file_preview_api, mock_message_file, mock_message): + """Test file ownership validation when UploadFile not found""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock MessageFile and Message found but UploadFile not found + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query - found + mock_message, # Message query - found + None, # UploadFile query - not found + ] + + # Execute and assert exception + with pytest.raises(FileNotFoundError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "Upload file record not found" in str(exc_info.value) + + def test_validate_file_ownership_tenant_mismatch( + self, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test file ownership validation with tenant mismatch""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up tenant mismatch + mock_upload_file.tenant_id = "different_tenant_id" + mock_app.tenant_id = "app_tenant_id" + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + assert "tenant mismatch" in str(exc_info.value) + + def test_validate_file_ownership_invalid_input(self, file_preview_api): + """Test file ownership validation with invalid input""" + + # Test with empty file_id + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership("", "app_id") + assert "Invalid file or app identifier" in str(exc_info.value) + + # Test with empty app_id + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership("file_id", "") + assert "Invalid file or app identifier" in str(exc_info.value) + + def test_build_file_response_basic(self, file_preview_api, mock_upload_file): + """Test basic file response building""" + mock_generator = Mock() + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Check response properties + assert response.mimetype == mock_upload_file.mime_type + assert response.direct_passthrough is True + assert response.headers["Content-Length"] == str(mock_upload_file.size) + assert "Cache-Control" in response.headers + + def test_build_file_response_as_attachment(self, file_preview_api, mock_upload_file): + """Test file response building with attachment flag""" + mock_generator = Mock() + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, True) + + # Check attachment-specific headers + assert "attachment" in response.headers["Content-Disposition"] + assert mock_upload_file.name in response.headers["Content-Disposition"] + assert response.headers["Content-Type"] == "application/octet-stream" + + def test_build_file_response_audio_video(self, file_preview_api, mock_upload_file): + """Test file response building for audio/video files""" + mock_generator = Mock() + mock_upload_file.mime_type = "video/mp4" + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Check Range support for media files + assert response.headers["Accept-Ranges"] == "bytes" + + def test_build_file_response_no_size(self, file_preview_api, mock_upload_file): + """Test file response building when size is unknown""" + mock_generator = Mock() + mock_upload_file.size = 0 # Unknown size + + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + + # Content-Length should not be set when size is unknown + assert "Content-Length" not in response.headers + + @patch("controllers.service_api.app.file_preview.storage") + def test_get_method_integration( + self, mock_storage, file_preview_api, mock_app, mock_end_user, mock_upload_file, mock_message_file, mock_message + ): + """Test the full GET method integration (without decorator)""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + mock_generator = Mock() + mock_storage.load.return_value = mock_generator + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + with patch("controllers.service_api.app.file_preview.reqparse") as mock_reqparse: + # Mock request parsing + mock_parser = Mock() + mock_parser.parse_args.return_value = {"as_attachment": False} + mock_reqparse.RequestParser.return_value = mock_parser + + # Test the core logic directly without Flask decorators + # Validate file ownership + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + # Test file response building + response = file_preview_api._build_file_response(mock_generator, mock_upload_file, False) + assert response is not None + + # Verify storage was called correctly + mock_storage.load.assert_not_called() # Since we're testing components separately + + @patch("controllers.service_api.app.file_preview.storage") + def test_storage_error_handling( + self, mock_storage, file_preview_api, mock_app, mock_upload_file, mock_message_file, mock_message + ): + """Test storage error handling in the core logic""" + file_id = str(uuid.uuid4()) + app_id = mock_app.id + + # Set up mocks + mock_upload_file.tenant_id = mock_app.tenant_id + mock_message.app_id = app_id + mock_message_file.upload_file_id = file_id + mock_message_file.message_id = mock_message.id + + # Mock storage error + mock_storage.load.side_effect = Exception("Storage error") + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database queries for validation + mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_message_file, # MessageFile query + mock_message, # Message query + mock_upload_file, # UploadFile query + mock_app, # App query for tenant validation + ] + + # First validate file ownership works + result_message_file, result_upload_file = file_preview_api._validate_file_ownership(file_id, app_id) + assert result_message_file == mock_message_file + assert result_upload_file == mock_upload_file + + # Test storage error handling + with pytest.raises(Exception) as exc_info: + mock_storage.load(mock_upload_file.key, stream=True) + + assert "Storage error" in str(exc_info.value) + + @patch("controllers.service_api.app.file_preview.logger") + def test_validate_file_ownership_unexpected_error_logging(self, mock_logger, file_preview_api): + """Test that unexpected errors are logged properly""" + file_id = str(uuid.uuid4()) + app_id = str(uuid.uuid4()) + + with patch("controllers.service_api.app.file_preview.db") as mock_db: + # Mock database query to raise unexpected exception + mock_db.session.query.side_effect = Exception("Unexpected database error") + + # Execute and assert exception + with pytest.raises(FileAccessDeniedError) as exc_info: + file_preview_api._validate_file_ownership(file_id, app_id) + + # Verify error message + assert "File access validation failed" in str(exc_info.value) + + # Verify logging was called + mock_logger.exception.assert_called_once_with( + "Unexpected error during file ownership validation", + extra={"file_id": file_id, "app_id": app_id, "error": "Unexpected database error"}, + ) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py new file mode 100644 index 0000000000..da175e7ccd --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -0,0 +1,419 @@ +"""Test conversation variable handling in AdvancedChatAppRunner.""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.variables import SegmentType +from factories import variable_factory +from models import ConversationVariable, Workflow + + +class TestAdvancedChatAppRunnerConversationVariables: + """Test that AdvancedChatAppRunner correctly handles conversation variables.""" + + def test_missing_conversation_variables_are_added(self): + """Test that new conversation variables added to workflow are created for existing conversations.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with two conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "existing_var", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "new_var", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow with conversation variables + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Create existing conversation variable (only var1 exists in DB) + existing_db_var = MagicMock(spec=ConversationVariable) + existing_db_var.id = "var1" + existing_db_var.app_id = app_id + existing_db_var.conversation_id = conversation_id + existing_db_var.to_variable = MagicMock(return_value=workflow_vars[0]) + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # First query returns only existing variable + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = [existing_db_var] + mock_session.scalars.return_value = mock_scalars_result + + # Track what gets added to session + added_items = [] + + def track_add_all(items): + added_items.extend(items) + + mock_session.add_all.side_effect = track_add_all + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that the missing variable was added + assert len(added_items) == 1, "Should have added exactly one missing variable" + + # Check that the added item is the missing variable (var2) + added_var = added_items[0] + assert hasattr(added_var, "id"), "Added item should be a ConversationVariable" + # Note: Since we're mocking ConversationVariable.from_variable, + # we can't directly check the id, but we can verify add_all was called + assert mock_session.add_all.called, "Session add_all should have been called" + assert mock_session.commit.called, "Session commit should have been called" + + def test_no_variables_creates_all(self): + """Test that all conversation variables are created when none exist in DB.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "var1", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "var2", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # Query returns empty list (no existing variables) + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = [] + mock_session.scalars.return_value = mock_scalars_result + + # Track what gets added to session + added_items = [] + + def track_add_all(items): + added_items.extend(items) + + mock_session.add_all.side_effect = track_add_all + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + patch("core.app.apps.advanced_chat.app_runner.ConversationVariable") as mock_conv_var_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock ConversationVariable.from_variable to return mock objects + mock_conv_vars = [] + for var in workflow_vars: + mock_cv = MagicMock() + mock_cv.id = var.id + mock_cv.to_variable.return_value = var + mock_conv_vars.append(mock_cv) + + mock_conv_var_class.from_variable.side_effect = mock_conv_vars + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that all variables were created + assert len(added_items) == 2, "Should have added both variables" + assert mock_session.add_all.called, "Session add_all should have been called" + assert mock_session.commit.called, "Session commit should have been called" + + def test_all_variables_exist_no_changes(self): + """Test that no changes are made when all variables already exist in DB.""" + # Setup + app_id = str(uuid4()) + conversation_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Create workflow with conversation variables + workflow_vars = [ + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var1", + "name": "var1", + "value_type": SegmentType.STRING, + "value": "default1", + } + ), + variable_factory.build_conversation_variable_from_mapping( + { + "id": "var2", + "name": "var2", + "value_type": SegmentType.STRING, + "value": "default2", + } + ), + ] + + # Mock workflow + mock_workflow = MagicMock(spec=Workflow) + mock_workflow.conversation_variables = workflow_vars + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.id = workflow_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + # Create existing conversation variables (both exist in DB) + existing_db_vars = [] + for var in workflow_vars: + db_var = MagicMock(spec=ConversationVariable) + db_var.id = var.id + db_var.app_id = app_id + db_var.conversation_id = conversation_id + db_var.to_variable = MagicMock(return_value=var) + existing_db_vars.append(db_var) + + # Mock conversation and message + mock_conversation = MagicMock() + mock_conversation.app_id = app_id + mock_conversation.id = conversation_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + # Mock app config + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + # Mock app generate entity + mock_app_generate_entity = MagicMock(spec=AdvancedChatAppGenerateEntity) + mock_app_generate_entity.app_config = mock_app_config + mock_app_generate_entity.inputs = {} + mock_app_generate_entity.query = "test query" + mock_app_generate_entity.files = [] + mock_app_generate_entity.user_id = str(uuid4()) + mock_app_generate_entity.invoke_from = InvokeFrom.SERVICE_API + mock_app_generate_entity.workflow_run_id = str(uuid4()) + mock_app_generate_entity.call_depth = 0 + mock_app_generate_entity.single_iteration_run = None + mock_app_generate_entity.single_loop_run = None + mock_app_generate_entity.trace_manager = None + + # Create runner + runner = AdvancedChatAppRunner( + application_generate_entity=mock_app_generate_entity, + queue_manager=MagicMock(), + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + ) + + # Mock database session + mock_session = MagicMock(spec=Session) + + # Query returns all existing variables + mock_scalars_result = MagicMock() + mock_scalars_result.all.return_value = existing_db_vars + mock_session.scalars.return_value = mock_scalars_result + + # Patch the necessary components + with ( + patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, + patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, + patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_input_moderation", return_value=False), + patch.object(runner, "handle_annotation_reply", return_value=False), + patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, + patch("core.app.apps.advanced_chat.app_runner.VariablePool") as mock_variable_pool_class, + ): + # Setup mocks + mock_session_class.return_value.__enter__.return_value = mock_session + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists + mock_db.engine = MagicMock() + + # Mock graph initialization + mock_init_graph.return_value = MagicMock() + + # Mock workflow entry + mock_workflow_entry = MagicMock() + mock_workflow_entry.run.return_value = iter([]) # Empty generator + mock_workflow_entry_class.return_value = mock_workflow_entry + + # Run the method + runner.run() + + # Verify that no variables were added + assert not mock_session.add_all.called, "Session add_all should not have been called" + assert mock_session.commit.called, "Session commit should still be called" diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/conftest.py b/api/tests/unit_tests/core/app/features/rate_limiting/conftest.py new file mode 100644 index 0000000000..9557e78150 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/rate_limiting/conftest.py @@ -0,0 +1,124 @@ +import time +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.features.rate_limiting.rate_limit import RateLimit + + +@pytest.fixture +def mock_redis(): + """Mock Redis client with realistic behavior for rate limiting tests.""" + mock_client = MagicMock() + + # Redis data storage for simulation + mock_data = {} + mock_hashes = {} + mock_expiry = {} + + def mock_setex(key, ttl, value): + mock_data[key] = str(value) + mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl + return True + + def mock_get(key): + if key in mock_data and (key not in mock_expiry or time.time() < mock_expiry[key]): + return mock_data[key].encode("utf-8") + return None + + def mock_exists(key): + return key in mock_data or key in mock_hashes + + def mock_expire(key, ttl): + if key in mock_data or key in mock_hashes: + mock_expiry[key] = time.time() + ttl.total_seconds() if hasattr(ttl, "total_seconds") else time.time() + ttl + return True + + def mock_hset(key, field, value): + if key not in mock_hashes: + mock_hashes[key] = {} + mock_hashes[key][field] = str(value).encode("utf-8") + return True + + def mock_hgetall(key): + return mock_hashes.get(key, {}) + + def mock_hdel(key, *fields): + if key in mock_hashes: + count = 0 + for field in fields: + if field in mock_hashes[key]: + del mock_hashes[key][field] + count += 1 + return count + return 0 + + def mock_hlen(key): + return len(mock_hashes.get(key, {})) + + # Configure mock methods + mock_client.setex = mock_setex + mock_client.get = mock_get + mock_client.exists = mock_exists + mock_client.expire = mock_expire + mock_client.hset = mock_hset + mock_client.hgetall = mock_hgetall + mock_client.hdel = mock_hdel + mock_client.hlen = mock_hlen + + # Store references for test verification + mock_client._mock_data = mock_data + mock_client._mock_hashes = mock_hashes + mock_client._mock_expiry = mock_expiry + + return mock_client + + +@pytest.fixture +def mock_time(): + """Mock time.time() for deterministic tests.""" + mock_time_val = 1000.0 + + def increment_time(seconds=1): + nonlocal mock_time_val + mock_time_val += seconds + return mock_time_val + + with patch("time.time", return_value=mock_time_val) as mock: + mock.increment = increment_time + yield mock + + +@pytest.fixture +def sample_generator(): + """Sample generator for testing RateLimitGenerator.""" + + def _create_generator(items=None, raise_error=False): + items = items or ["item1", "item2", "item3"] + for item in items: + if raise_error and item == "item2": + raise ValueError("Test error") + yield item + + return _create_generator + + +@pytest.fixture +def sample_mapping(): + """Sample mapping for testing RateLimitGenerator.""" + return {"key1": "value1", "key2": "value2"} + + +@pytest.fixture(autouse=True) +def reset_rate_limit_instances(): + """Clear RateLimit singleton instances between tests.""" + RateLimit._instance_dict.clear() + yield + RateLimit._instance_dict.clear() + + +@pytest.fixture +def redis_patch(): + """Patch redis_client globally for rate limit tests.""" + with patch("core.app.features.rate_limiting.rate_limit.redis_client") as mock: + yield mock diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py new file mode 100644 index 0000000000..3db10c1c72 --- /dev/null +++ b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py @@ -0,0 +1,569 @@ +import threading +import time +from datetime import timedelta +from unittest.mock import patch + +import pytest + +from core.app.features.rate_limiting.rate_limit import RateLimit +from core.errors.error import AppInvokeQuotaExceededError + + +class TestRateLimit: + """Core rate limiting functionality tests.""" + + def test_should_return_same_instance_for_same_client_id(self, redis_patch): + """Test singleton behavior for same client ID.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + rate_limit1 = RateLimit("client1", 5) + rate_limit2 = RateLimit("client1", 10) # Second instance with different limit + + assert rate_limit1 is rate_limit2 + # Current implementation: last constructor call overwrites max_active_requests + # This reflects the actual behavior where __init__ always sets max_active_requests + assert rate_limit1.max_active_requests == 10 + + def test_should_create_different_instances_for_different_client_ids(self, redis_patch): + """Test different instances for different client IDs.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + rate_limit1 = RateLimit("client1", 5) + rate_limit2 = RateLimit("client2", 10) + + assert rate_limit1 is not rate_limit2 + assert rate_limit1.client_id == "client1" + assert rate_limit2.client_id == "client2" + + def test_should_initialize_with_valid_parameters(self, redis_patch): + """Test normal initialization.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + rate_limit = RateLimit("test_client", 5) + + assert rate_limit.client_id == "test_client" + assert rate_limit.max_active_requests == 5 + assert hasattr(rate_limit, "initialized") + redis_patch.setex.assert_called_once() + + def test_should_skip_initialization_if_disabled(self): + """Test no initialization when rate limiting is disabled.""" + rate_limit = RateLimit("test_client", 0) + + assert rate_limit.disabled() + assert not hasattr(rate_limit, "initialized") + + def test_should_skip_reinitialization_of_existing_instance(self, redis_patch): + """Test that existing instance doesn't reinitialize.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + RateLimit("client1", 5) + redis_patch.reset_mock() + + RateLimit("client1", 10) + + redis_patch.setex.assert_not_called() + + def test_should_be_disabled_when_max_requests_is_zero_or_negative(self): + """Test disabled state for zero or negative limits.""" + rate_limit_zero = RateLimit("client1", 0) + rate_limit_negative = RateLimit("client2", -5) + + assert rate_limit_zero.disabled() + assert rate_limit_negative.disabled() + + def test_should_set_redis_keys_on_first_flush(self, redis_patch): + """Test Redis keys are set correctly on initial flush.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + rate_limit = RateLimit("test_client", 5) + + expected_max_key = "dify:rate_limit:test_client:max_active_requests" + redis_patch.setex.assert_called_with(expected_max_key, timedelta(days=1), 5) + + def test_should_sync_max_requests_from_redis_on_subsequent_flush(self, redis_patch): + """Test max requests syncs from Redis when key exists.""" + redis_patch.configure_mock( + **{ + "exists.return_value": True, + "get.return_value": b"10", + "expire.return_value": True, + } + ) + + rate_limit = RateLimit("test_client", 5) + rate_limit.flush_cache() + + assert rate_limit.max_active_requests == 10 + + @patch("time.time") + def test_should_clean_timeout_requests_from_active_list(self, mock_time, redis_patch): + """Test cleanup of timed-out requests.""" + current_time = 1000.0 + mock_time.return_value = current_time + + # Setup mock Redis with timed-out requests + timeout_requests = { + b"req1": str(current_time - 700).encode(), # 700 seconds ago (timeout) + b"req2": str(current_time - 100).encode(), # 100 seconds ago (active) + } + + redis_patch.configure_mock( + **{ + "exists.return_value": True, + "get.return_value": b"5", + "expire.return_value": True, + "hgetall.return_value": timeout_requests, + "hdel.return_value": 1, + } + ) + + rate_limit = RateLimit("test_client", 5) + redis_patch.reset_mock() # Reset to avoid counting initialization calls + rate_limit.flush_cache() + + # Verify timeout request was cleaned up + redis_patch.hdel.assert_called_once() + call_args = redis_patch.hdel.call_args[0] + assert call_args[0] == "dify:rate_limit:test_client:active_requests" + assert b"req1" in call_args # Timeout request should be removed + assert b"req2" not in call_args # Active request should remain + + +class TestRateLimitEnterExit: + """Rate limiting enter/exit logic tests.""" + + def test_should_allow_request_within_limit(self, redis_patch): + """Test allowing requests within the rate limit.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hlen.return_value": 2, + "hset.return_value": True, + } + ) + + rate_limit = RateLimit("test_client", 5) + request_id = rate_limit.enter() + + assert request_id != RateLimit._UNLIMITED_REQUEST_ID + redis_patch.hset.assert_called_once() + + def test_should_generate_request_id_if_not_provided(self, redis_patch): + """Test auto-generation of request ID.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hlen.return_value": 0, + "hset.return_value": True, + } + ) + + rate_limit = RateLimit("test_client", 5) + request_id = rate_limit.enter() + + assert len(request_id) == 36 # UUID format + + def test_should_use_provided_request_id(self, redis_patch): + """Test using provided request ID.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hlen.return_value": 0, + "hset.return_value": True, + } + ) + + rate_limit = RateLimit("test_client", 5) + custom_id = "custom_request_123" + request_id = rate_limit.enter(custom_id) + + assert request_id == custom_id + + def test_should_remove_request_on_exit(self, redis_patch): + """Test request removal on exit.""" + redis_patch.configure_mock( + **{ + "hdel.return_value": 1, + } + ) + + rate_limit = RateLimit("test_client", 5) + rate_limit.exit("test_request_id") + + redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", "test_request_id") + + def test_should_raise_quota_exceeded_when_at_limit(self, redis_patch): + """Test quota exceeded error when at limit.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hlen.return_value": 5, # At limit + } + ) + + rate_limit = RateLimit("test_client", 5) + + with pytest.raises(AppInvokeQuotaExceededError) as exc_info: + rate_limit.enter() + + assert "Too many requests" in str(exc_info.value) + assert "test_client" in str(exc_info.value) + + def test_should_allow_request_after_previous_exit(self, redis_patch): + """Test allowing new request after previous exit.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hlen.return_value": 4, # Under limit after exit + "hset.return_value": True, + "hdel.return_value": 1, + } + ) + + rate_limit = RateLimit("test_client", 5) + + request_id = rate_limit.enter() + rate_limit.exit(request_id) + + new_request_id = rate_limit.enter() + assert new_request_id is not None + + @patch("time.time") + def test_should_flush_cache_when_interval_exceeded(self, mock_time, redis_patch): + """Test cache flush when time interval exceeded.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hlen.return_value": 0, + } + ) + + mock_time.return_value = 1000.0 + rate_limit = RateLimit("test_client", 5) + + # Advance time beyond flush interval + mock_time.return_value = 1400.0 # 400 seconds later + redis_patch.reset_mock() + + rate_limit.enter() + + # Should have called setex again due to cache flush + redis_patch.setex.assert_called() + + def test_should_return_unlimited_id_when_disabled(self): + """Test unlimited ID return when rate limiting disabled.""" + rate_limit = RateLimit("test_client", 0) + request_id = rate_limit.enter() + + assert request_id == RateLimit._UNLIMITED_REQUEST_ID + + def test_should_ignore_exit_for_unlimited_requests(self, redis_patch): + """Test ignoring exit for unlimited requests.""" + rate_limit = RateLimit("test_client", 0) + rate_limit.exit(RateLimit._UNLIMITED_REQUEST_ID) + + redis_patch.hdel.assert_not_called() + + +class TestRateLimitGenerator: + """Rate limit generator wrapper tests.""" + + def test_should_wrap_generator_and_iterate_normally(self, redis_patch, sample_generator): + """Test normal generator iteration with rate limit wrapper.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hdel.return_value": 1, + } + ) + + rate_limit = RateLimit("test_client", 5) + generator = sample_generator() + request_id = "test_request" + + wrapped_gen = rate_limit.generate(generator, request_id) + result = list(wrapped_gen) + + assert result == ["item1", "item2", "item3"] + redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id) + + def test_should_handle_mapping_input_directly(self, sample_mapping): + """Test direct return of mapping input.""" + rate_limit = RateLimit("test_client", 0) # Disabled + result = rate_limit.generate(sample_mapping, "test_request") + + assert result is sample_mapping + + def test_should_cleanup_on_exception_during_iteration(self, redis_patch, sample_generator): + """Test cleanup when exception occurs during iteration.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hdel.return_value": 1, + } + ) + + rate_limit = RateLimit("test_client", 5) + generator = sample_generator(raise_error=True) + request_id = "test_request" + + wrapped_gen = rate_limit.generate(generator, request_id) + + with pytest.raises(ValueError): + list(wrapped_gen) + + redis_patch.hdel.assert_called_once_with("dify:rate_limit:test_client:active_requests", request_id) + + def test_should_cleanup_on_explicit_close(self, redis_patch, sample_generator): + """Test cleanup on explicit generator close.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hdel.return_value": 1, + } + ) + + rate_limit = RateLimit("test_client", 5) + generator = sample_generator() + request_id = "test_request" + + wrapped_gen = rate_limit.generate(generator, request_id) + wrapped_gen.close() + + redis_patch.hdel.assert_called_once() + + def test_should_handle_generator_without_close_method(self, redis_patch): + """Test handling generator without close method.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hdel.return_value": 1, + } + ) + + # Create a generator-like object without close method + class SimpleGenerator: + def __init__(self): + self.items = ["test"] + self.index = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.index >= len(self.items): + raise StopIteration + item = self.items[self.index] + self.index += 1 + return item + + rate_limit = RateLimit("test_client", 5) + generator = SimpleGenerator() + + wrapped_gen = rate_limit.generate(generator, "test_request") + wrapped_gen.close() # Should not raise error + + redis_patch.hdel.assert_called_once() + + def test_should_prevent_iteration_after_close(self, redis_patch, sample_generator): + """Test StopIteration after generator is closed.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hdel.return_value": 1, + } + ) + + rate_limit = RateLimit("test_client", 5) + generator = sample_generator() + + wrapped_gen = rate_limit.generate(generator, "test_request") + wrapped_gen.close() + + with pytest.raises(StopIteration): + next(wrapped_gen) + + +class TestRateLimitConcurrency: + """Concurrent access safety tests.""" + + def test_should_handle_concurrent_instance_creation(self, redis_patch): + """Test thread-safe singleton instance creation.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + instances = [] + errors = [] + + def create_instance(): + try: + instance = RateLimit("concurrent_client", 5) + instances.append(instance) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=create_instance) for _ in range(10)] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0 + assert len({id(inst) for inst in instances}) == 1 # All same instance + + def test_should_handle_concurrent_enter_requests(self, redis_patch): + """Test concurrent enter requests handling.""" + # Setup mock to simulate realistic Redis behavior + request_count = 0 + + def mock_hlen(key): + nonlocal request_count + return request_count + + def mock_hset(key, field, value): + nonlocal request_count + request_count += 1 + return True + + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + "hlen.side_effect": mock_hlen, + "hset.side_effect": mock_hset, + } + ) + + rate_limit = RateLimit("concurrent_client", 3) + results = [] + errors = [] + + def try_enter(): + try: + request_id = rate_limit.enter() + results.append(request_id) + except AppInvokeQuotaExceededError as e: + errors.append(e) + + threads = [threading.Thread(target=try_enter) for _ in range(5)] + + for t in threads: + t.start() + for t in threads: + t.join() + + # Should have some successful requests and some quota exceeded + assert len(results) + len(errors) == 5 + assert len(errors) > 0 # Some should be rejected + + @patch("time.time") + def test_should_maintain_accurate_count_under_load(self, mock_time, redis_patch): + """Test accurate count maintenance under concurrent load.""" + mock_time.return_value = 1000.0 + + # Use real mock_redis fixture for better simulation + mock_client = self._create_mock_redis() + redis_patch.configure_mock(**mock_client) + + rate_limit = RateLimit("load_test_client", 10) + active_requests = [] + + def enter_and_exit(): + try: + request_id = rate_limit.enter() + active_requests.append(request_id) + time.sleep(0.01) # Simulate some work + rate_limit.exit(request_id) + active_requests.remove(request_id) + except AppInvokeQuotaExceededError: + pass # Expected under load + + threads = [threading.Thread(target=enter_and_exit) for _ in range(20)] + + for t in threads: + t.start() + for t in threads: + t.join() + + # All requests should have been cleaned up + assert len(active_requests) == 0 + + def _create_mock_redis(self): + """Create a thread-safe mock Redis for concurrency tests.""" + import threading + + lock = threading.Lock() + data = {} + hashes = {} + + def mock_hlen(key): + with lock: + return len(hashes.get(key, {})) + + def mock_hset(key, field, value): + with lock: + if key not in hashes: + hashes[key] = {} + hashes[key][field] = str(value).encode("utf-8") + return True + + def mock_hdel(key, *fields): + with lock: + if key in hashes: + count = 0 + for field in fields: + if field in hashes[key]: + del hashes[key][field] + count += 1 + return count + return 0 + + return { + "exists.return_value": False, + "setex.return_value": True, + "hlen.side_effect": mock_hlen, + "hset.side_effect": mock_hset, + "hdel.side_effect": mock_hdel, + } diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py index 8122cd08eb..aadd366762 100644 --- a/api/tests/unit_tests/core/mcp/client/test_sse.py +++ b/api/tests/unit_tests/core/mcp/client/test_sse.py @@ -1,3 +1,4 @@ +import contextlib import json import queue import threading @@ -124,13 +125,10 @@ def test_sse_client_connection_validation(): mock_event_source.iter_sse.return_value = [endpoint_event] # Test connection - try: + with contextlib.suppress(Exception): with sse_client(test_url) as (read_queue, write_queue): assert read_queue is not None assert write_queue is not None - except Exception as e: - # Connection might fail due to mocking, but we're testing the validation logic - pass def test_sse_client_error_handling(): @@ -178,7 +176,7 @@ def test_sse_client_timeout_configuration(): mock_event_source.iter_sse.return_value = [] mock_sse_connect.return_value.__enter__.return_value = mock_event_source - try: + with contextlib.suppress(Exception): with sse_client( test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout ) as (read_queue, write_queue): @@ -190,9 +188,6 @@ def test_sse_client_timeout_configuration(): assert call_args is not None timeout_arg = call_args[1]["timeout"] assert timeout_arg.read == custom_sse_timeout - except Exception: - # Connection might fail due to mocking, but we tested the configuration - pass def test_sse_transport_endpoint_validation(): @@ -251,37 +246,15 @@ def test_sse_client_queue_cleanup(): # Mock connection that raises an exception mock_sse_connect.side_effect = Exception("Connection failed") - try: + with contextlib.suppress(Exception): with sse_client(test_url) as (rq, wq): read_queue = rq write_queue = wq - except Exception: - pass # Expected to fail # Queues should be cleaned up even on exception # Note: In real implementation, cleanup should put None to signal shutdown -def test_sse_client_url_processing(): - """Test SSE client URL processing functions.""" - from core.mcp.client.sse_client import remove_request_params - - # Test URL with parameters - url_with_params = "http://example.com/sse?param1=value1¶m2=value2" - cleaned_url = remove_request_params(url_with_params) - assert cleaned_url == "http://example.com/sse" - - # Test URL without parameters - url_without_params = "http://example.com/sse" - cleaned_url = remove_request_params(url_without_params) - assert cleaned_url == "http://example.com/sse" - - # Test URL with path and parameters - complex_url = "http://example.com/path/to/sse?session=123&token=abc" - cleaned_url = remove_request_params(complex_url) - assert cleaned_url == "http://example.com/path/to/sse" - - def test_sse_client_headers_propagation(): """Test that custom headers are properly propagated in SSE client.""" test_url = "http://test.example/sse" @@ -303,11 +276,9 @@ def test_sse_client_headers_propagation(): mock_event_source.iter_sse.return_value = [] mock_sse_connect.return_value.__enter__.return_value = mock_event_source - try: + with contextlib.suppress(Exception): with sse_client(test_url, headers=custom_headers): pass - except Exception: - pass # Expected due to mocking # Verify headers were passed to client factory mock_client_factory.assert_called_with(headers=custom_headers) diff --git a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py new file mode 100644 index 0000000000..c10f7b89c3 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py @@ -0,0 +1,148 @@ +"""Tests for LLMUsage entity.""" + +from decimal import Decimal + +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata + + +class TestLLMUsage: + """Test cases for LLMUsage class.""" + + def test_from_metadata_with_all_tokens(self): + """Test from_metadata when all token types are provided.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_unit_price": 0.001, + "completion_unit_price": 0.002, + "total_price": 0.2, + "currency": "USD", + "latency": 1.5, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + assert usage.prompt_unit_price == Decimal("0.001") + assert usage.completion_unit_price == Decimal("0.002") + assert usage.total_price == Decimal("0.2") + assert usage.currency == "USD" + assert usage.latency == 1.5 + + def test_from_metadata_with_prompt_tokens_only(self): + """Test from_metadata when only prompt_tokens is provided.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "total_tokens": 100, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 100 + + def test_from_metadata_with_completion_tokens_only(self): + """Test from_metadata when only completion_tokens is provided.""" + metadata: LLMUsageMetadata = { + "completion_tokens": 50, + "total_tokens": 50, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 50 + + def test_from_metadata_calculates_total_when_missing(self): + """Test from_metadata calculates total_tokens when not provided.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "completion_tokens": 50, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 # Should be calculated + + def test_from_metadata_with_total_but_no_completion(self): + """ + Test from_metadata when total_tokens is provided but completion_tokens is 0. + This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens. + """ + metadata: LLMUsageMetadata = { + "prompt_tokens": 479, + "completion_tokens": 0, + "total_tokens": 521, + } + + usage = LLMUsage.from_metadata(metadata) + + # This is the key fix - prompt tokens should remain as prompt tokens + assert usage.prompt_tokens == 479 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 521 + + def test_from_metadata_with_empty_metadata(self): + """Test from_metadata with empty metadata.""" + metadata: LLMUsageMetadata = {} + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 0 + assert usage.currency == "USD" + assert usage.latency == 0.0 + + def test_from_metadata_preserves_zero_completion_tokens(self): + """ + Test that zero completion_tokens are preserved when explicitly set. + This is important for agent nodes that only use prompt tokens. + """ + metadata: LLMUsageMetadata = { + "prompt_tokens": 1000, + "completion_tokens": 0, + "total_tokens": 1000, + "prompt_unit_price": 0.15, + "completion_unit_price": 0.60, + "prompt_price": 0.00015, + "completion_price": 0, + "total_price": 0.00015, + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_tokens == 1000 + assert usage.completion_tokens == 0 + assert usage.total_tokens == 1000 + assert usage.prompt_price == Decimal("0.00015") + assert usage.completion_price == Decimal(0) + assert usage.total_price == Decimal("0.00015") + + def test_from_metadata_with_decimal_values(self): + """Test from_metadata handles decimal values correctly.""" + metadata: LLMUsageMetadata = { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "prompt_unit_price": "0.001", + "completion_unit_price": "0.002", + "prompt_price": "0.1", + "completion_price": "0.1", + "total_price": "0.2", + } + + usage = LLMUsage.from_metadata(metadata) + + assert usage.prompt_unit_price == Decimal("0.001") + assert usage.completion_unit_price == Decimal("0.002") + assert usage.prompt_price == Decimal("0.1") + assert usage.completion_price == Decimal("0.1") + assert usage.total_price == Decimal("0.2") diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 81cb04548d..1dc380ad0b 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -102,9 +102,14 @@ class TestPhoenixConfig: assert config.project == "default" def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1") - assert config.endpoint == "https://custom.phoenix.com" + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" class TestLangfuseConfig: @@ -117,6 +122,13 @@ class TestLangfuseConfig: assert config.secret_key == "secret_key" assert config.host == "https://custom.langfuse.com" + def test_valid_config_with_path(self): + host = "https://custom.langfuse.com/api/v1" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == host + def test_default_values(self): """Test default values are set correctly""" config = LangfuseConfig(public_key="public", secret_key="secret") @@ -361,13 +373,15 @@ class TestConfigIntegration: """Test that URL normalization works consistently across configs""" # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") - phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/") + phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") aliyun_config = AliyunConfig( license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" ) assert arize_config.endpoint == "https://arize.com" - assert phoenix_config.endpoint == "https://phoenix.com" + assert phoenix_with_path_config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + assert phoenix_without_path_config.endpoint == "https://app.phoenix.arize.com" assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" def test_project_default_values(self): diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index f6d22690d1..8abed0a3f9 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -164,7 +164,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg ) assert isinstance(prompt_messages[3].content, list) assert len(prompt_messages[3].content) == 2 - assert prompt_messages[3].content[1].data == files[0].remote_url + assert prompt_messages[3].content[0].data == files[0].remote_url @pytest.fixture diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py new file mode 100644 index 0000000000..e7733b2317 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -0,0 +1,247 @@ +""" +Unit tests for CeleryWorkflowExecutionRepository. + +These tests verify the Celery-based asynchronous storage functionality +for workflow execution data. +""" + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType +from libs.datetime_utils import naive_utc_now +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom + + +@pytest.fixture +def mock_session_factory(): + """Mock SQLAlchemy session factory.""" + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + # Create a real sessionmaker with in-memory SQLite for testing + engine = create_engine("sqlite:///:memory:") + return sessionmaker(bind=engine) + + +@pytest.fixture +def mock_account(): + """Mock Account user.""" + account = Mock(spec=Account) + account.id = str(uuid4()) + account.current_tenant_id = str(uuid4()) + return account + + +@pytest.fixture +def mock_end_user(): + """Mock EndUser.""" + user = Mock(spec=EndUser) + user.id = str(uuid4()) + user.tenant_id = str(uuid4()) + return user + + +@pytest.fixture +def sample_workflow_execution(): + """Sample WorkflowExecution for testing.""" + return WorkflowExecution.new( + id_=str(uuid4()), + workflow_id=str(uuid4()), + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0", + graph={"nodes": [], "edges": []}, + inputs={"input1": "value1"}, + started_at=naive_utc_now(), + ) + + +class TestCeleryWorkflowExecutionRepository: + """Test cases for CeleryWorkflowExecutionRepository.""" + + def test_init_with_sessionmaker(self, mock_session_factory, mock_account): + """Test repository initialization with sessionmaker.""" + app_id = "test-app-id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id=app_id, + triggered_from=triggered_from, + ) + + assert repo._tenant_id == mock_account.current_tenant_id + assert repo._app_id == app_id + assert repo._triggered_from == triggered_from + assert repo._creator_user_id == mock_account.id + assert repo._creator_user_role is not None + + def test_init_basic_functionality(self, mock_session_factory, mock_account): + """Test repository initialization basic functionality.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + + # Verify basic initialization + assert repo._tenant_id == mock_account.current_tenant_id + assert repo._app_id == "test-app" + assert repo._triggered_from == WorkflowRunTriggeredFrom.DEBUGGING + + def test_init_with_end_user(self, mock_session_factory, mock_end_user): + """Test repository initialization with EndUser.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_end_user, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + assert repo._tenant_id == mock_end_user.tenant_id + + def test_init_without_tenant_id_raises_error(self, mock_session_factory): + """Test that initialization fails without tenant_id.""" + # Create a mock Account with no tenant_id + user = Mock(spec=Account) + user.current_tenant_id = None + user.id = str(uuid4()) + + with pytest.raises(ValueError, match="User must have a tenant_id"): + CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=user, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") + def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution): + """Test that save operation queues a Celery task without tracking.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + repo.save(sample_workflow_execution) + + # Verify Celery task was queued with correct parameters + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[1] + + assert call_args["execution_data"] == sample_workflow_execution.model_dump() + assert call_args["tenant_id"] == mock_account.current_tenant_id + assert call_args["app_id"] == "test-app" + assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value + assert call_args["creator_user_id"] == mock_account.id + + # Verify no task tracking occurs (no _pending_saves attribute) + assert not hasattr(repo, "_pending_saves") + + @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") + def test_save_handles_celery_failure( + self, mock_task, mock_session_factory, mock_account, sample_workflow_execution + ): + """Test that save operation handles Celery task failures.""" + mock_task.delay.side_effect = Exception("Celery is down") + + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + with pytest.raises(Exception, match="Celery is down"): + repo.save(sample_workflow_execution) + + @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") + def test_save_operation_fire_and_forget( + self, mock_task, mock_session_factory, mock_account, sample_workflow_execution + ): + """Test that save operation works in fire-and-forget mode.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Test that save doesn't block or maintain state + repo.save(sample_workflow_execution) + + # Verify no pending saves are tracked (no _pending_saves attribute) + assert not hasattr(repo, "_pending_saves") + + @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") + def test_multiple_save_operations(self, mock_task, mock_session_factory, mock_account): + """Test multiple save operations work correctly.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Create multiple executions + exec1 = WorkflowExecution.new( + id_=str(uuid4()), + workflow_id=str(uuid4()), + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0", + graph={"nodes": [], "edges": []}, + inputs={"input1": "value1"}, + started_at=naive_utc_now(), + ) + exec2 = WorkflowExecution.new( + id_=str(uuid4()), + workflow_id=str(uuid4()), + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0", + graph={"nodes": [], "edges": []}, + inputs={"input2": "value2"}, + started_at=naive_utc_now(), + ) + + # Save both executions + repo.save(exec1) + repo.save(exec2) + + # Should work without issues and not maintain state (no _pending_saves attribute) + assert not hasattr(repo, "_pending_saves") + + @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") + def test_save_with_different_user_types(self, mock_task, mock_session_factory, mock_end_user): + """Test save operation with different user types.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_end_user, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + execution = WorkflowExecution.new( + id_=str(uuid4()), + workflow_id=str(uuid4()), + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0", + graph={"nodes": [], "edges": []}, + inputs={"input1": "value1"}, + started_at=naive_utc_now(), + ) + + repo.save(execution) + + # Verify task was called with EndUser context + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[1] + assert call_args["tenant_id"] == mock_end_user.tenant_id + assert call_args["creator_user_id"] == mock_end_user.id diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py new file mode 100644 index 0000000000..0c6fdc8f92 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -0,0 +1,349 @@ +""" +Unit tests for CeleryWorkflowNodeExecutionRepository. + +These tests verify the Celery-based asynchronous storage functionality +for workflow node execution data. +""" + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.workflow.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from core.workflow.nodes.enums import NodeType +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from libs.datetime_utils import naive_utc_now +from models import Account, EndUser +from models.workflow import WorkflowNodeExecutionTriggeredFrom + + +@pytest.fixture +def mock_session_factory(): + """Mock SQLAlchemy session factory.""" + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + # Create a real sessionmaker with in-memory SQLite for testing + engine = create_engine("sqlite:///:memory:") + return sessionmaker(bind=engine) + + +@pytest.fixture +def mock_account(): + """Mock Account user.""" + account = Mock(spec=Account) + account.id = str(uuid4()) + account.current_tenant_id = str(uuid4()) + return account + + +@pytest.fixture +def mock_end_user(): + """Mock EndUser.""" + user = Mock(spec=EndUser) + user.id = str(uuid4()) + user.tenant_id = str(uuid4()) + return user + + +@pytest.fixture +def sample_workflow_node_execution(): + """Sample WorkflowNodeExecution for testing.""" + return WorkflowNodeExecution( + id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + node_id="test_node", + node_type=NodeType.START, + title="Test Node", + inputs={"input1": "value1"}, + status=WorkflowNodeExecutionStatus.RUNNING, + created_at=naive_utc_now(), + ) + + +class TestCeleryWorkflowNodeExecutionRepository: + """Test cases for CeleryWorkflowNodeExecutionRepository.""" + + def test_init_with_sessionmaker(self, mock_session_factory, mock_account): + """Test repository initialization with sessionmaker.""" + app_id = "test-app-id" + triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id=app_id, + triggered_from=triggered_from, + ) + + assert repo._tenant_id == mock_account.current_tenant_id + assert repo._app_id == app_id + assert repo._triggered_from == triggered_from + assert repo._creator_user_id == mock_account.id + assert repo._creator_user_role is not None + + def test_init_with_cache_initialized(self, mock_session_factory, mock_account): + """Test repository initialization with cache properly initialized.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + + assert repo._execution_cache == {} + assert repo._workflow_execution_mapping == {} + + def test_init_with_end_user(self, mock_session_factory, mock_end_user): + """Test repository initialization with EndUser.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_end_user, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + assert repo._tenant_id == mock_end_user.tenant_id + + def test_init_without_tenant_id_raises_error(self, mock_session_factory): + """Test that initialization fails without tenant_id.""" + # Create a mock Account with no tenant_id + user = Mock(spec=Account) + user.current_tenant_id = None + user.id = str(uuid4()) + + with pytest.raises(ValueError, match="User must have a tenant_id"): + CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=user, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") + def test_save_caches_and_queues_celery_task( + self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution + ): + """Test that save operation caches execution and queues a Celery task.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + repo.save(sample_workflow_node_execution) + + # Verify Celery task was queued with correct parameters + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[1] + + assert call_args["execution_data"] == sample_workflow_node_execution.model_dump() + assert call_args["tenant_id"] == mock_account.current_tenant_id + assert call_args["app_id"] == "test-app" + assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + assert call_args["creator_user_id"] == mock_account.id + + # Verify execution is cached + assert sample_workflow_node_execution.id in repo._execution_cache + assert repo._execution_cache[sample_workflow_node_execution.id] == sample_workflow_node_execution + + # Verify workflow execution mapping is updated + assert sample_workflow_node_execution.workflow_execution_id in repo._workflow_execution_mapping + assert ( + sample_workflow_node_execution.id + in repo._workflow_execution_mapping[sample_workflow_node_execution.workflow_execution_id] + ) + + @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") + def test_save_handles_celery_failure( + self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution + ): + """Test that save operation handles Celery task failures.""" + mock_task.delay.side_effect = Exception("Celery is down") + + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + with pytest.raises(Exception, match="Celery is down"): + repo.save(sample_workflow_node_execution) + + @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") + def test_get_by_workflow_run_from_cache( + self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution + ): + """Test that get_by_workflow_run retrieves executions from cache.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + # Save execution to cache first + repo.save(sample_workflow_node_execution) + + workflow_run_id = sample_workflow_node_execution.workflow_execution_id + order_config = OrderConfig(order_by=["index"], order_direction="asc") + + result = repo.get_by_workflow_run(workflow_run_id, order_config) + + # Verify results were retrieved from cache + assert len(result) == 1 + assert result[0].id == sample_workflow_node_execution.id + assert result[0] is sample_workflow_node_execution + + def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account): + """Test get_by_workflow_run without order configuration.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + result = repo.get_by_workflow_run("workflow-run-id") + + # Should return empty list since nothing in cache + assert len(result) == 0 + + @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") + def test_cache_operations(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution): + """Test cache operations work correctly.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + # Test saving to cache + repo.save(sample_workflow_node_execution) + + # Verify cache contains the execution + assert sample_workflow_node_execution.id in repo._execution_cache + + # Test retrieving from cache + result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id) + assert len(result) == 1 + assert result[0].id == sample_workflow_node_execution.id + + @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") + def test_multiple_executions_same_workflow(self, mock_task, mock_session_factory, mock_account): + """Test multiple executions for the same workflow.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + # Create multiple executions for the same workflow + workflow_run_id = str(uuid4()) + exec1 = WorkflowNodeExecution( + id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_id=str(uuid4()), + workflow_execution_id=workflow_run_id, + index=1, + node_id="node1", + node_type=NodeType.START, + title="Node 1", + inputs={"input1": "value1"}, + status=WorkflowNodeExecutionStatus.RUNNING, + created_at=naive_utc_now(), + ) + exec2 = WorkflowNodeExecution( + id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_id=str(uuid4()), + workflow_execution_id=workflow_run_id, + index=2, + node_id="node2", + node_type=NodeType.LLM, + title="Node 2", + inputs={"input2": "value2"}, + status=WorkflowNodeExecutionStatus.RUNNING, + created_at=naive_utc_now(), + ) + + # Save both executions + repo.save(exec1) + repo.save(exec2) + + # Verify both are cached and mapped + assert len(repo._execution_cache) == 2 + assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2 + + # Test retrieval + result = repo.get_by_workflow_run(workflow_run_id) + assert len(result) == 2 + + @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") + def test_ordering_functionality(self, mock_task, mock_session_factory, mock_account): + """Test ordering functionality works correctly.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + # Create executions with different indices + workflow_run_id = str(uuid4()) + exec1 = WorkflowNodeExecution( + id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_id=str(uuid4()), + workflow_execution_id=workflow_run_id, + index=2, + node_id="node2", + node_type=NodeType.START, + title="Node 2", + inputs={}, + status=WorkflowNodeExecutionStatus.RUNNING, + created_at=naive_utc_now(), + ) + exec2 = WorkflowNodeExecution( + id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_id=str(uuid4()), + workflow_execution_id=workflow_run_id, + index=1, + node_id="node1", + node_type=NodeType.LLM, + title="Node 1", + inputs={}, + status=WorkflowNodeExecutionStatus.RUNNING, + created_at=naive_utc_now(), + ) + + # Save in random order + repo.save(exec1) + repo.save(exec2) + + # Test ascending order + order_config = OrderConfig(order_by=["index"], order_direction="asc") + result = repo.get_by_workflow_run(workflow_run_id, order_config) + assert len(result) == 2 + assert result[0].index == 1 + assert result[1].index == 2 + + # Test descending order + order_config = OrderConfig(order_by=["index"], order_direction="desc") + result = repo.get_by_workflow_run(workflow_run_id, order_config) + assert len(result) == 2 + assert result[0].index == 2 + assert result[1].index == 1 diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index fce4a6fb6b..30f51902ef 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -2,19 +2,19 @@ Unit tests for the RepositoryFactory. This module tests the factory pattern implementation for creating repository instances -based on configuration, including error handling and validation. +based on configuration, including error handling. """ from unittest.mock import MagicMock, patch import pytest -from pytest_mock import MockerFixture from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -23,122 +23,36 @@ from models.workflow import WorkflowNodeExecutionTriggeredFrom class TestRepositoryFactory: """Test cases for RepositoryFactory.""" - def test_import_class_success(self): + def test_import_string_success(self): """Test successful class import.""" # Test importing a real class class_path = "unittest.mock.MagicMock" - result = DifyCoreRepositoryFactory._import_class(class_path) + result = import_string(class_path) assert result is MagicMock - def test_import_class_invalid_path(self): + def test_import_string_invalid_path(self): """Test import with invalid module path.""" - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._import_class("invalid.module.path") - assert "Cannot import repository class" in str(exc_info.value) + with pytest.raises(ImportError) as exc_info: + import_string("invalid.module.path") + assert "No module named" in str(exc_info.value) - def test_import_class_invalid_class_name(self): + def test_import_string_invalid_class_name(self): """Test import with invalid class name.""" - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass") - assert "Cannot import repository class" in str(exc_info.value) + with pytest.raises(ImportError) as exc_info: + import_string("unittest.mock.NonExistentClass") + assert "does not define" in str(exc_info.value) - def test_import_class_malformed_path(self): + def test_import_string_malformed_path(self): """Test import with malformed path (no dots).""" - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._import_class("invalidpath") - assert "Cannot import repository class" in str(exc_info.value) - - def test_validate_repository_interface_success(self): - """Test successful interface validation.""" - - # Create a mock class that implements the required methods - class MockRepository: - def save(self): - pass - - def get_by_id(self): - pass - - # Create a mock interface with the same methods - class MockInterface: - def save(self): - pass - - def get_by_id(self): - pass - - # Should not raise an exception - DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) - - def test_validate_repository_interface_missing_methods(self): - """Test interface validation with missing methods.""" - - # Create a mock class that doesn't implement all required methods - class IncompleteRepository: - def save(self): - pass - - # Missing get_by_id method - - # Create a mock interface with required methods - class MockInterface: - def save(self): - pass - - def get_by_id(self): - pass - - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) - assert "does not implement required methods" in str(exc_info.value) - assert "get_by_id" in str(exc_info.value) - - def test_validate_constructor_signature_success(self): - """Test successful constructor signature validation.""" - - class MockRepository: - def __init__(self, session_factory, user, app_id, triggered_from): - pass - - # Should not raise an exception - DifyCoreRepositoryFactory._validate_constructor_signature( - MockRepository, ["session_factory", "user", "app_id", "triggered_from"] - ) - - def test_validate_constructor_signature_missing_params(self): - """Test constructor validation with missing parameters.""" - - class IncompleteRepository: - def __init__(self, session_factory, user): - # Missing app_id and triggered_from parameters - pass - - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._validate_constructor_signature( - IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"] - ) - assert "does not accept required parameters" in str(exc_info.value) - assert "app_id" in str(exc_info.value) - assert "triggered_from" in str(exc_info.value) - - def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture): - """Test constructor validation when inspection fails.""" - # Mock inspect.signature to raise an exception - mocker.patch("inspect.signature", side_effect=Exception("Inspection failed")) - - class MockRepository: - def __init__(self, session_factory): - pass - - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) - assert "Failed to validate constructor signature" in str(exc_info.value) + with pytest.raises(ImportError) as exc_info: + import_string("invalidpath") + assert "doesn't look like a module path" in str(exc_info.value) @patch("core.repositories.factory.dify_config") - def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture): - """Test successful creation of WorkflowExecutionRepository.""" + def test_create_workflow_execution_repository_success(self, mock_config): + """Test successful WorkflowExecutionRepository creation.""" # Setup mock configuration - mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" # Create mock dependencies mock_session_factory = MagicMock(spec=sessionmaker) @@ -146,17 +60,13 @@ class TestRepositoryFactory: app_id = "test-app-id" triggered_from = WorkflowRunTriggeredFrom.APP_RUN - # Mock the imported class to be a valid repository + # Create mock repository class and instance mock_repository_class = MagicMock() mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) mock_repository_class.return_value = mock_repository_instance - # Mock the validation methods - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), - ): + # Mock import_string + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, @@ -177,7 +87,7 @@ class TestRepositoryFactory: def test_create_workflow_execution_repository_import_error(self, mock_config): """Test WorkflowExecutionRepository creation with import error.""" # Setup mock configuration with invalid class path - mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" + mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" mock_session_factory = MagicMock(spec=sessionmaker) mock_user = MagicMock(spec=Account) @@ -189,52 +99,23 @@ class TestRepositoryFactory: app_id="test-app-id", triggered_from=WorkflowRunTriggeredFrom.APP_RUN, ) - assert "Cannot import repository class" in str(exc_info.value) + assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) @patch("core.repositories.factory.dify_config") - def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): - """Test WorkflowExecutionRepository creation with validation error.""" - # Setup mock configuration - mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" - - mock_session_factory = MagicMock(spec=sessionmaker) - mock_user = MagicMock(spec=Account) - - # Mock import to succeed but validation to fail - mock_repository_class = MagicMock() - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object( - DifyCoreRepositoryFactory, - "_validate_repository_interface", - side_effect=RepositoryImportError("Interface validation failed"), - ), - ): - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=mock_session_factory, - user=mock_user, - app_id="test-app-id", - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - ) - assert "Interface validation failed" in str(exc_info.value) - - @patch("core.repositories.factory.dify_config") - def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture): + def test_create_workflow_execution_repository_instantiation_error(self, mock_config): """Test WorkflowExecutionRepository creation with instantiation error.""" # Setup mock configuration - mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" mock_session_factory = MagicMock(spec=sessionmaker) mock_user = MagicMock(spec=Account) - # Mock import and validation to succeed but instantiation to fail - mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), - ): + # Create a mock repository class that raises exception on instantiation + mock_repository_class = MagicMock() + mock_repository_class.side_effect = Exception("Instantiation failed") + + # Mock import_string to return a failing class + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): with pytest.raises(RepositoryImportError) as exc_info: DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, @@ -245,28 +126,24 @@ class TestRepositoryFactory: assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) @patch("core.repositories.factory.dify_config") - def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture): - """Test successful creation of WorkflowNodeExecutionRepository.""" + def test_create_workflow_node_execution_repository_success(self, mock_config): + """Test successful WorkflowNodeExecutionRepository creation.""" # Setup mock configuration - mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" # Create mock dependencies mock_session_factory = MagicMock(spec=sessionmaker) mock_user = MagicMock(spec=EndUser) app_id = "test-app-id" - triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP - # Mock the imported class to be a valid repository + # Create mock repository class and instance mock_repository_class = MagicMock() mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository) mock_repository_class.return_value = mock_repository_instance - # Mock the validation methods - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), - ): + # Mock import_string + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, @@ -287,7 +164,7 @@ class TestRepositoryFactory: def test_create_workflow_node_execution_repository_import_error(self, mock_config): """Test WorkflowNodeExecutionRepository creation with import error.""" # Setup mock configuration with invalid class path - mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" + mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" mock_session_factory = MagicMock(spec=sessionmaker) mock_user = MagicMock(spec=EndUser) @@ -297,159 +174,71 @@ class TestRepositoryFactory: session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) - assert "Cannot import repository class" in str(exc_info.value) - - def test_repository_import_error_exception(self): - """Test RepositoryImportError exception.""" - error_message = "Test error message" - exception = RepositoryImportError(error_message) - assert str(exception) == error_message - assert isinstance(exception, Exception) - - @patch("core.repositories.factory.dify_config") - def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture): - """Test repository creation with Engine instead of sessionmaker.""" - # Setup mock configuration - mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" - - # Create mock dependencies with Engine instead of sessionmaker - mock_engine = MagicMock(spec=Engine) - mock_user = MagicMock(spec=Account) - - # Mock the imported class to be a valid repository - mock_repository_class = MagicMock() - mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) - mock_repository_class.return_value = mock_repository_instance - - # Mock the validation methods - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), - ): - result = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=mock_engine, # Using Engine instead of sessionmaker - user=mock_user, - app_id="test-app-id", - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - ) - - # Verify the repository was created with the Engine - mock_repository_class.assert_called_once_with( - session_factory=mock_engine, - user=mock_user, - app_id="test-app-id", - triggered_from=WorkflowRunTriggeredFrom.APP_RUN, - ) - assert result is mock_repository_instance - - @patch("core.repositories.factory.dify_config") - def test_create_workflow_node_execution_repository_validation_error(self, mock_config): - """Test WorkflowNodeExecutionRepository creation with validation error.""" - # Setup mock configuration - mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" - - mock_session_factory = MagicMock(spec=sessionmaker) - mock_user = MagicMock(spec=EndUser) - - # Mock import to succeed but validation to fail - mock_repository_class = MagicMock() - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object( - DifyCoreRepositoryFactory, - "_validate_repository_interface", - side_effect=RepositoryImportError("Interface validation failed"), - ), - ): - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=mock_session_factory, - user=mock_user, - app_id="test-app-id", - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - assert "Interface validation failed" in str(exc_info.value) + assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) @patch("core.repositories.factory.dify_config") def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config): """Test WorkflowNodeExecutionRepository creation with instantiation error.""" # Setup mock configuration - mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + mock_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" mock_session_factory = MagicMock(spec=sessionmaker) mock_user = MagicMock(spec=EndUser) - # Mock import and validation to succeed but instantiation to fail - mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) - with ( - patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), - patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), - ): + # Create a mock repository class that raises exception on instantiation + mock_repository_class = MagicMock() + mock_repository_class.side_effect = Exception("Instantiation failed") + + # Mock import_string to return a failing class + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): with pytest.raises(RepositoryImportError) as exc_info: DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) - def test_validate_repository_interface_with_private_methods(self): - """Test interface validation ignores private methods.""" + def test_repository_import_error_exception(self): + """Test RepositoryImportError exception handling.""" + error_message = "Custom error message" + error = RepositoryImportError(error_message) + assert str(error) == error_message - # Create a mock class with private methods - class MockRepository: - def save(self): - pass + @patch("core.repositories.factory.dify_config") + def test_create_with_engine_instead_of_sessionmaker(self, mock_config): + """Test repository creation with Engine instead of sessionmaker.""" + # Setup mock configuration + mock_config.CORE_WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" - def get_by_id(self): - pass + # Create mock dependencies using Engine instead of sessionmaker + mock_engine = MagicMock(spec=Engine) + mock_user = MagicMock(spec=Account) + app_id = "test-app-id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN - def _private_method(self): - pass + # Create mock repository class and instance + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) + mock_repository_class.return_value = mock_repository_instance - # Create a mock interface with private methods - class MockInterface: - def save(self): - pass - - def get_by_id(self): - pass - - def _private_method(self): - pass - - # Should not raise an exception (private methods are ignored) - DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) - - def test_validate_constructor_signature_with_extra_params(self): - """Test constructor validation with extra parameters (should pass).""" - - class MockRepository: - def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None): - pass - - # Should not raise an exception (extra parameters are allowed) - DifyCoreRepositoryFactory._validate_constructor_signature( - MockRepository, ["session_factory", "user", "app_id", "triggered_from"] - ) - - def test_validate_constructor_signature_with_kwargs(self): - """Test constructor validation with **kwargs (current implementation doesn't support this).""" - - class MockRepository: - def __init__(self, session_factory, user, **kwargs): - pass - - # Current implementation doesn't handle **kwargs, so this should raise an exception - with pytest.raises(RepositoryImportError) as exc_info: - DifyCoreRepositoryFactory._validate_constructor_signature( - MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + # Mock import_string + with patch("core.repositories.factory.import_string", return_value=mock_repository_class): + result = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_engine, # Using Engine instead of sessionmaker + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, ) - assert "does not accept required parameters" in str(exc_info.value) - assert "app_id" in str(exc_info.value) - assert "triggered_from" in str(exc_info.value) + + # Verify the repository was created with correct parameters + mock_repository_class.assert_called_once_with( + session_factory=mock_engine, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + assert result is mock_repository_instance diff --git a/api/tests/unit_tests/core/tools/utils/test_encryption.py b/api/tests/unit_tests/core/tools/utils/test_encryption.py new file mode 100644 index 0000000000..6425ab0b8d --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_encryption.py @@ -0,0 +1,181 @@ +import copy +from unittest.mock import patch + +import pytest + +from core.entities.provider_entities import BasicProviderConfig +from core.tools.utils.encryption import ProviderConfigEncrypter + + +# --------------------------- +# A no-op cache +# --------------------------- +class NoopCache: + """Simple cache stub: always returns None, does nothing for set/delete.""" + + def get(self): + return None + + def set(self, config): + pass + + def delete(self): + pass + + +@pytest.fixture +def secret_field() -> BasicProviderConfig: + """A SECRET_INPUT field named 'password'.""" + return BasicProviderConfig( + name="password", + type=BasicProviderConfig.Type.SECRET_INPUT, + ) + + +@pytest.fixture +def normal_field() -> BasicProviderConfig: + """A TEXT_INPUT field named 'username'.""" + return BasicProviderConfig( + name="username", + type=BasicProviderConfig.Type.TEXT_INPUT, + ) + + +@pytest.fixture +def encrypter_obj(secret_field, normal_field): + """ + Build ProviderConfigEncrypter with: + - tenant_id = tenant123 + - one secret field (password) and one normal field (username) + - NoopCache as cache + """ + return ProviderConfigEncrypter( + tenant_id="tenant123", + config=[secret_field, normal_field], + provider_config_cache=NoopCache(), + ) + + +# ============================================================ +# ProviderConfigEncrypter.encrypt() +# ============================================================ + + +def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj): + """ + Secret field should be encrypted, non-secret field unchanged. + Verify encrypt_token called only for secret field. + Also check deep copy (input not modified). + """ + data_in = {"username": "alice", "password": "plain_pwd"} + data_copy = copy.deepcopy(data_in) + + with patch("core.tools.utils.encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt: + out = encrypter_obj.encrypt(data_in) + + assert out["username"] == "alice" + assert out["password"] == "CIPHERTEXT" + mock_encrypt.assert_called_once_with("tenant123", "plain_pwd") + assert data_in == data_copy # deep copy semantics + + +def test_encrypt_missing_secret_key_is_ok(encrypter_obj): + """If secret field missing in input, no error and no encryption called.""" + with patch("core.tools.utils.encryption.encrypter.encrypt_token") as mock_encrypt: + out = encrypter_obj.encrypt({"username": "alice"}) + assert out["username"] == "alice" + mock_encrypt.assert_not_called() + + +# ============================================================ +# ProviderConfigEncrypter.mask_tool_credentials() +# ============================================================ + + +@pytest.mark.parametrize( + ("raw", "prefix", "suffix"), + [ + ("longsecret", "lo", "et"), + ("abcdefg", "ab", "fg"), + ("1234567", "12", "67"), + ], +) +def test_mask_tool_credentials_long_secret(encrypter_obj, raw, prefix, suffix): + """ + For length > 6: keep first 2 and last 2, mask middle with '*'. + """ + data_in = {"username": "alice", "password": raw} + data_copy = copy.deepcopy(data_in) + + out = encrypter_obj.mask_tool_credentials(data_in) + masked = out["password"] + + assert masked.startswith(prefix) + assert masked.endswith(suffix) + assert "*" in masked + assert len(masked) == len(raw) + assert data_in == data_copy # deep copy semantics + + +@pytest.mark.parametrize("raw", ["", "1", "12", "123", "123456"]) +def test_mask_tool_credentials_short_secret(encrypter_obj, raw): + """ + For length <= 6: fully mask with '*' of same length. + """ + out = encrypter_obj.mask_tool_credentials({"password": raw}) + assert out["password"] == ("*" * len(raw)) + + +def test_mask_tool_credentials_missing_key_noop(encrypter_obj): + """If secret key missing, leave other fields unchanged.""" + data_in = {"username": "alice"} + data_copy = copy.deepcopy(data_in) + + out = encrypter_obj.mask_tool_credentials(data_in) + assert out["username"] == "alice" + assert data_in == data_copy + + +# ============================================================ +# ProviderConfigEncrypter.decrypt() +# ============================================================ + + +def test_decrypt_normal_flow(encrypter_obj): + """ + Normal decrypt flow: + - decrypt_token called for secret field + - secret replaced with decrypted value + - non-secret unchanged + """ + data_in = {"username": "alice", "password": "ENC"} + data_copy = copy.deepcopy(data_in) + + with patch("core.tools.utils.encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt: + out = encrypter_obj.decrypt(data_in) + + assert out["username"] == "alice" + assert out["password"] == "PLAIN" + mock_decrypt.assert_called_once_with("tenant123", "ENC") + assert data_in == data_copy # deep copy semantics + + +@pytest.mark.parametrize("empty_val", ["", None]) +def test_decrypt_skip_empty_values(encrypter_obj, empty_val): + """Skip decrypt if value is empty or None, keep original.""" + with patch("core.tools.utils.encryption.encrypter.decrypt_token") as mock_decrypt: + out = encrypter_obj.decrypt({"password": empty_val}) + + mock_decrypt.assert_not_called() + assert out["password"] == empty_val + + +def test_decrypt_swallow_exception_and_keep_original(encrypter_obj): + """ + If decrypt_token raises, exception should be swallowed, + and original value preserved. + """ + with patch("core.tools.utils.encryption.encrypter.decrypt_token", side_effect=Exception("boom")): + out = encrypter_obj.decrypt({"password": "ENC_ERR"}) + + assert out["password"] == "ENC_ERR" diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index 8e07293ce0..e1eab21ca4 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -54,3 +54,58 @@ def test_parse_openapi_to_tool_bundle_operation_id(app): assert tool_bundles[0].operation_id == "_get" assert tool_bundles[1].operation_id == "apiresources_get" assert tool_bundles[2].operation_id == "createResource" + + +def test_parse_openapi_to_tool_bundle_properties_all_of(app): + openapi = { + "openapi": "3.0.0", + "info": {"title": "Simple API", "version": "1.0.0"}, + "servers": [{"url": "http://localhost:3000"}], + "paths": { + "/api/resource": { + "get": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Request", + }, + }, + }, + "required": True, + }, + }, + }, + }, + "components": { + "schemas": { + "Request": { + "type": "object", + "properties": { + "prop1": { + "enum": ["option1"], + "description": "desc prop1", + "allOf": [ + {"$ref": "#/components/schemas/AllOfItem"}, + { + "enum": ["option2"], + }, + ], + }, + }, + }, + "AllOfItem": { + "type": "string", + "enum": ["option3"], + "description": "desc allOf item", + }, + } + }, + } + with app.test_request_context(): + tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + assert tool_bundles[0].parameters[0].type == "string" + assert tool_bundles[0].parameters[0].llm_description == "desc prop1" + # TODO: support enum in OpenAPI + # assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"} diff --git a/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py b/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py new file mode 100644 index 0000000000..4029edfb68 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_tool_engine_serialization.py @@ -0,0 +1,481 @@ +import json +from datetime import date, datetime +from decimal import Decimal +from uuid import uuid4 + +import numpy as np +import pytest +import pytz + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_dict, safe_json_value + + +class TestSafeJsonValue: + """Test suite for safe_json_value function to ensure proper serialization of complex types""" + + def test_datetime_conversion(self): + """Test datetime conversion with timezone handling""" + # Test datetime with UTC timezone + dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC) + result = safe_json_value(dt) + assert isinstance(result, str) + assert "2024-01-01T12:00:00+00:00" in result + + # Test datetime without timezone (should default to UTC) + dt_no_tz = datetime(2024, 1, 1, 12, 0, 0) + result = safe_json_value(dt_no_tz) + assert isinstance(result, str) + # The exact time will depend on the system's timezone, so we check the format + assert "T" in result # ISO format separator + # Check that it's a valid ISO format datetime string + assert len(result) >= 19 # At least YYYY-MM-DDTHH:MM:SS + + def test_date_conversion(self): + """Test date conversion to ISO format""" + test_date = date(2024, 1, 1) + result = safe_json_value(test_date) + assert result == "2024-01-01" + + def test_uuid_conversion(self): + """Test UUID conversion to string""" + test_uuid = uuid4() + result = safe_json_value(test_uuid) + assert isinstance(result, str) + assert result == str(test_uuid) + + def test_decimal_conversion(self): + """Test Decimal conversion to float""" + test_decimal = Decimal("123.456") + result = safe_json_value(test_decimal) + assert result == 123.456 + assert isinstance(result, float) + + def test_bytes_conversion(self): + """Test bytes conversion with UTF-8 decoding""" + # Test valid UTF-8 bytes + test_bytes = b"Hello, World!" + result = safe_json_value(test_bytes) + assert result == "Hello, World!" + + # Test invalid UTF-8 bytes (should fall back to hex) + invalid_bytes = b"\xff\xfe\xfd" + result = safe_json_value(invalid_bytes) + assert result == "fffefd" + + def test_memoryview_conversion(self): + """Test memoryview conversion to hex string""" + test_bytes = b"test data" + test_memoryview = memoryview(test_bytes) + result = safe_json_value(test_memoryview) + assert result == "746573742064617461" # hex of "test data" + + def test_numpy_ndarray_conversion(self): + """Test numpy ndarray conversion to list""" + # Test 1D array + test_array = np.array([1, 2, 3, 4]) + result = safe_json_value(test_array) + assert result == [1, 2, 3, 4] + + # Test 2D array + test_2d_array = np.array([[1, 2], [3, 4]]) + result = safe_json_value(test_2d_array) + assert result == [[1, 2], [3, 4]] + + # Test array with float values + test_float_array = np.array([1.5, 2.7, 3.14]) + result = safe_json_value(test_float_array) + assert result == [1.5, 2.7, 3.14] + + def test_dict_conversion(self): + """Test dictionary conversion using safe_json_dict""" + test_dict = { + "string": "value", + "number": 42, + "float": 3.14, + "boolean": True, + "list": [1, 2, 3], + "nested": {"key": "value"}, + } + result = safe_json_value(test_dict) + assert isinstance(result, dict) + assert result == test_dict + + def test_list_conversion(self): + """Test list conversion with mixed types""" + test_list = [ + "string", + 42, + 3.14, + True, + [1, 2, 3], + {"key": "value"}, + datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + Decimal("123.456"), + uuid4(), + ] + result = safe_json_value(test_list) + assert isinstance(result, list) + assert len(result) == len(test_list) + assert isinstance(result[6], str) # datetime should be converted to string + assert isinstance(result[7], float) # Decimal should be converted to float + assert isinstance(result[8], str) # UUID should be converted to string + + def test_tuple_conversion(self): + """Test tuple conversion to list""" + test_tuple = (1, "string", 3.14) + result = safe_json_value(test_tuple) + assert isinstance(result, list) + assert result == [1, "string", 3.14] + + def test_set_conversion(self): + """Test set conversion to list""" + test_set = {1, "string", 3.14} + result = safe_json_value(test_set) + assert isinstance(result, list) + # Note: set order is not guaranteed, so we check length and content + assert len(result) == 3 + assert 1 in result + assert "string" in result + assert 3.14 in result + + def test_basic_types_passthrough(self): + """Test that basic types are passed through unchanged""" + assert safe_json_value("string") == "string" + assert safe_json_value(42) == 42 + assert safe_json_value(3.14) == 3.14 + assert safe_json_value(True) is True + assert safe_json_value(False) is False + assert safe_json_value(None) is None + + def test_nested_complex_structure(self): + """Test complex nested structure with all types""" + complex_data = { + "dates": [date(2024, 1, 1), date(2024, 1, 2)], + "timestamps": [ + datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + datetime(2024, 1, 2, 12, 0, 0, tzinfo=pytz.UTC), + ], + "numbers": [Decimal("123.456"), Decimal("789.012")], + "identifiers": [uuid4(), uuid4()], + "binary_data": [b"hello", b"world"], + "arrays": [np.array([1, 2, 3]), np.array([4, 5, 6])], + } + + result = safe_json_value(complex_data) + + # Verify structure is maintained + assert isinstance(result, dict) + assert "dates" in result + assert "timestamps" in result + assert "numbers" in result + assert "identifiers" in result + assert "binary_data" in result + assert "arrays" in result + + # Verify conversions + assert all(isinstance(d, str) for d in result["dates"]) + assert all(isinstance(t, str) for t in result["timestamps"]) + assert all(isinstance(n, float) for n in result["numbers"]) + assert all(isinstance(i, str) for i in result["identifiers"]) + assert all(isinstance(b, str) for b in result["binary_data"]) + assert all(isinstance(a, list) for a in result["arrays"]) + + +class TestSafeJsonDict: + """Test suite for safe_json_dict function""" + + def test_valid_dict_conversion(self): + """Test valid dictionary conversion""" + test_dict = { + "string": "value", + "number": 42, + "datetime": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + "decimal": Decimal("123.456"), + } + result = safe_json_dict(test_dict) + assert isinstance(result, dict) + assert result["string"] == "value" + assert result["number"] == 42 + assert isinstance(result["datetime"], str) + assert isinstance(result["decimal"], float) + + def test_invalid_input_type(self): + """Test that invalid input types raise TypeError""" + with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"): + safe_json_dict("not a dict") + + with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"): + safe_json_dict([1, 2, 3]) + + with pytest.raises(TypeError, match="safe_json_dict\\(\\) expects a dictionary \\(dict\\) as input"): + safe_json_dict(42) + + def test_empty_dict(self): + """Test empty dictionary handling""" + result = safe_json_dict({}) + assert result == {} + + def test_nested_dict_conversion(self): + """Test nested dictionary conversion""" + test_dict = { + "level1": { + "level2": {"datetime": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), "decimal": Decimal("123.456")} + } + } + result = safe_json_dict(test_dict) + assert isinstance(result["level1"]["level2"]["datetime"], str) + assert isinstance(result["level1"]["level2"]["decimal"], float) + + +class TestToolInvokeMessageJsonSerialization: + """Test suite for ToolInvokeMessage JSON serialization through safe_json_value""" + + def test_json_message_serialization(self): + """Test JSON message serialization with complex data""" + complex_data = { + "timestamp": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + "amount": Decimal("123.45"), + "id": uuid4(), + "binary": b"test data", + "array": np.array([1, 2, 3]), + } + + # Create JSON message + json_message = ToolInvokeMessage.JsonMessage(json_object=complex_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + # Apply safe_json_value transformation + transformed_data = safe_json_value(message.message.json_object) + + # Verify transformations + assert isinstance(transformed_data["timestamp"], str) + assert isinstance(transformed_data["amount"], float) + assert isinstance(transformed_data["id"], str) + assert isinstance(transformed_data["binary"], str) + assert isinstance(transformed_data["array"], list) + + # Verify JSON serialization works + json_string = json.dumps(transformed_data, ensure_ascii=False) + assert isinstance(json_string, str) + + # Verify we can deserialize back + deserialized = json.loads(json_string) + assert deserialized["amount"] == 123.45 + assert deserialized["array"] == [1, 2, 3] + + def test_json_message_with_nested_structures(self): + """Test JSON message with deeply nested complex structures""" + nested_data = { + "level1": { + "level2": { + "level3": { + "dates": [date(2024, 1, 1), date(2024, 1, 2)], + "timestamps": [datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC)], + "numbers": [Decimal("1.1"), Decimal("2.2")], + "arrays": [np.array([1, 2]), np.array([3, 4])], + } + } + } + } + + json_message = ToolInvokeMessage.JsonMessage(json_object=nested_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + # Transform the data + transformed_data = safe_json_value(message.message.json_object) + + # Verify nested transformations + level3 = transformed_data["level1"]["level2"]["level3"] + assert all(isinstance(d, str) for d in level3["dates"]) + assert all(isinstance(t, str) for t in level3["timestamps"]) + assert all(isinstance(n, float) for n in level3["numbers"]) + assert all(isinstance(a, list) for a in level3["arrays"]) + + # Test JSON serialization + json_string = json.dumps(transformed_data, ensure_ascii=False) + assert isinstance(json_string, str) + + # Verify deserialization + deserialized = json.loads(json_string) + assert deserialized["level1"]["level2"]["level3"]["numbers"] == [1.1, 2.2] + + def test_json_message_transformer_integration(self): + """Test integration with ToolFileMessageTransformer for JSON messages""" + complex_data = { + "metadata": { + "created_at": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + "version": Decimal("1.0"), + "tags": ["tag1", "tag2"], + }, + "data": {"values": np.array([1.1, 2.2, 3.3]), "binary": b"binary content"}, + } + + # Create message generator + def message_generator(): + json_message = ToolInvokeMessage.JsonMessage(json_object=complex_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + yield message + + # Transform messages + transformed_messages = list( + ToolFileMessageTransformer.transform_tool_invoke_messages( + message_generator(), user_id="test_user", tenant_id="test_tenant" + ) + ) + + assert len(transformed_messages) == 1 + transformed_message = transformed_messages[0] + assert transformed_message.type == ToolInvokeMessage.MessageType.JSON + + # Verify the JSON object was transformed + json_obj = transformed_message.message.json_object + assert isinstance(json_obj["metadata"]["created_at"], str) + assert isinstance(json_obj["metadata"]["version"], float) + assert isinstance(json_obj["data"]["values"], list) + assert isinstance(json_obj["data"]["binary"], str) + + # Test final JSON serialization + final_json = json.dumps(json_obj, ensure_ascii=False) + assert isinstance(final_json, str) + + # Verify we can deserialize + deserialized = json.loads(final_json) + assert deserialized["metadata"]["version"] == 1.0 + assert deserialized["data"]["values"] == [1.1, 2.2, 3.3] + + def test_edge_cases_and_error_handling(self): + """Test edge cases and error handling in JSON serialization""" + # Test with None values + data_with_none = {"null_value": None, "empty_string": "", "zero": 0, "false_value": False} + + json_message = ToolInvokeMessage.JsonMessage(json_object=data_with_none) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + transformed_data = safe_json_value(message.message.json_object) + json_string = json.dumps(transformed_data, ensure_ascii=False) + + # Verify serialization works with edge cases + assert json_string is not None + deserialized = json.loads(json_string) + assert deserialized["null_value"] is None + assert deserialized["empty_string"] == "" + assert deserialized["zero"] == 0 + assert deserialized["false_value"] is False + + # Test with very large numbers + large_data = { + "large_int": 2**63 - 1, + "large_float": 1.7976931348623157e308, + "small_float": 2.2250738585072014e-308, + } + + json_message = ToolInvokeMessage.JsonMessage(json_object=large_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + transformed_data = safe_json_value(message.message.json_object) + json_string = json.dumps(transformed_data, ensure_ascii=False) + + # Verify large numbers are handled correctly + deserialized = json.loads(json_string) + assert deserialized["large_int"] == 2**63 - 1 + assert deserialized["large_float"] == 1.7976931348623157e308 + assert deserialized["small_float"] == 2.2250738585072014e-308 + + +class TestEndToEndSerialization: + """Test suite for end-to-end serialization workflow""" + + def test_complete_workflow_with_real_data(self): + """Test complete workflow from complex data to JSON string and back""" + # Simulate real-world complex data structure + real_world_data = { + "user_profile": { + "id": uuid4(), + "name": "John Doe", + "email": "john@example.com", + "created_at": datetime(2024, 1, 1, 12, 0, 0, tzinfo=pytz.UTC), + "last_login": datetime(2024, 1, 15, 14, 30, 0, tzinfo=pytz.UTC), + "preferences": {"theme": "dark", "language": "en", "timezone": "UTC"}, + }, + "analytics": { + "session_count": 42, + "total_time": Decimal("123.45"), + "metrics": np.array([1.1, 2.2, 3.3, 4.4, 5.5]), + "events": [ + { + "timestamp": datetime(2024, 1, 1, 10, 0, 0, tzinfo=pytz.UTC), + "action": "login", + "duration": Decimal("5.67"), + }, + { + "timestamp": datetime(2024, 1, 1, 11, 0, 0, tzinfo=pytz.UTC), + "action": "logout", + "duration": Decimal("3600.0"), + }, + ], + }, + "files": [ + { + "id": uuid4(), + "name": "document.pdf", + "size": 1024, + "uploaded_at": datetime(2024, 1, 1, 9, 0, 0, tzinfo=pytz.UTC), + "checksum": b"abc123def456", + } + ], + } + + # Step 1: Create ToolInvokeMessage + json_message = ToolInvokeMessage.JsonMessage(json_object=real_world_data) + message = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=json_message) + + # Step 2: Apply safe_json_value transformation + transformed_data = safe_json_value(message.message.json_object) + + # Step 3: Serialize to JSON string + json_string = json.dumps(transformed_data, ensure_ascii=False) + + # Step 4: Verify the string is valid JSON + assert isinstance(json_string, str) + assert json_string.startswith("{") + assert json_string.endswith("}") + + # Step 5: Deserialize back to Python object + deserialized_data = json.loads(json_string) + + # Step 6: Verify data integrity + assert deserialized_data["user_profile"]["name"] == "John Doe" + assert deserialized_data["user_profile"]["email"] == "john@example.com" + assert isinstance(deserialized_data["user_profile"]["created_at"], str) + assert isinstance(deserialized_data["analytics"]["total_time"], float) + assert deserialized_data["analytics"]["total_time"] == 123.45 + assert isinstance(deserialized_data["analytics"]["metrics"], list) + assert deserialized_data["analytics"]["metrics"] == [1.1, 2.2, 3.3, 4.4, 5.5] + assert isinstance(deserialized_data["files"][0]["checksum"], str) + + # Step 7: Verify all complex types were properly converted + self._verify_all_complex_types_converted(deserialized_data) + + def _verify_all_complex_types_converted(self, data): + """Helper method to verify all complex types were properly converted""" + if isinstance(data, dict): + for key, value in data.items(): + if key in ["id", "checksum"]: + # These should be strings (UUID/bytes converted) + assert isinstance(value, str) + elif key in ["created_at", "last_login", "timestamp", "uploaded_at"]: + # These should be strings (datetime converted) + assert isinstance(value, str) + elif key in ["total_time", "duration"]: + # These should be floats (Decimal converted) + assert isinstance(value, float) + elif key == "metrics": + # This should be a list (ndarray converted) + assert isinstance(value, list) + else: + # Recursively check nested structures + self._verify_all_complex_types_converted(value) + elif isinstance(data, list): + for item in data: + self._verify_all_complex_types_converted(item) diff --git a/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py new file mode 100644 index 0000000000..20f753786d --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_web_reader_tool.py @@ -0,0 +1,312 @@ +import pytest + +from core.tools.utils.web_reader_tool import ( + extract_using_readabilipy, + get_image_upload_file_ids, + get_url, + page_result, +) + + +class FakeResponse: + """Minimal fake response object for ssrf_proxy / cloudscraper.""" + + def __init__(self, *, status_code=200, headers=None, content=b"", text=""): + self.status_code = status_code + self.headers = headers or {} + self.content = content + self.text = text if text else content.decode("utf-8", errors="ignore") + + +# --------------------------- +# Tests: page_result +# --------------------------- +@pytest.mark.parametrize( + ("text", "cursor", "maxlen", "expected"), + [ + ("abcdef", 0, 3, "abc"), + ("abcdef", 2, 10, "cdef"), # maxlen beyond end + ("abcdef", 6, 5, ""), # cursor at end + ("abcdef", 7, 5, ""), # cursor beyond end + ("", 0, 5, ""), # empty text + ], +) +def test_page_result(text, cursor, maxlen, expected): + assert page_result(text, cursor, maxlen) == expected + + +# --------------------------- +# Tests: get_url +# --------------------------- +@pytest.fixture +def stub_support_types(monkeypatch): + """Stub supported content types list.""" + import core.tools.utils.web_reader_tool as mod + + # e.g. binary types supported by ExtractProcessor + monkeypatch.setattr(mod.extract_processor, "SUPPORT_URL_CONTENT_TYPES", ["application/pdf", "text/plain"]) + return mod + + +def test_get_url_unsupported_content_type(monkeypatch, stub_support_types): + # HEAD 200 but content-type not supported and not text/html + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse( + status_code=200, + headers={"Content-Type": "image/png"}, # not supported + ) + + monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head) + + result = get_url("https://x.test/file.png") + assert result == "Unsupported content-type [image/png] of URL." + + +def test_get_url_supported_binary_type_uses_extract_processor(monkeypatch, stub_support_types): + """ + When content-type is in SUPPORT_URL_CONTENT_TYPES, + should call ExtractProcessor.load_from_url and return its text. + """ + calls = {"load": 0} + + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse( + status_code=200, + headers={"Content-Type": "application/pdf"}, + ) + + def fake_load_from_url(url, return_text=False): + calls["load"] += 1 + assert return_text is True + return "PDF extracted text" + + monkeypatch.setattr(stub_support_types.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(stub_support_types.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url)) + + result = get_url("https://x.test/doc.pdf") + assert calls["load"] == 1 + assert result == "PDF extracted text" + + +def test_get_url_html_flow_with_chardet_and_readability(monkeypatch, stub_support_types): + """200 + text/html → GET, chardet detects encoding, readability returns article which is templated.""" + + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}) + + def fake_get(url, headers=None, follow_redirects=True, timeout=None): + html = b"xhello" + return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=html) + + # chardet.detect returns utf-8 + import core.tools.utils.web_reader_tool as mod + + monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get) + monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"}) + + # readability → a dict that maps to Article, then FULL_TEMPLATE + def fake_simple_json_from_html_string(html, use_readability=True): + return { + "title": "My Title", + "byline": "Bob", + "plain_text": [{"type": "text", "text": "Hello world"}], + } + + monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string) + + out = get_url("https://x.test/page") + assert "TITLE: My Title" in out + assert "AUTHOR: Bob" in out + assert "Hello world" in out + + +def test_get_url_html_flow_empty_article_text_returns_empty(monkeypatch, stub_support_types): + """If readability returns no text, should return empty string.""" + + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}) + + def fake_get(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=b"") + + import core.tools.utils.web_reader_tool as mod + + monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get) + monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"}) + # readability returns empty plain_text + monkeypatch.setattr(mod, "simple_json_from_html_string", lambda html, use_readability=True: {"plain_text": []}) + + out = get_url("https://x.test/empty") + assert out == "" + + +def test_get_url_403_cloudscraper_fallback(monkeypatch, stub_support_types): + """HEAD 403 → use cloudscraper.get via ssrf_proxy.make_request, then proceed.""" + + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse(status_code=403, headers={}) + + # cloudscraper.create_scraper() → object with .get() + class FakeScraper: + def __init__(self): + pass # removed unused attribute + + def get(self, url, headers=None, follow_redirects=True, timeout=None): + # mimic html 200 + html = b"hi" + return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}, content=html) + + import core.tools.utils.web_reader_tool as mod + + monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(mod.cloudscraper, "create_scraper", lambda: FakeScraper()) + monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"}) + monkeypatch.setattr( + mod, + "simple_json_from_html_string", + lambda html, use_readability=True: {"title": "T", "byline": "A", "plain_text": [{"type": "text", "text": "X"}]}, + ) + + out = get_url("https://x.test/403") + assert "TITLE: T" in out + assert "AUTHOR: A" in out + assert "X" in out + + +def test_get_url_head_non_200_returns_status(monkeypatch, stub_support_types): + """HEAD returns non-200 and non-403 → should directly return code message.""" + + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse(status_code=500) + + import core.tools.utils.web_reader_tool as mod + + monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) + + out = get_url("https://x.test/fail") + assert out == "URL returned status code 500." + + +def test_get_url_content_disposition_filename_detection(monkeypatch, stub_support_types): + """ + If HEAD 200 with no Content-Type but Content-Disposition filename suggests a supported type, + it should route to ExtractProcessor.load_from_url. + """ + calls = {"load": 0} + + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse(status_code=200, headers={"Content-Disposition": 'attachment; filename="doc.pdf"'}) + + def fake_load_from_url(url, return_text=False): + calls["load"] += 1 + return "From ExtractProcessor via filename" + + import core.tools.utils.web_reader_tool as mod + + monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(mod.ExtractProcessor, "load_from_url", staticmethod(fake_load_from_url)) + + out = get_url("https://x.test/fname") + assert calls["load"] == 1 + assert out == "From ExtractProcessor via filename" + + +def test_get_url_html_encoding_fallback_when_decode_fails(monkeypatch, stub_support_types): + """ + If chardet returns an encoding but content.decode raises, should fallback to response.text. + """ + + def fake_head(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse(status_code=200, headers={"Content-Type": "text/html"}) + + # Return bytes that will raise with the chosen encoding + def fake_get(url, headers=None, follow_redirects=True, timeout=None): + return FakeResponse( + status_code=200, + headers={"Content-Type": "text/html"}, + content=b"\xff\xfe\xfa", # likely to fail under utf-8 + text="fallback text", + ) + + import core.tools.utils.web_reader_tool as mod + + monkeypatch.setattr(mod.ssrf_proxy, "head", fake_head) + monkeypatch.setattr(mod.ssrf_proxy, "get", fake_get) + monkeypatch.setattr(mod.chardet, "detect", lambda b: {"encoding": "utf-8"}) + monkeypatch.setattr( + mod, + "simple_json_from_html_string", + lambda html, use_readability=True: {"title": "", "byline": "", "plain_text": [{"type": "text", "text": "ok"}]}, + ) + + out = get_url("https://x.test/enc-fallback") + assert "ok" in out + + +# --------------------------- +# Tests: extract_using_readabilipy +# --------------------------- + + +def test_extract_using_readabilipy_field_mapping_and_defaults(monkeypatch): + # stub readabilipy.simple_json_from_html_string + def fake_simple_json_from_html_string(html, use_readability=True): + return { + "title": "Hello", + "byline": "Alice", + "plain_text": [{"type": "text", "text": "world"}], + } + + import core.tools.utils.web_reader_tool as mod + + monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string) + + article = extract_using_readabilipy("...") + assert article.title == "Hello" + assert article.author == "Alice" + assert isinstance(article.text, list) + assert article.text + assert article.text[0]["text"] == "world" + + +def test_extract_using_readabilipy_defaults_when_missing(monkeypatch): + def fake_simple_json_from_html_string(html, use_readability=True): + return {} # all missing + + import core.tools.utils.web_reader_tool as mod + + monkeypatch.setattr(mod, "simple_json_from_html_string", fake_simple_json_from_html_string) + + article = extract_using_readabilipy("...") + assert article.title == "" + assert article.author == "" + assert article.text == [] + + +# --------------------------- +# Tests: get_image_upload_file_ids +# --------------------------- +def test_get_image_upload_file_ids(): + # should extract id from https + file-preview + content = "![image](https://example.com/a/b/files/abc123/file-preview)" + assert get_image_upload_file_ids(content) == ["abc123"] + + # should extract id from http + image-preview + content = "![image](http://host/files/xyz789/image-preview)" + assert get_image_upload_file_ids(content) == ["xyz789"] + + # should not match invalid scheme 'htt://' + content = "![image](htt://host/files/bad/file-preview)" + assert get_image_upload_file_ids(content) == [] + + # should extract multiple ids in order + content = """ + some text + ![image](https://h/files/id1/file-preview) + middle + ![image](http://h/files/id2/image-preview) + end + """ + assert get_image_upload_file_ids(content) == ["id1", "id2"] diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index 64d0d8c7e7..b33a83ba77 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,4 +1,4 @@ -from core.variables.types import SegmentType +from core.variables.types import ArrayValidation, SegmentType class TestSegmentTypeIsArrayType: @@ -17,7 +17,6 @@ class TestSegmentTypeIsArrayType: value is tested for the is_array_type method. """ # Arrange - all_segment_types = set(SegmentType) expected_array_types = [ SegmentType.ARRAY_ANY, SegmentType.ARRAY_STRING, @@ -58,3 +57,27 @@ class TestSegmentTypeIsArrayType: for seg_type in enum_values: is_array = seg_type.is_array_type() assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}" + + +class TestSegmentTypeIsValidArrayValidation: + """ + Test SegmentType.is_valid with array types using different validation strategies. + """ + + def test_array_validation_all_success(self): + value = ["hello", "world", "foo"] + assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + + def test_array_validation_all_fail(self): + value = ["hello", 123, "world"] + # Should return False, since 123 is not a string + assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + + def test_array_validation_first(self): + value = ["hello", 123, None] + assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST) + + def test_array_validation_none(self): + value = [1, 2, 3] + # validation is None, skip + assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index 137e8b889d..8b1b9a55bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -1,6 +1,5 @@ import uuid from collections.abc import Generator -from datetime import UTC, datetime from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( @@ -15,6 +14,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData from core.workflow.system_variable import SystemVariable +from libs.datetime_utils import naive_utc_now def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: @@ -29,7 +29,7 @@ def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngine def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: - route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) + route_node_state = RouteNodeState(node_id=next_node_id, start_at=naive_utc_now()) parallel_id = graph.node_parallel_mapping.get(next_node_id) parallel_start_node_id = None @@ -68,7 +68,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve ) route_node_state.status = RouteNodeState.Status.SUCCESS - route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None) + route_node_state.finished_at = naive_utc_now() yield NodeRunSucceededEvent( id=node_execution_id, node_id=next_node_id, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index bb6d72f51e..8b5a82fcbb 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -49,7 +49,7 @@ def test_executor_with_json_body_and_number_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == [] + assert executor.params is None assert executor.json == {"number": 42} assert executor.data is None assert executor.files is None @@ -102,7 +102,7 @@ def test_executor_with_json_body_and_object_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == [] + assert executor.params is None assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} assert executor.data is None assert executor.files is None @@ -157,7 +157,7 @@ def test_executor_with_json_body_and_nested_object_variable(): assert executor.method == "post" assert executor.url == "https://api.example.com/data" assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == [] + assert executor.params is None assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} assert executor.data is None assert executor.files is None @@ -243,15 +243,18 @@ def test_executor_with_form_data(): # Check the executor's data assert executor.method == "post" assert executor.url == "https://api.example.com/upload" - assert "Content-Type" in executor.headers - assert "multipart/form-data" in executor.headers["Content-Type"] - assert executor.params == [] + assert executor.params is None assert executor.json is None # '__multipart_placeholder__' is expected when no file inputs exist, # to ensure the request is treated as multipart/form-data by the backend. assert executor.files == [("__multipart_placeholder__", ("", b"", "application/octet-stream"))] assert executor.content is None + # After fix for #23829: When placeholder files exist, Content-Type is removed + # to let httpx handle Content-Type and boundary automatically + headers = executor._assembling_headers() + assert "Content-Type" not in headers or "multipart/form-data" not in headers.get("Content-Type", "") + # Check that the form data is correctly loaded in executor.data assert isinstance(executor.data, dict) assert "text_field" in executor.data diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index c65b60cb4d..c0330b9441 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -69,8 +69,12 @@ def test_get_file_attribute(pool, file): def test_use_long_selector(pool): - pool.add(("node_1", "part_1", "part_2"), StringSegment(value="test_value")) + # The add method now only accepts 2-element selectors (node_id, variable_name) + # Store nested data as an ObjectSegment instead + nested_data = {"part_2": "test_value"} + pool.add(("node_1", "part_1"), ObjectSegment(value=nested_data)) + # The get method supports longer selectors for nested access result = pool.get(("node_1", "part_1", "part_2")) assert result is not None assert result.value == "test_value" @@ -280,8 +284,10 @@ class TestVariablePoolSerialization: pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file])) pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}])) - # Add nested variables - pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value")) + # Add nested variables as ObjectSegment + # The add method only accepts 2-element selectors + nested_obj = {"deep": {"var": "deep_value"}} + pool.add((self._NODE3_ID, "nested"), ObjectSegment(value=nested_obj)) def test_system_variables(self): sys_vars = SystemVariable( diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 4866db1fdb..1d2eba1e71 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -1,5 +1,4 @@ import json -from datetime import UTC, datetime from unittest.mock import MagicMock import pytest @@ -23,6 +22,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager +from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import Workflow, WorkflowRun @@ -145,8 +145,8 @@ def real_workflow(): workflow.graph = json.dumps(graph_data) workflow.features = json.dumps({"file_upload": {"enabled": False}}) workflow.created_by = "test-user-id" - workflow.created_at = datetime.now(UTC).replace(tzinfo=None) - workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.created_at = naive_utc_now() + workflow.updated_at = naive_utc_now() workflow._environment_variables = "{}" workflow._conversation_variables = "{}" @@ -169,7 +169,7 @@ def real_workflow_run(): workflow_run.outputs = json.dumps({"answer": "test answer"}) workflow_run.created_by_role = CreatorUserRole.ACCOUNT workflow_run.created_by = "test-user-id" - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run.created_at = naive_utc_now() return workflow_run @@ -211,7 +211,7 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -245,7 +245,7 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -282,7 +282,7 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -335,7 +335,7 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -366,7 +366,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): event.process_data = {"process": "test process"} event.outputs = {"output": "test output"} event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100} - event.start_at = datetime.now(UTC).replace(tzinfo=None) + event.start_at = naive_utc_now() # Create a real node execution @@ -379,7 +379,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): node_id="test-node-id", node_type=NodeType.LLM, title="Test Node", - created_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), ) # Pre-populate the cache with the node execution @@ -409,7 +409,7 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl workflow_type=WorkflowType.CHAT, graph={"nodes": [], "edges": []}, inputs={"query": "test query"}, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) # Pre-populate the cache with the workflow execution @@ -443,7 +443,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager): event.process_data = {"process": "test process"} event.outputs = {"output": "test output"} event.execution_metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100} - event.start_at = datetime.now(UTC).replace(tzinfo=None) + event.start_at = naive_utc_now() event.error = "Test error message" # Create a real node execution @@ -457,7 +457,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager): node_id="test-node-id", node_type=NodeType.LLM, title="Test Node", - created_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), ) # Pre-populate the cache with the node execution diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py deleted file mode 100644 index 54bf6558bf..0000000000 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Any - -from core.variables.segments import ObjectSegment, StringSegment -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.utils.variable_utils import append_variables_recursively - - -class TestAppendVariablesRecursively: - """Test cases for append_variables_recursively function""" - - def test_append_simple_dict_value(self): - """Test appending a simple dictionary value""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["output"] - variable_value = {"name": "John", "age": 30} - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert main_var.value == variable_value - - # Check that nested variables are added recursively - name_var = pool.get([node_id] + variable_key_list + ["name"]) - assert name_var is not None - assert name_var.value == "John" - - age_var = pool.get([node_id] + variable_key_list + ["age"]) - assert age_var is not None - assert age_var.value == 30 - - def test_append_object_segment_value(self): - """Test appending an ObjectSegment value""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["result"] - - # Create an ObjectSegment - obj_data = {"status": "success", "code": 200} - variable_value = ObjectSegment(value=obj_data) - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert isinstance(main_var, ObjectSegment) - assert main_var.value == obj_data - - # Check that nested variables are added recursively - status_var = pool.get([node_id] + variable_key_list + ["status"]) - assert status_var is not None - assert status_var.value == "success" - - code_var = pool.get([node_id] + variable_key_list + ["code"]) - assert code_var is not None - assert code_var.value == 200 - - def test_append_nested_dict_value(self): - """Test appending a nested dictionary value""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["data"] - - variable_value = { - "user": { - "profile": {"name": "Alice", "email": "alice@example.com"}, - "settings": {"theme": "dark", "notifications": True}, - }, - "metadata": {"version": "1.0", "timestamp": 1234567890}, - } - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check deeply nested variables - name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"]) - assert name_var is not None - assert name_var.value == "Alice" - - email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"]) - assert email_var is not None - assert email_var.value == "alice@example.com" - - theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"]) - assert theme_var is not None - assert theme_var.value == "dark" - - notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"]) - assert notifications_var is not None - assert notifications_var.value == 1 # Boolean True is converted to integer 1 - - version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"]) - assert version_var is not None - assert version_var.value == "1.0" - - def test_append_non_dict_value(self): - """Test appending a non-dictionary value (should not recurse)""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["simple"] - variable_value = "simple_string" - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that only the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert main_var.value == variable_value - - # Ensure no additional variables are created - assert len(pool.variable_dictionary[node_id]) == 1 - - def test_append_segment_non_object_value(self): - """Test appending a Segment that is not ObjectSegment (should not recurse)""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["text"] - variable_value = StringSegment(value="Hello World") - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that only the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert isinstance(main_var, StringSegment) - assert main_var.value == "Hello World" - - # Ensure no additional variables are created - assert len(pool.variable_dictionary[node_id]) == 1 - - def test_append_empty_dict_value(self): - """Test appending an empty dictionary value""" - pool = VariablePool.empty() - node_id = "test_node" - variable_key_list = ["empty"] - variable_value: dict[str, Any] = {} - - append_variables_recursively(pool, node_id, variable_key_list, variable_value) - - # Check that the main variable is added - main_var = pool.get([node_id] + variable_key_list) - assert main_var is not None - assert main_var.value == {} - - # Ensure only the main variable is created (no recursion for empty dict) - assert len(pool.variable_dictionary[node_id]) == 1 diff --git a/api/tests/unit_tests/extensions/test_celery_ssl.py b/api/tests/unit_tests/extensions/test_celery_ssl.py new file mode 100644 index 0000000000..bc46fe8322 --- /dev/null +++ b/api/tests/unit_tests/extensions/test_celery_ssl.py @@ -0,0 +1,149 @@ +"""Tests for Celery SSL configuration.""" + +import ssl +from unittest.mock import MagicMock, patch + + +class TestCelerySSLConfiguration: + """Test suite for Celery SSL configuration.""" + + def test_get_celery_ssl_options_when_ssl_disabled(self): + """Test SSL options when REDIS_USE_SSL is False.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = False + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is None + + def test_get_celery_ssl_options_when_broker_not_redis(self): + """Test SSL options when broker is not Redis.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "amqp://localhost:5672" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is None + + def test_get_celery_ssl_options_with_cert_none(self): + """Test SSL options with CERT_NONE requirement.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" + mock_config.REDIS_SSL_CA_CERTS = None + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_NONE + assert result["ssl_ca_certs"] is None + assert result["ssl_certfile"] is None + assert result["ssl_keyfile"] is None + + def test_get_celery_ssl_options_with_cert_required(self): + """Test SSL options with CERT_REQUIRED and certificates.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "rediss://localhost:6380/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_REQUIRED" + mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" + mock_config.REDIS_SSL_CERTFILE = "/path/to/client.crt" + mock_config.REDIS_SSL_KEYFILE = "/path/to/client.key" + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_REQUIRED + assert result["ssl_ca_certs"] == "/path/to/ca.crt" + assert result["ssl_certfile"] == "/path/to/client.crt" + assert result["ssl_keyfile"] == "/path/to/client.key" + + def test_get_celery_ssl_options_with_cert_optional(self): + """Test SSL options with CERT_OPTIONAL requirement.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_OPTIONAL" + mock_config.REDIS_SSL_CA_CERTS = "/path/to/ca.crt" + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_OPTIONAL + assert result["ssl_ca_certs"] == "/path/to/ca.crt" + + def test_get_celery_ssl_options_with_invalid_cert_reqs(self): + """Test SSL options with invalid cert requirement defaults to CERT_NONE.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "INVALID_VALUE" + mock_config.REDIS_SSL_CA_CERTS = None + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + + with patch("extensions.ext_celery.dify_config", mock_config): + from extensions.ext_celery import _get_celery_ssl_options + + result = _get_celery_ssl_options() + assert result is not None + assert result["ssl_cert_reqs"] == ssl.CERT_NONE # Should default to CERT_NONE + + def test_celery_init_applies_ssl_to_broker_and_backend(self): + """Test that SSL options are applied to both broker and backend when using Redis.""" + mock_config = MagicMock() + mock_config.REDIS_USE_SSL = True + mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0" + mock_config.CELERY_BACKEND = "redis" + mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0" + mock_config.REDIS_SSL_CERT_REQS = "CERT_NONE" + mock_config.REDIS_SSL_CA_CERTS = None + mock_config.REDIS_SSL_CERTFILE = None + mock_config.REDIS_SSL_KEYFILE = None + mock_config.CELERY_USE_SENTINEL = False + mock_config.LOG_FORMAT = "%(message)s" + mock_config.LOG_TZ = "UTC" + mock_config.LOG_FILE = None + + # Mock all the scheduler configs + mock_config.CELERY_BEAT_SCHEDULER_TIME = 1 + mock_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK = False + mock_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK = False + mock_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK = False + mock_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK = False + mock_config.ENABLE_CLEAN_MESSAGES = False + mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False + mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False + mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False + + with patch("extensions.ext_celery.dify_config", mock_config): + from dify_app import DifyApp + from extensions.ext_celery import init_app + + app = DifyApp(__name__) + celery_app = init_app(app) + + # Check that SSL options were applied + assert "broker_use_ssl" in celery_app.conf + assert celery_app.conf["broker_use_ssl"] is not None + assert celery_app.conf["broker_use_ssl"]["ssl_cert_reqs"] == ssl.CERT_NONE + + # Check that SSL is also applied to Redis backend + assert "redis_backend_use_ssl" in celery_app.conf + assert celery_app.conf["redis_backend_use_ssl"] is not None diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py index d42c4412f5..39280c9267 100644 --- a/api/tests/unit_tests/factories/test_build_from_mapping.py +++ b/api/tests/unit_tests/factories/test_build_from_mapping.py @@ -21,7 +21,7 @@ TEST_REMOTE_URL = "http://example.com/test.jpg" # Test Config TEST_CONFIG = FileUploadConfig( - allowed_file_types=["image", "document"], + allowed_file_types=[FileType.IMAGE, FileType.DOCUMENT], allowed_file_extensions=[".jpg", ".pdf"], allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE], number_limits=10, @@ -171,10 +171,10 @@ def test_build_without_type_specification(mock_upload_file): mapping = { "transfer_method": "local_file", "upload_file_id": TEST_UPLOAD_FILE_ID, - # leave out the type + # type field is intentionally omitted } file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) - # It should automatically infer the type as "image" based on the file extension + # Should automatically infer the type as "image" based on the file extension assert file.type == FileType.IMAGE @@ -194,3 +194,81 @@ def test_file_validation_with_config(mock_upload_file, file_type, should_pass, e else: with pytest.raises(ValueError, match=expected_error): build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG) + + +def test_invalid_transfer_method(): + """Test that invalid transfer method raises ValueError.""" + mapping = { + "transfer_method": "invalid_method", + "upload_file_id": TEST_UPLOAD_FILE_ID, + "type": "image", + } + with pytest.raises(ValueError, match="No matching enum found for value 'invalid_method'"): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + + +def test_invalid_uuid_format(): + """Test that invalid UUID format raises ValueError.""" + mapping = { + "transfer_method": "local_file", + "upload_file_id": "not-a-valid-uuid", + "type": "image", + } + with pytest.raises(ValueError, match="Invalid upload file id format"): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + + +def test_tenant_mismatch(): + """Test that tenant mismatch raises security error.""" + # Create a mock upload file with a different tenant_id + mock_file = MagicMock(spec=UploadFile) + mock_file.id = TEST_UPLOAD_FILE_ID + mock_file.tenant_id = "different_tenant_id" + mock_file.name = "test.jpg" + mock_file.extension = "jpg" + mock_file.mime_type = "image/jpeg" + mock_file.source_url = TEST_REMOTE_URL + mock_file.size = 1024 + mock_file.key = "test_key" + + # Mock the database query to return None (no file found for this tenant) + with patch("factories.file_factory.db.session.scalar", return_value=None): + mapping = local_file_mapping() + with pytest.raises(ValueError, match="Invalid upload file"): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID) + + +def test_disallowed_file_types(mock_upload_file): + """Test that disallowed file types are rejected.""" + # Config that only allows image and document types + restricted_config = FileUploadConfig( + allowed_file_types=[FileType.IMAGE, FileType.DOCUMENT], + ) + + # Try to upload a video file + mapping = local_file_mapping(file_type="video") + with pytest.raises(ValueError, match="File validation failed"): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=restricted_config) + + +def test_disallowed_extensions(mock_upload_file): + """Test that disallowed file extensions are rejected for custom type.""" + # Mock a file with .exe extension + mock_upload_file.return_value.extension = "exe" + mock_upload_file.return_value.name = "malicious.exe" + mock_upload_file.return_value.mime_type = "application/x-msdownload" + + # Config that only allows specific extensions for custom files + restricted_config = FileUploadConfig( + allowed_file_extensions=[".txt", ".csv", ".json"], + ) + + # Mapping without specifying type (will be detected as custom) + mapping = { + "transfer_method": "local_file", + "upload_file_id": TEST_UPLOAD_FILE_ID, + "type": "custom", + } + + with pytest.raises(ValueError, match="File validation failed"): + build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=restricted_config) diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 5bc77ad0ef..4c61320c29 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -9,7 +9,6 @@ from core.file.models import File from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from core.variables.segments import IntegerSegment, Segment from factories.variable_factory import build_segment -from models.model import EndUser from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable @@ -43,14 +42,9 @@ def test_environment_variables(): {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} ) - # Mock current_user as an EndUser - mock_user = mock.Mock(spec=EndUser) - mock_user.tenant_id = "tenant_id" - with ( mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), - mock.patch("models.workflow.current_user", mock_user), ): # Set the environment_variables property of the Workflow instance variables = [variable1, variable2, variable3, variable4] @@ -90,14 +84,9 @@ def test_update_environment_variables(): {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} ) - # Mock current_user as an EndUser - mock_user = mock.Mock(spec=EndUser) - mock_user.tenant_id = "tenant_id" - with ( mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), - mock.patch("models.workflow.current_user", mock_user), ): variables = [variable1, variable2, variable3, variable4] @@ -136,14 +125,9 @@ def test_to_dict(): # Create some EnvironmentVariable instances - # Mock current_user as an EndUser - mock_user = mock.Mock(spec=EndUser) - mock_user.tenant_id = "tenant_id" - with ( mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"), - mock.patch("models.workflow.current_user", mock_user), ): # Set the environment_variables property of the Workflow instance workflow.environment_variables = [ diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py new file mode 100644 index 0000000000..dd2bc21814 --- /dev/null +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -0,0 +1,168 @@ +import datetime +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy.orm import Session + +from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs + + +class TestClearFreePlanTenantExpiredLogs: + """Unit tests for ClearFreePlanTenantExpiredLogs._clear_message_related_tables method.""" + + @pytest.fixture + def mock_session(self): + """Create a mock database session.""" + session = Mock(spec=Session) + session.query.return_value.filter.return_value.all.return_value = [] + session.query.return_value.filter.return_value.delete.return_value = 0 + return session + + @pytest.fixture + def mock_storage(self): + """Create a mock storage object.""" + storage = Mock() + storage.save.return_value = None + return storage + + @pytest.fixture + def sample_message_ids(self): + """Sample message IDs for testing.""" + return ["msg-1", "msg-2", "msg-3"] + + @pytest.fixture + def sample_records(self): + """Sample records for testing.""" + records = [] + for i in range(3): + record = Mock() + record.id = f"record-{i}" + record.to_dict.return_value = { + "id": f"record-{i}", + "message_id": f"msg-{i}", + "created_at": datetime.datetime.now().isoformat(), + } + records.append(record) + return records + + def test_clear_message_related_tables_empty_message_ids(self, mock_session): + """Test that method returns early when message_ids is empty.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", []) + + # Should not call any database operations + mock_session.query.assert_not_called() + mock_storage.save.assert_not_called() + + def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids): + """Test when no related records are found.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = [] + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call query for each related table but find no records + assert mock_session.query.call_count > 0 + mock_storage.save.assert_not_called() + + def test_clear_message_related_tables_with_records_and_to_dict( + self, mock_session, sample_message_ids, sample_records + ): + """Test when records are found and have to_dict method.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call to_dict on each record (called once per table, so 7 times total) + for record in sample_records: + assert record.to_dict.call_count == 7 + + # Should save backup data + assert mock_storage.save.call_count > 0 + + def test_clear_message_related_tables_with_records_no_to_dict(self, mock_session, sample_message_ids): + """Test when records are found but don't have to_dict method.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + # Create records without to_dict method + records = [] + for i in range(2): + record = Mock() + mock_table = Mock() + mock_id_column = Mock() + mock_id_column.name = "id" + mock_message_id_column = Mock() + mock_message_id_column.name = "message_id" + mock_table.columns = [mock_id_column, mock_message_id_column] + record.__table__ = mock_table + record.id = f"record-{i}" + record.message_id = f"msg-{i}" + del record.to_dict + records.append(record) + + # Mock records for first table only, empty for others + mock_session.query.return_value.filter.return_value.all.side_effect = [ + records, + [], + [], + [], + [], + [], + [], + ] + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should save backup data even without to_dict + assert mock_storage.save.call_count > 0 + + def test_clear_message_related_tables_storage_error_continues( + self, mock_session, sample_message_ids, sample_records + ): + """Test that method continues even when storage.save fails.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_storage.save.side_effect = Exception("Storage error") + + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + # Should not raise exception + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should still delete records even if backup fails + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids): + """Test that method continues even when record serialization fails.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + record = Mock() + record.id = "record-1" + record.to_dict.side_effect = Exception("Serialization error") + + mock_session.query.return_value.filter.return_value.all.return_value = [record] + + # Should not raise exception + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should still delete records even if serialization fails + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records): + """Test that deletion is called for found records.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + # Should call delete for each table that has records + assert mock_session.query.return_value.filter.return_value.delete.called + + def test_clear_message_related_tables_logging_output( + self, mock_session, sample_message_ids, sample_records, capsys + ): + """Test that logging output is generated.""" + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: + mock_session.query.return_value.filter.return_value.all.return_value = sample_records + + ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) + + pass diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py new file mode 100644 index 0000000000..9c1c044f03 --- /dev/null +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -0,0 +1,127 @@ +import uuid +from unittest.mock import MagicMock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from services.conversation_service import ConversationService + + +class TestConversationService: + def test_pagination_with_empty_include_ids(self): + """Test that empty include_ids returns empty result""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=[], # Empty include_ids should return empty result + exclude_ids=None, + ) + + assert result.data == [] + assert result.has_more is False + assert result.limit == 20 + + def test_pagination_with_non_empty_include_ids(self): + """Test that non-empty include_ids filters properly""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=["conv1", "conv2"], # Non-empty include_ids + exclude_ids=None, + ) + + # Verify the where clause was called with id.in_ + assert mock_stmt.where.called + + def test_pagination_with_empty_exclude_ids(self): + """Test that empty exclude_ids doesn't filter""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(5)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=None, + exclude_ids=[], # Empty exclude_ids should not filter + ) + + # Result should contain the mocked conversations + assert len(result.data) == 5 + + def test_pagination_with_non_empty_exclude_ids(self): + """Test that non-empty exclude_ids filters properly""" + mock_session = MagicMock() + mock_app_model = MagicMock(id=str(uuid.uuid4())) + mock_user = MagicMock(id=str(uuid.uuid4())) + + # Mock the query results + mock_conversations = [MagicMock(id=str(uuid.uuid4())) for _ in range(3)] + mock_session.scalars.return_value.all.return_value = mock_conversations + mock_session.scalar.return_value = 0 + + with patch("services.conversation_service.select") as mock_select: + mock_stmt = MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_stmt.subquery.return_value = MagicMock() + + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=mock_app_model, + user=mock_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=None, + exclude_ids=["conv1", "conv2"], # Non-empty exclude_ids + ) + + # Verify the where clause was called for exclusion + assert mock_stmt.where.called diff --git a/api/tests/unit_tests/services/test_dataset_permission.py b/api/tests/unit_tests/services/test_dataset_permission.py index a67252e856..c1e4981325 100644 --- a/api/tests/unit_tests/services/test_dataset_permission.py +++ b/api/tests/unit_tests/services/test_dataset_permission.py @@ -301,5 +301,5 @@ class TestDatasetPermissionService: # Verify debug message was logged with correct user and dataset information mock_logging_dependencies["logging"].debug.assert_called_with( - f"User {normal_user.id} does not have permission to access dataset {dataset.id}" + "User %s does not have permission to access dataset %s", normal_user.id, dataset.id ) diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py index dc09aca5b2..1881ceac26 100644 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -93,16 +93,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus: with ( patch("services.dataset_service.DocumentService.get_document") as mock_get_doc, patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.datetime") as mock_datetime, + patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, ): current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_datetime.datetime.now.return_value = current_time - mock_datetime.UTC = datetime.UTC + mock_naive_utc_now.return_value = current_time yield { "get_document": mock_get_doc, "db_session": mock_db, - "datetime": mock_datetime, + "naive_utc_now": mock_naive_utc_now, "current_time": current_time, } @@ -120,21 +119,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus: assert document.enabled == True assert document.disabled_at is None assert document.disabled_by is None - assert document.updated_at == current_time.replace(tzinfo=None) + assert document.updated_at == current_time def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime): """Helper method to verify document was disabled correctly.""" assert document.enabled == False - assert document.disabled_at == current_time.replace(tzinfo=None) + assert document.disabled_at == current_time assert document.disabled_by == user_id - assert document.updated_at == current_time.replace(tzinfo=None) + assert document.updated_at == current_time def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime): """Helper method to verify document was archived correctly.""" assert document.archived == True - assert document.archived_at == current_time.replace(tzinfo=None) + assert document.archived_at == current_time assert document.archived_by == user_id - assert document.updated_at == current_time.replace(tzinfo=None) + assert document.updated_at == current_time def _assert_document_unarchived(self, document: Mock): """Helper method to verify document was unarchived correctly.""" @@ -430,7 +429,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Verify document attributes were updated correctly self._assert_document_unarchived(archived_doc) - assert archived_doc.updated_at == mock_document_service_dependencies["current_time"].replace(tzinfo=None) + assert archived_doc.updated_at == mock_document_service_dependencies["current_time"] # Verify Redis cache was set (because document is enabled) redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) @@ -495,9 +494,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Verify document was unarchived self._assert_document_unarchived(archived_disabled_doc) - assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"].replace( - tzinfo=None - ) + assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"] # Verify no Redis cache was set (document is disabled) redis_mock.setex.assert_not_called() diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py new file mode 100644 index 0000000000..0fc36510b9 --- /dev/null +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -0,0 +1,189 @@ +from unittest.mock import Mock, patch + +import pytest +from flask_restx import reqparse +from werkzeug.exceptions import BadRequest + +from services.entities.knowledge_entities.knowledge_entities import MetadataArgs +from services.metadata_service import MetadataService + + +class TestMetadataBugCompleteValidation: + """Complete test suite to verify the metadata nullable bug and its fix.""" + + def test_1_pydantic_layer_validation(self): + """Test Layer 1: Pydantic model validation correctly rejects None values.""" + # Pydantic should reject None values for required fields + with pytest.raises((ValueError, TypeError)): + MetadataArgs(type=None, name=None) + + with pytest.raises((ValueError, TypeError)): + MetadataArgs(type="string", name=None) + + with pytest.raises((ValueError, TypeError)): + MetadataArgs(type=None, name="test") + + # Valid values should work + valid_args = MetadataArgs(type="string", name="test_name") + assert valid_args.type == "string" + assert valid_args.name == "test_name" + + def test_2_business_logic_layer_crashes_on_none(self): + """Test Layer 2: Business logic crashes when None values slip through.""" + # Create mock that bypasses Pydantic validation + mock_metadata_args = Mock() + mock_metadata_args.name = None + mock_metadata_args.type = "string" + + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + # Should crash with TypeError + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args) + + # Test update method as well + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.update_metadata_name("dataset-123", "metadata-456", None) + + def test_3_database_constraints_verification(self): + """Test Layer 3: Verify database model has nullable=False constraints.""" + from sqlalchemy import inspect + + from models.dataset import DatasetMetadata + + # Get table info + mapper = inspect(DatasetMetadata) + + # Check that type and name columns are not nullable + type_column = mapper.columns["type"] + name_column = mapper.columns["name"] + + assert type_column.nullable is False, "type column should be nullable=False" + assert name_column.nullable is False, "name column should be nullable=False" + + def test_4_fixed_api_layer_rejects_null(self, app): + """Test Layer 4: Fixed API configuration properly rejects null values.""" + # Test Console API create endpoint (fixed) + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + with pytest.raises(BadRequest): + parser.parse_args() + + # Test with just name being null + with app.test_request_context(json={"type": "string", "name": None}, content_type="application/json"): + with pytest.raises(BadRequest): + parser.parse_args() + + # Test with just type being null + with app.test_request_context(json={"type": None, "name": "test"}, content_type="application/json"): + with pytest.raises(BadRequest): + parser.parse_args() + + def test_5_fixed_api_accepts_valid_values(self, app): + """Test that fixed API still accepts valid non-null values.""" + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + + with app.test_request_context(json={"type": "string", "name": "valid_name"}, content_type="application/json"): + args = parser.parse_args() + assert args["type"] == "string" + assert args["name"] == "valid_name" + + def test_6_simulated_buggy_behavior(self, app): + """Test simulating the original buggy behavior with nullable=True.""" + # Simulate the old buggy configuration + buggy_parser = reqparse.RequestParser() + buggy_parser.add_argument("type", type=str, required=True, nullable=True, location="json") + buggy_parser.add_argument("name", type=str, required=True, nullable=True, location="json") + + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + # This would pass in the buggy version + args = buggy_parser.parse_args() + assert args["type"] is None + assert args["name"] is None + + # But would crash when trying to create MetadataArgs + with pytest.raises((ValueError, TypeError)): + MetadataArgs(**args) + + def test_7_end_to_end_validation_layers(self): + """Test all validation layers work together correctly.""" + # Layer 1: API should reject null at parameter level (with fix) + # Layer 2: Pydantic should reject null at model level + # Layer 3: Business logic expects non-null + # Layer 4: Database enforces non-null + + # Test that valid data flows through all layers + valid_data = {"type": "string", "name": "test_metadata"} + + # Should create valid Pydantic object + metadata_args = MetadataArgs(**valid_data) + assert metadata_args.type == "string" + assert metadata_args.name == "test_metadata" + + # Should not crash in business logic length check + assert len(metadata_args.name) <= 255 # This should not crash + assert len(metadata_args.type) > 0 # This should not crash + + def test_8_verify_specific_fix_locations(self): + """Verify that the specific locations mentioned in bug report are fixed.""" + # Read the actual files to verify fixes + import os + + # Console API create + console_create_file = "api/controllers/console/datasets/metadata.py" + if os.path.exists(console_create_file): + with open(console_create_file) as f: + content = f.read() + # Should contain nullable=False, not nullable=True + assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0] + + # Service API create + service_create_file = "api/controllers/service_api/dataset/metadata.py" + if os.path.exists(service_create_file): + with open(service_create_file) as f: + content = f.read() + # Should contain nullable=False, not nullable=True + create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0] + assert "nullable=True" not in create_api_section + + +class TestMetadataValidationSummary: + """Summary tests that demonstrate the complete validation architecture.""" + + def test_validation_layer_architecture(self): + """Document and test the 4-layer validation architecture.""" + # Layer 1: API Parameter Validation (Flask-RESTful reqparse) + # - Role: First line of defense, validates HTTP request parameters + # - Fixed: nullable=False ensures null values are rejected at API boundary + + # Layer 2: Pydantic Model Validation + # - Role: Validates data structure and types before business logic + # - Working: Required fields without Optional[] reject None values + + # Layer 3: Business Logic Validation + # - Role: Domain-specific validation (length checks, uniqueness, etc.) + # - Vulnerable: Direct len() calls crash on None values + + # Layer 4: Database Constraints + # - Role: Final data integrity enforcement + # - Working: nullable=False prevents None values in database + + # The bug was: Layer 1 allowed None, but Layers 2-4 expected non-None + # The fix: Make Layer 1 consistent with Layers 2-4 + + assert True # This test documents the architecture + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py new file mode 100644 index 0000000000..7f6344f942 --- /dev/null +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -0,0 +1,108 @@ +from unittest.mock import Mock, patch + +import pytest +from flask_restx import reqparse + +from services.entities.knowledge_entities.knowledge_entities import MetadataArgs +from services.metadata_service import MetadataService + + +class TestMetadataNullableBug: + """Test case to reproduce the metadata nullable validation bug.""" + + def test_metadata_args_with_none_values_should_fail(self): + """Test that MetadataArgs validation should reject None values.""" + # This test demonstrates the expected behavior - should fail validation + with pytest.raises((ValueError, TypeError)): + # This should fail because Pydantic expects non-None values + MetadataArgs(type=None, name=None) + + def test_metadata_service_create_with_none_name_crashes(self): + """Test that MetadataService.create_metadata crashes when name is None.""" + # Mock the MetadataArgs to bypass Pydantic validation + mock_metadata_args = Mock() + mock_metadata_args.name = None # This will cause len() to crash + mock_metadata_args.type = "string" + + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + # This should crash with TypeError when calling len(None) + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args) + + def test_metadata_service_update_with_none_name_crashes(self): + """Test that MetadataService.update_metadata_name crashes when name is None.""" + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + # This should crash with TypeError when calling len(None) + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.update_metadata_name("dataset-123", "metadata-456", None) + + def test_api_parser_accepts_null_values(self, app): + """Test that API parser configuration incorrectly accepts null values.""" + # Simulate the current API parser configuration + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=True, location="json") + parser.add_argument("name", type=str, required=True, nullable=True, location="json") + + # Simulate request data with null values + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + # This should parse successfully due to nullable=True + args = parser.parse_args() + + # Verify that null values are accepted + assert args["type"] is None + assert args["name"] is None + + # This demonstrates the bug: API accepts None but business logic will crash + + def test_integration_bug_scenario(self, app): + """Test the complete bug scenario from API to service layer.""" + # Step 1: API parser accepts null values (current buggy behavior) + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=True, location="json") + parser.add_argument("name", type=str, required=True, nullable=True, location="json") + + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + args = parser.parse_args() + + # Step 2: Try to create MetadataArgs with None values + # This should fail at Pydantic validation level + with pytest.raises((ValueError, TypeError)): + metadata_args = MetadataArgs(**args) + + # Step 3: If we bypass Pydantic (simulating the bug scenario) + # Move this outside the request context to avoid Flask-Login issues + mock_metadata_args = Mock() + mock_metadata_args.name = None # From args["name"] + mock_metadata_args.type = None # From args["type"] + + with patch("services.metadata_service.current_user") as mock_user: + mock_user.current_tenant_id = "tenant-123" + mock_user.id = "user-456" + + # Step 4: Service layer crashes on len(None) + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args) + + def test_correct_nullable_false_configuration_works(self, app): + """Test that the correct nullable=False configuration works as expected.""" + # This tests the FIXED configuration + parser = reqparse.RequestParser() + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + + with app.test_request_context(json={"type": None, "name": None}, content_type="application/json"): + # This should fail with BadRequest due to nullable=False + from werkzeug.exceptions import BadRequest + + with pytest.raises(BadRequest): + parser.parse_args() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/api/tests/unit_tests/tasks/__init__.py b/api/tests/unit_tests/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py new file mode 100644 index 0000000000..d8003570b5 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -0,0 +1,243 @@ +from unittest.mock import ANY, MagicMock, call, patch + +import pytest +import sqlalchemy as sa + +from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch + + +class TestDeleteDraftVariablesBatch: + @patch("tasks.remove_app_and_related_data_task.db") + def test_delete_draft_variables_batch_success(self, mock_db): + """Test successful deletion of draft variables in batches.""" + app_id = "test-app-id" + batch_size = 100 + + # Mock database connection and engine + mock_conn = MagicMock() + mock_engine = MagicMock() + mock_db.engine = mock_engine + # Properly mock the context manager + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__exit__.return_value = None + mock_engine.begin.return_value = mock_context_manager + + # Mock two batches of results, then empty + batch1_ids = [f"var-{i}" for i in range(100)] + batch2_ids = [f"var-{i}" for i in range(100, 150)] + + # Setup side effects for execute calls in the correct order: + # 1. SELECT (returns batch1_ids) + # 2. DELETE (returns result with rowcount=100) + # 3. SELECT (returns batch2_ids) + # 4. DELETE (returns result with rowcount=50) + # 5. SELECT (returns empty, ends loop) + + # Create mock results with actual integer rowcount attributes + class MockResult: + def __init__(self, rowcount): + self.rowcount = rowcount + + # First SELECT result + select_result1 = MagicMock() + select_result1.__iter__.return_value = iter([(id_,) for id_ in batch1_ids]) + + # First DELETE result + delete_result1 = MockResult(rowcount=100) + + # Second SELECT result + select_result2 = MagicMock() + select_result2.__iter__.return_value = iter([(id_,) for id_ in batch2_ids]) + + # Second DELETE result + delete_result2 = MockResult(rowcount=50) + + # Third SELECT result (empty, ends loop) + select_result3 = MagicMock() + select_result3.__iter__.return_value = iter([]) + + # Configure side effects in the correct order + mock_conn.execute.side_effect = [ + select_result1, # First SELECT + delete_result1, # First DELETE + select_result2, # Second SELECT + delete_result2, # Second DELETE + select_result3, # Third SELECT (empty) + ] + + # Execute the function + result = delete_draft_variables_batch(app_id, batch_size) + + # Verify the result + assert result == 150 + + # Verify database calls + assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes + + # Verify the expected calls in order: + # 1. SELECT, 2. DELETE, 3. SELECT, 4. DELETE, 5. SELECT + expected_calls = [ + # First SELECT + call( + sa.text(""" + SELECT id FROM workflow_draft_variables + WHERE app_id = :app_id + LIMIT :batch_size + """), + {"app_id": app_id, "batch_size": batch_size}, + ), + # First DELETE + call( + sa.text(""" + DELETE FROM workflow_draft_variables + WHERE id IN :ids + """), + {"ids": tuple(batch1_ids)}, + ), + # Second SELECT + call( + sa.text(""" + SELECT id FROM workflow_draft_variables + WHERE app_id = :app_id + LIMIT :batch_size + """), + {"app_id": app_id, "batch_size": batch_size}, + ), + # Second DELETE + call( + sa.text(""" + DELETE FROM workflow_draft_variables + WHERE id IN :ids + """), + {"ids": tuple(batch2_ids)}, + ), + # Third SELECT (empty result) + call( + sa.text(""" + SELECT id FROM workflow_draft_variables + WHERE app_id = :app_id + LIMIT :batch_size + """), + {"app_id": app_id, "batch_size": batch_size}, + ), + ] + + # Check that all calls were made correctly + actual_calls = mock_conn.execute.call_args_list + assert len(actual_calls) == len(expected_calls) + + # Simplified verification - just check that the right number of calls were made + # and that the SQL queries contain the expected patterns + for i, actual_call in enumerate(actual_calls): + if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4) + # Verify it's a SELECT query + sql_text = str(actual_call[0][0]) + assert "SELECT id FROM workflow_draft_variables" in sql_text + assert "WHERE app_id = :app_id" in sql_text + assert "LIMIT :batch_size" in sql_text + else: # DELETE calls (odd indices: 1, 3) + # Verify it's a DELETE query + sql_text = str(actual_call[0][0]) + assert "DELETE FROM workflow_draft_variables" in sql_text + assert "WHERE id IN :ids" in sql_text + + @patch("tasks.remove_app_and_related_data_task.db") + def test_delete_draft_variables_batch_empty_result(self, mock_db): + """Test deletion when no draft variables exist for the app.""" + app_id = "nonexistent-app-id" + batch_size = 1000 + + # Mock database connection + mock_conn = MagicMock() + mock_engine = MagicMock() + mock_db.engine = mock_engine + # Properly mock the context manager + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__exit__.return_value = None + mock_engine.begin.return_value = mock_context_manager + + # Mock empty result + empty_result = MagicMock() + empty_result.__iter__.return_value = iter([]) + mock_conn.execute.return_value = empty_result + + result = delete_draft_variables_batch(app_id, batch_size) + + assert result == 0 + assert mock_conn.execute.call_count == 1 # Only one select query + + def test_delete_draft_variables_batch_invalid_batch_size(self): + """Test that invalid batch size raises ValueError.""" + app_id = "test-app-id" + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, -1) + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, 0) + + @patch("tasks.remove_app_and_related_data_task.db") + @patch("tasks.remove_app_and_related_data_task.logging") + def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db): + """Test that batch deletion logs progress correctly.""" + app_id = "test-app-id" + batch_size = 50 + + # Mock database + mock_conn = MagicMock() + mock_engine = MagicMock() + mock_db.engine = mock_engine + # Properly mock the context manager + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_conn + mock_context_manager.__exit__.return_value = None + mock_engine.begin.return_value = mock_context_manager + + # Mock one batch then empty + batch_ids = [f"var-{i}" for i in range(30)] + # Create properly configured mocks + select_result = MagicMock() + select_result.__iter__.return_value = iter([(id_,) for id_ in batch_ids]) + + # Create simple object with rowcount attribute + class MockResult: + def __init__(self, rowcount): + self.rowcount = rowcount + + delete_result = MockResult(rowcount=30) + + empty_result = MagicMock() + empty_result.__iter__.return_value = iter([]) + + mock_conn.execute.side_effect = [ + # Select query result + select_result, + # Delete query result + delete_result, + # Empty select result (end condition) + empty_result, + ] + + result = delete_draft_variables_batch(app_id, batch_size) + + assert result == 30 + + # Verify logging calls + assert mock_logging.info.call_count == 2 + mock_logging.info.assert_any_call( + ANY # click.style call + ) + + @patch("tasks.remove_app_and_related_data_task.delete_draft_variables_batch") + def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete): + """Test that _delete_draft_variables calls the batch function correctly.""" + app_id = "test-app-id" + expected_return = 42 + mock_batch_delete.return_value = expected_return + + result = _delete_draft_variables(app_id) + + assert result == expected_return + mock_batch_delete.assert_called_once_with(app_id, batch_size=1000) diff --git a/api/uv.lock b/api/uv.lock index 623b125ab3..45b020e1dd 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'", @@ -741,6 +741,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" }, ] +[[package]] +name = "celery-types" +version = "0.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/d1/0823e71c281e4ad0044e278cf1577d1a68e05f2809424bf94e1614925c5d/celery_types-0.23.0.tar.gz", hash = "sha256:402ed0555aea3cd5e1e6248f4632e4f18eec8edb2435173f9e6dc08449fa101e", size = 31479, upload-time = "2025-03-03T23:56:51.547Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/8b/92bb54dd74d145221c3854aa245c84f4dc04cc9366147496182cec8e88e3/celery_types-0.23.0-py3-none-any.whl", hash = "sha256:0cc495b8d7729891b7e070d0ec8d4906d2373209656a6e8b8276fe1ed306af9a", size = 50189, upload-time = "2025-03-03T23:56:50.458Z" }, +] + [[package]] name = "certifi" version = "2025.6.15" @@ -983,6 +995,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/1f/935d0810b73184a1d306f92458cb0a2e9b0de2377f536da874e063b8e422/clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020", size = 239584, upload-time = "2024-08-21T21:36:22.105Z" }, ] +[[package]] +name = "clickzetta-connector-python" +version = "0.8.102" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "future" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "python-dateutil" }, + { name = "requests" }, + { name = "sqlalchemy" }, + { name = "urllib3" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/e5/23dcc950e873127df0135cf45144062a3207f5d2067259c73854e8ce7228/clickzetta_connector_python-0.8.102-py3-none-any.whl", hash = "sha256:c45486ae77fd82df7113ec67ec50e772372588d79c23757f8ee6291a057994a7", size = 77861, upload-time = "2025-07-17T03:11:59.543Z" }, +] + [[package]] name = "cloudscraper" version = "1.2.71" @@ -1217,7 +1248,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.7.0" +version = "1.7.2" source = { virtual = "." } dependencies = [ { name = "arize-phoenix-otel" }, @@ -1234,7 +1265,8 @@ dependencies = [ { name = "flask-cors" }, { name = "flask-login" }, { name = "flask-migrate" }, - { name = "flask-restful" }, + { name = "flask-orjson" }, + { name = "flask-restx" }, { name = "flask-sqlalchemy" }, { name = "gevent" }, { name = "gmpy2" }, @@ -1265,6 +1297,8 @@ dependencies = [ { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-instrumentation-celery" }, { name = "opentelemetry-instrumentation-flask" }, + { name = "opentelemetry-instrumentation-redis" }, + { name = "opentelemetry-instrumentation-requests" }, { name = "opentelemetry-instrumentation-sqlalchemy" }, { name = "opentelemetry-propagator-b3" }, { name = "opentelemetry-proto" }, @@ -1304,6 +1338,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "boto3-stubs" }, + { name = "celery-types" }, { name = "coverage" }, { name = "dotenv-linter" }, { name = "faker" }, @@ -1318,6 +1353,7 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, { name = "scipy-stubs" }, + { name = "testcontainers" }, { name = "types-aiofiles" }, { name = "types-beautifulsoup4" }, { name = "types-cachetools" }, @@ -1349,6 +1385,7 @@ dev = [ { name = "types-python-http-client" }, { name = "types-pywin32" }, { name = "types-pyyaml" }, + { name = "types-redis" }, { name = "types-regex" }, { name = "types-requests" }, { name = "types-requests-oauthlib" }, @@ -1380,6 +1417,7 @@ vdb = [ { name = "alibabacloud-tea-openapi" }, { name = "chromadb" }, { name = "clickhouse-connect" }, + { name = "clickzetta-connector-python" }, { name = "couchbase" }, { name = "elasticsearch" }, { name = "mo-vector" }, @@ -1411,12 +1449,13 @@ requires-dist = [ { name = "cachetools", specifier = "~=5.3.0" }, { name = "celery", specifier = "~=5.5.2" }, { name = "chardet", specifier = "~=5.1.0" }, - { name = "flask", specifier = "~=3.1.0" }, + { name = "flask", specifier = "~=3.1.2" }, { name = "flask-compress", specifier = "~=1.17" }, { name = "flask-cors", specifier = "~=6.0.0" }, { name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-migrate", specifier = "~=4.0.7" }, - { name = "flask-restful", specifier = "~=0.3.10" }, + { name = "flask-orjson", specifier = "~=2.0.0" }, + { name = "flask-restx", specifier = ">=1.3.0" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=24.11.1" }, { name = "gmpy2", specifier = "~=2.2.1" }, @@ -1447,6 +1486,8 @@ requires-dist = [ { name = "opentelemetry-instrumentation", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-celery", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-flask", specifier = "==0.48b0" }, + { name = "opentelemetry-instrumentation-redis", specifier = "==0.48b0" }, + { name = "opentelemetry-instrumentation-requests", specifier = "==0.48b0" }, { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.48b0" }, { name = "opentelemetry-propagator-b3", specifier = "==1.27.0" }, { name = "opentelemetry-proto", specifier = "==1.27.0" }, @@ -1486,12 +1527,13 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "boto3-stubs", specifier = ">=1.38.20" }, + { name = "celery-types", specifier = ">=0.23.0" }, { name = "coverage", specifier = "~=7.2.4" }, { name = "dotenv-linter", specifier = "~=0.5.0" }, { name = "faker", specifier = "~=32.1.0" }, { name = "hypothesis", specifier = ">=6.131.15" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, - { name = "mypy", specifier = "~=1.16.0" }, + { name = "mypy", specifier = "~=1.17.1" }, { name = "pandas-stubs", specifier = "~=2.2.3" }, { name = "pytest", specifier = "~=8.3.2" }, { name = "pytest-benchmark", specifier = "~=4.0.0" }, @@ -1500,6 +1542,7 @@ dev = [ { name = "pytest-mock", specifier = "~=3.14.0" }, { name = "ruff", specifier = "~=0.12.3" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, + { name = "testcontainers", specifier = "~=4.10.0" }, { name = "types-aiofiles", specifier = "~=24.1.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, { name = "types-cachetools", specifier = "~=5.5.0" }, @@ -1531,6 +1574,7 @@ dev = [ { name = "types-python-http-client", specifier = ">=3.3.7.20240910" }, { name = "types-pywin32", specifier = "~=310.0.0" }, { name = "types-pyyaml", specifier = "~=6.0.12" }, + { name = "types-redis", specifier = ">=4.6.0.20241004" }, { name = "types-regex", specifier = "~=2024.11.6" }, { name = "types-requests", specifier = "~=2.32.0" }, { name = "types-requests-oauthlib", specifier = "~=2.0.0" }, @@ -1562,6 +1606,7 @@ vdb = [ { name = "alibabacloud-tea-openapi", specifier = "~=0.3.9" }, { name = "chromadb", specifier = "==0.5.20" }, { name = "clickhouse-connect", specifier = "~=0.7.16" }, + { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.3.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, @@ -1571,7 +1616,7 @@ vdb = [ { name = "pgvector", specifier = "==0.2.5" }, { name = "pymilvus", specifier = "~=2.5.0" }, { name = "pymochow", specifier = "==1.3.1" }, - { name = "pyobvector", specifier = "~=0.1.6" }, + { name = "pyobvector", specifier = "~=0.2.15" }, { name = "qdrant-client", specifier = "==1.9.0" }, { name = "tablestore", specifier = "==6.2.0" }, { name = "tcvectordb", specifier = "~=1.6.4" }, @@ -1600,6 +1645,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "docker" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, +] + [[package]] name = "docstring-parser" version = "0.16" @@ -1745,7 +1804,7 @@ wheels = [ [[package]] name = "flask" -version = "3.1.1" +version = "3.1.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "blinker" }, @@ -1755,9 +1814,9 @@ dependencies = [ { name = "markupsafe" }, { name = "werkzeug" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c0/de/e47735752347f4128bcf354e0da07ef311a78244eba9e3dc1d4a5ab21a98/flask-3.1.1.tar.gz", hash = "sha256:284c7b8f2f58cb737f0cf1c30fd7eaf0ccfcde196099d24ecede3fc2005aa59e", size = 753440, upload-time = "2025-05-13T15:01:17.447Z" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/6d/cfe3c0fcc5e477df242b98bfe186a4c34357b4847e87ecaef04507332dab/flask-3.1.2.tar.gz", hash = "sha256:bf656c15c80190ed628ad08cdfd3aaa35beb087855e2f494910aa3774cc4fd87", size = 720160, upload-time = "2025-08-19T21:03:21.205Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/68/9d4508e893976286d2ead7f8f571314af6c2037af34853a30fd769c02e9d/flask-3.1.1-py3-none-any.whl", hash = "sha256:07aae2bb5eaf77993ef57e357491839f5fd9f4dc281593a81a9e4d79a24f295c", size = 103305, upload-time = "2025-05-13T15:01:15.591Z" }, + { url = "https://files.pythonhosted.org/packages/ec/f9/7f9263c5695f4bd0023734af91bedb2ff8209e8de6ead162f35d8dc762fd/flask-3.1.2-py3-none-any.whl", hash = "sha256:ca1d8112ec8a6158cc29ea4858963350011b5c846a414cdb7a954aa9e967d03c", size = 103308, upload-time = "2025-08-19T21:03:19.499Z" }, ] [[package]] @@ -1817,18 +1876,33 @@ wheels = [ ] [[package]] -name = "flask-restful" -version = "0.3.10" +name = "flask-orjson" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "flask" }, + { name = "orjson" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/49/575796f6ddca171d82dbb12762e33166c8b8f8616c946f0a6dfbb9bc3cd6/flask_orjson-2.0.0.tar.gz", hash = "sha256:6df6631437f9bc52cf9821735f896efa5583b5f80712f7d29d9ef69a79986a9c", size = 2974, upload-time = "2024-01-15T00:03:22.236Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/ca/53e14be018a2284acf799830e8cd8e0b263c0fd3dff1ad7b35f8417e7067/flask_orjson-2.0.0-py3-none-any.whl", hash = "sha256:5d15f2ba94b8d6c02aee88fc156045016e83db9eda2c30545fabd640aebaec9d", size = 3622, upload-time = "2024-01-15T00:03:17.511Z" }, +] + +[[package]] +name = "flask-restx" +version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aniso8601" }, { name = "flask" }, + { name = "importlib-resources" }, + { name = "jsonschema" }, { name = "pytz" }, - { name = "six" }, + { name = "werkzeug" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c0/ce/a0a133db616ea47f78a41e15c4c68b9f08cab3df31eb960f61899200a119/Flask-RESTful-0.3.10.tar.gz", hash = "sha256:fe4af2ef0027df8f9b4f797aba20c5566801b6ade995ac63b588abf1a59cec37", size = 110453, upload-time = "2023-05-21T03:58:55.781Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/4c/2e7d84e2b406b47cf3bf730f521efe474977b404ee170d8ea68dc37e6733/flask-restx-1.3.0.tar.gz", hash = "sha256:4f3d3fa7b6191fcc715b18c201a12cd875176f92ba4acc61626ccfd571ee1728", size = 2814072, upload-time = "2023-12-10T14:48:55.575Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/7b/f0b45f0df7d2978e5ae51804bb5939b7897b2ace24306009da0cc34d8d1f/Flask_RESTful-0.3.10-py2.py3-none-any.whl", hash = "sha256:1cf93c535172f112e080b0d4503a8d15f93a48c88bdd36dd87269bdaf405051b", size = 26217, upload-time = "2023-05-21T03:58:54.004Z" }, + { url = "https://files.pythonhosted.org/packages/a5/bf/1907369f2a7ee614dde5152ff8f811159d357e77962aa3f8c2e937f63731/flask_restx-1.3.0-py2.py3-none-any.whl", hash = "sha256:636c56c3fb3f2c1df979e748019f084a938c4da2035a3e535a4673e4fc177691", size = 2798683, upload-time = "2023-12-10T14:48:53.293Z" }, ] [[package]] @@ -2091,7 +2165,7 @@ wheels = [ [[package]] name = "google-cloud-bigquery" -version = "3.34.0" +version = "3.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core", extra = ["grpc"] }, @@ -2102,9 +2176,9 @@ dependencies = [ { name = "python-dateutil" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/24/f9/e9da2d56d7028f05c0e2f5edf6ce43c773220c3172666c3dd925791d763d/google_cloud_bigquery-3.34.0.tar.gz", hash = "sha256:5ee1a78ba5c2ccb9f9a8b2bf3ed76b378ea68f49b6cac0544dc55cc97ff7c1ce", size = 489091, upload-time = "2025-05-29T17:18:06.03Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/2f/3dda76b3ec029578838b1fe6396e6b86eb574200352240e23dea49265bb7/google_cloud_bigquery-3.30.0.tar.gz", hash = "sha256:7e27fbafc8ed33cc200fe05af12ecd74d279fe3da6692585a3cef7aee90575b6", size = 474389, upload-time = "2025-02-27T18:49:45.416Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/7e/7115c4f67ca0bc678f25bff1eab56cc37d06eb9a3978940b2ebd0705aa0a/google_cloud_bigquery-3.34.0-py3-none-any.whl", hash = "sha256:de20ded0680f8136d92ff5256270b5920dfe4fae479f5d0f73e90e5df30b1cf7", size = 253555, upload-time = "2025-05-29T17:18:02.904Z" }, + { url = "https://files.pythonhosted.org/packages/0c/6d/856a6ca55c1d9d99129786c929a27dd9d31992628ebbff7f5d333352981f/google_cloud_bigquery-3.30.0-py2.py3-none-any.whl", hash = "sha256:f4d28d846a727f20569c9b2d2f4fa703242daadcb2ec4240905aa485ba461877", size = 247885, upload-time = "2025-02-27T18:49:43.454Z" }, ] [[package]] @@ -3214,28 +3288,28 @@ wheels = [ [[package]] name = "mypy" -version = "1.16.1" +version = "1.17.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mypy-extensions" }, { name = "pathspec" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/81/69/92c7fa98112e4d9eb075a239caa4ef4649ad7d441545ccffbd5e34607cbb/mypy-1.16.1.tar.gz", hash = "sha256:6bd00a0a2094841c5e47e7374bb42b83d64c527a502e3334e1173a0c24437bab", size = 3324747, upload-time = "2025-06-16T16:51:35.145Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/22/ea637422dedf0bf36f3ef238eab4e455e2a0dcc3082b5cc067615347ab8e/mypy-1.17.1.tar.gz", hash = "sha256:25e01ec741ab5bb3eec8ba9cdb0f769230368a22c959c4937360efb89b7e9f01", size = 3352570, upload-time = "2025-07-31T07:54:19.204Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9a/61/ec1245aa1c325cb7a6c0f8570a2eee3bfc40fa90d19b1267f8e50b5c8645/mypy-1.16.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:472e4e4c100062488ec643f6162dd0d5208e33e2f34544e1fc931372e806c0cc", size = 10890557, upload-time = "2025-06-16T16:37:21.421Z" }, - { url = "https://files.pythonhosted.org/packages/6b/bb/6eccc0ba0aa0c7a87df24e73f0ad34170514abd8162eb0c75fd7128171fb/mypy-1.16.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ea16e2a7d2714277e349e24d19a782a663a34ed60864006e8585db08f8ad1782", size = 10012921, upload-time = "2025-06-16T16:51:28.659Z" }, - { url = "https://files.pythonhosted.org/packages/5f/80/b337a12e2006715f99f529e732c5f6a8c143bb58c92bb142d5ab380963a5/mypy-1.16.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08e850ea22adc4d8a4014651575567b0318ede51e8e9fe7a68f25391af699507", size = 11802887, upload-time = "2025-06-16T16:50:53.627Z" }, - { url = "https://files.pythonhosted.org/packages/d9/59/f7af072d09793d581a745a25737c7c0a945760036b16aeb620f658a017af/mypy-1.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22d76a63a42619bfb90122889b903519149879ddbf2ba4251834727944c8baca", size = 12531658, upload-time = "2025-06-16T16:33:55.002Z" }, - { url = "https://files.pythonhosted.org/packages/82/c4/607672f2d6c0254b94a646cfc45ad589dd71b04aa1f3d642b840f7cce06c/mypy-1.16.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2c7ce0662b6b9dc8f4ed86eb7a5d505ee3298c04b40ec13b30e572c0e5ae17c4", size = 12732486, upload-time = "2025-06-16T16:37:03.301Z" }, - { url = "https://files.pythonhosted.org/packages/b6/5e/136555ec1d80df877a707cebf9081bd3a9f397dedc1ab9750518d87489ec/mypy-1.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:211287e98e05352a2e1d4e8759c5490925a7c784ddc84207f4714822f8cf99b6", size = 9479482, upload-time = "2025-06-16T16:47:37.48Z" }, - { url = "https://files.pythonhosted.org/packages/b4/d6/39482e5fcc724c15bf6280ff5806548c7185e0c090712a3736ed4d07e8b7/mypy-1.16.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:af4792433f09575d9eeca5c63d7d90ca4aeceda9d8355e136f80f8967639183d", size = 11066493, upload-time = "2025-06-16T16:47:01.683Z" }, - { url = "https://files.pythonhosted.org/packages/e6/e5/26c347890efc6b757f4d5bb83f4a0cf5958b8cf49c938ac99b8b72b420a6/mypy-1.16.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:66df38405fd8466ce3517eda1f6640611a0b8e70895e2a9462d1d4323c5eb4b9", size = 10081687, upload-time = "2025-06-16T16:48:19.367Z" }, - { url = "https://files.pythonhosted.org/packages/44/c7/b5cb264c97b86914487d6a24bd8688c0172e37ec0f43e93b9691cae9468b/mypy-1.16.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:44e7acddb3c48bd2713994d098729494117803616e116032af192871aed80b79", size = 11839723, upload-time = "2025-06-16T16:49:20.912Z" }, - { url = "https://files.pythonhosted.org/packages/15/f8/491997a9b8a554204f834ed4816bda813aefda31cf873bb099deee3c9a99/mypy-1.16.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ab5eca37b50188163fa7c1b73c685ac66c4e9bdee4a85c9adac0e91d8895e15", size = 12722980, upload-time = "2025-06-16T16:37:40.929Z" }, - { url = "https://files.pythonhosted.org/packages/df/f0/2bd41e174b5fd93bc9de9a28e4fb673113633b8a7f3a607fa4a73595e468/mypy-1.16.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dedb6229b2c9086247e21a83c309754b9058b438704ad2f6807f0d8227f6ebdd", size = 12903328, upload-time = "2025-06-16T16:34:35.099Z" }, - { url = "https://files.pythonhosted.org/packages/61/81/5572108a7bec2c46b8aff7e9b524f371fe6ab5efb534d38d6b37b5490da8/mypy-1.16.1-cp312-cp312-win_amd64.whl", hash = "sha256:1f0435cf920e287ff68af3d10a118a73f212deb2ce087619eb4e648116d1fe9b", size = 9562321, upload-time = "2025-06-16T16:48:58.823Z" }, - { url = "https://files.pythonhosted.org/packages/cf/d3/53e684e78e07c1a2bf7105715e5edd09ce951fc3f47cf9ed095ec1b7a037/mypy-1.16.1-py3-none-any.whl", hash = "sha256:5fc2ac4027d0ef28d6ba69a0343737a23c4d1b83672bf38d1fe237bdc0643b37", size = 2265923, upload-time = "2025-06-16T16:48:02.366Z" }, + { url = "https://files.pythonhosted.org/packages/46/cf/eadc80c4e0a70db1c08921dcc220357ba8ab2faecb4392e3cebeb10edbfa/mypy-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad37544be07c5d7fba814eb370e006df58fed8ad1ef33ed1649cb1889ba6ff58", size = 10921009, upload-time = "2025-07-31T07:53:23.037Z" }, + { url = "https://files.pythonhosted.org/packages/5d/c1/c869d8c067829ad30d9bdae051046561552516cfb3a14f7f0347b7d973ee/mypy-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:064e2ff508e5464b4bd807a7c1625bc5047c5022b85c70f030680e18f37273a5", size = 10047482, upload-time = "2025-07-31T07:53:26.151Z" }, + { url = "https://files.pythonhosted.org/packages/98/b9/803672bab3fe03cee2e14786ca056efda4bb511ea02dadcedde6176d06d0/mypy-1.17.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:70401bbabd2fa1aa7c43bb358f54037baf0586f41e83b0ae67dd0534fc64edfd", size = 11832883, upload-time = "2025-07-31T07:53:47.948Z" }, + { url = "https://files.pythonhosted.org/packages/88/fb/fcdac695beca66800918c18697b48833a9a6701de288452b6715a98cfee1/mypy-1.17.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e92bdc656b7757c438660f775f872a669b8ff374edc4d18277d86b63edba6b8b", size = 12566215, upload-time = "2025-07-31T07:54:04.031Z" }, + { url = "https://files.pythonhosted.org/packages/7f/37/a932da3d3dace99ee8eb2043b6ab03b6768c36eb29a02f98f46c18c0da0e/mypy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c1fdf4abb29ed1cb091cf432979e162c208a5ac676ce35010373ff29247bcad5", size = 12751956, upload-time = "2025-07-31T07:53:36.263Z" }, + { url = "https://files.pythonhosted.org/packages/8c/cf/6438a429e0f2f5cab8bc83e53dbebfa666476f40ee322e13cac5e64b79e7/mypy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:ff2933428516ab63f961644bc49bc4cbe42bbffb2cd3b71cc7277c07d16b1a8b", size = 9507307, upload-time = "2025-07-31T07:53:59.734Z" }, + { url = "https://files.pythonhosted.org/packages/17/a2/7034d0d61af8098ec47902108553122baa0f438df8a713be860f7407c9e6/mypy-1.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:69e83ea6553a3ba79c08c6e15dbd9bfa912ec1e493bf75489ef93beb65209aeb", size = 11086295, upload-time = "2025-07-31T07:53:28.124Z" }, + { url = "https://files.pythonhosted.org/packages/14/1f/19e7e44b594d4b12f6ba8064dbe136505cec813549ca3e5191e40b1d3cc2/mypy-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b16708a66d38abb1e6b5702f5c2c87e133289da36f6a1d15f6a5221085c6403", size = 10112355, upload-time = "2025-07-31T07:53:21.121Z" }, + { url = "https://files.pythonhosted.org/packages/5b/69/baa33927e29e6b4c55d798a9d44db5d394072eef2bdc18c3e2048c9ed1e9/mypy-1.17.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89e972c0035e9e05823907ad5398c5a73b9f47a002b22359b177d40bdaee7056", size = 11875285, upload-time = "2025-07-31T07:53:55.293Z" }, + { url = "https://files.pythonhosted.org/packages/90/13/f3a89c76b0a41e19490b01e7069713a30949d9a6c147289ee1521bcea245/mypy-1.17.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03b6d0ed2b188e35ee6d5c36b5580cffd6da23319991c49ab5556c023ccf1341", size = 12737895, upload-time = "2025-07-31T07:53:43.623Z" }, + { url = "https://files.pythonhosted.org/packages/23/a1/c4ee79ac484241301564072e6476c5a5be2590bc2e7bfd28220033d2ef8f/mypy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c837b896b37cd103570d776bda106eabb8737aa6dd4f248451aecf53030cdbeb", size = 12931025, upload-time = "2025-07-31T07:54:17.125Z" }, + { url = "https://files.pythonhosted.org/packages/89/b8/7409477be7919a0608900e6320b155c72caab4fef46427c5cc75f85edadd/mypy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:665afab0963a4b39dff7c1fa563cc8b11ecff7910206db4b2e64dd1ba25aed19", size = 9584664, upload-time = "2025-07-31T07:54:12.842Z" }, + { url = "https://files.pythonhosted.org/packages/1d/f3/8fcd2af0f5b806f6cf463efaffd3c9548a28f84220493ecd38d127b6b66d/mypy-1.17.1-py3-none-any.whl", hash = "sha256:a9f52c0351c21fe24c21d8c0eb1f62967b262d6729393397b6f443c3b773c3b9", size = 2283411, upload-time = "2025-07-31T07:53:24.664Z" }, ] [[package]] @@ -3654,6 +3728,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/3d/fcde4f8f0bf9fa1ee73a12304fa538076fb83fe0a2ae966ab0f0b7da5109/opentelemetry_instrumentation_flask-0.48b0-py3-none-any.whl", hash = "sha256:26b045420b9d76e85493b1c23fcf27517972423480dc6cf78fd6924248ba5808", size = 14588, upload-time = "2024-08-28T21:26:58.504Z" }, ] +[[package]] +name = "opentelemetry-instrumentation-redis" +version = "0.48b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/70/be/92e98e4c7f275be3d373899a41b0a7d4df64266657d985dbbdb9a54de0d5/opentelemetry_instrumentation_redis-0.48b0.tar.gz", hash = "sha256:61e33e984b4120e1b980d9fba6e9f7ca0c8d972f9970654d8f6e9f27fa115a8c", size = 10511, upload-time = "2024-08-28T21:28:15.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/40/892f30d400091106309cc047fd3f6d76a828fedd984a953fd5386b78a2fb/opentelemetry_instrumentation_redis-0.48b0-py3-none-any.whl", hash = "sha256:48c7f2e25cbb30bde749dc0d8b9c74c404c851f554af832956b9630b27f5bcb7", size = 11610, upload-time = "2024-08-28T21:27:18.759Z" }, +] + +[[package]] +name = "opentelemetry-instrumentation-requests" +version = "0.48b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "opentelemetry-util-http" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/ac/5eb78efde21ff21d0ad5dc8c6cc6a0f8ae482ce8a46293c2f45a628b6166/opentelemetry_instrumentation_requests-0.48b0.tar.gz", hash = "sha256:67ab9bd877a0352ee0db4616c8b4ae59736ddd700c598ed907482d44f4c9a2b3", size = 14120, upload-time = "2024-08-28T21:28:16.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/df/0df9226d1b14f29d23c07e6194b9fd5ad50e7d987b7fd13df7dcf718aeb1/opentelemetry_instrumentation_requests-0.48b0-py3-none-any.whl", hash = "sha256:d4f01852121d0bd4c22f14f429654a735611d4f7bf3cf93f244bdf1489b2233d", size = 12366, upload-time = "2024-08-28T21:27:20.771Z" }, +] + [[package]] name = "opentelemetry-instrumentation-sqlalchemy" version = "0.48b0" @@ -3868,11 +3972,11 @@ wheels = [ [[package]] name = "packaging" -version = "24.2" +version = "23.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950, upload-time = "2024-11-08T09:47:47.202Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fb/2b/9b9c33ffed44ee921d0967086d653047286054117d584f1b1a7c22ceaf7b/packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", size = 146714, upload-time = "2023-10-01T13:50:05.279Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451, upload-time = "2024-11-08T09:47:44.722Z" }, + { url = "https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7", size = 53011, upload-time = "2023-10-01T13:50:03.745Z" }, ] [[package]] @@ -4252,6 +4356,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, ] +[[package]] +name = "pyarrow" +version = "14.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645, upload-time = "2023-12-18T15:43:41.625Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/8a/411ef0b05483076b7f548c74ccaa0f90c1e60d3875db71a821f6ffa8cf42/pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b", size = 26904455, upload-time = "2023-12-18T15:40:43.477Z" }, + { url = "https://files.pythonhosted.org/packages/6c/6c/882a57798877e3a49ba54d8e0540bea24aed78fb42e1d860f08c3449c75e/pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23", size = 23997116, upload-time = "2023-12-18T15:40:48.533Z" }, + { url = "https://files.pythonhosted.org/packages/ec/3f/ef47fe6192ce4d82803a073db449b5292135406c364a7fc49dfbcd34c987/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200", size = 35944575, upload-time = "2023-12-18T15:40:55.128Z" }, + { url = "https://files.pythonhosted.org/packages/1a/90/2021e529d7f234a3909f419d4341d53382541ef77d957fa274a99c533b18/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696", size = 38079719, upload-time = "2023-12-18T15:41:02.565Z" }, + { url = "https://files.pythonhosted.org/packages/30/a9/474caf5fd54a6d5315aaf9284c6e8f5d071ca825325ad64c53137b646e1f/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a", size = 35429706, upload-time = "2023-12-18T15:41:09.955Z" }, + { url = "https://files.pythonhosted.org/packages/d9/f8/cfba56f5353e51c19b0c240380ce39483f4c76e5c4aee5a000f3d75b72da/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02", size = 38001476, upload-time = "2023-12-18T15:41:16.372Z" }, + { url = "https://files.pythonhosted.org/packages/43/3f/7bdf7dc3b3b0cfdcc60760e7880954ba99ccd0bc1e0df806f3dd61bc01cd/pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b", size = 24576230, upload-time = "2023-12-18T15:41:22.561Z" }, + { url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585, upload-time = "2023-12-18T15:41:27.59Z" }, + { url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222, upload-time = "2023-12-18T15:41:32.449Z" }, + { url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036, upload-time = "2023-12-18T15:41:38.767Z" }, + { url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266, upload-time = "2023-12-18T15:41:47.617Z" }, + { url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468, upload-time = "2023-12-18T15:41:54.49Z" }, + { url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134, upload-time = "2023-12-18T15:42:01.593Z" }, + { url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754, upload-time = "2023-12-18T15:42:07.108Z" }, +] + [[package]] name = "pyasn1" version = "0.6.1" @@ -4456,17 +4585,19 @@ wheels = [ [[package]] name = "pyobvector" -version = "0.1.14" +version = "0.2.15" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiomysql" }, { name = "numpy" }, + { name = "pydantic" }, { name = "pymysql" }, { name = "sqlalchemy" }, + { name = "sqlglot" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/dc/59/7d762061808948dd6aad165a000b34e22163dc83fb5014184eeacc0fabe5/pyobvector-0.1.14.tar.gz", hash = "sha256:4f85cdd63064d040e94c0a96099a0cd5cda18ce625865382e89429f28422fc02", size = 26780, upload-time = "2024-11-20T11:46:18.017Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/7d/3f3aac6acf1fdd1782042d6eecd48efaa2ee355af0dbb61e93292d629391/pyobvector-0.2.15.tar.gz", hash = "sha256:5de258c1e952c88b385b5661e130c1cf8262c498c1f8a4a348a35962d379fce4", size = 39611, upload-time = "2025-08-18T02:49:26.683Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/68/ecb21b74c974e7be7f9034e205d08db62d614ff5c221581ae96d37ef853e/pyobvector-0.1.14-py3-none-any.whl", hash = "sha256:828e0bec49a177355b70c7a1270af3b0bf5239200ee0d096e4165b267eeff97c", size = 35526, upload-time = "2024-11-20T11:46:16.809Z" }, + { url = "https://files.pythonhosted.org/packages/5f/1f/a62754ba9b8a02c038d2a96cb641b71d3809f34d2ba4f921fecd7840d7fb/pyobvector-0.2.15-py3-none-any.whl", hash = "sha256:feeefe849ee5400e72a9a4d3844e425a58a99053dd02abe06884206923065ebb", size = 52680, upload-time = "2025-08-18T02:49:25.452Z" }, ] [[package]] @@ -5319,6 +5450,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" }, ] +[[package]] +name = "sqlglot" +version = "26.33.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/25/9d/fcd59b4612d5ad1e2257c67c478107f073b19e1097d3bfde2fb517884416/sqlglot-26.33.0.tar.gz", hash = "sha256:2817278779fa51d6def43aa0d70690b93a25c83eb18ec97130fdaf707abc0d73", size = 5353340, upload-time = "2025-07-01T13:09:06.311Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/8d/f1d9cb5b18e06aa45689fbeaaea6ebab66d5f01d1e65029a8f7657c06be5/sqlglot-26.33.0-py3-none-any.whl", hash = "sha256:031cee20c0c796a83d26d079a47fdce667604df430598c7eabfa4e4dfd147033", size = 477610, upload-time = "2025-07-01T13:09:03.926Z" }, +] + [[package]] name = "sseclient-py" version = "1.8.0" @@ -5468,6 +5608,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" }, ] +[[package]] +name = "testcontainers" +version = "4.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docker" }, + { name = "python-dotenv" }, + { name = "typing-extensions" }, + { name = "urllib3" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/49/9c618aff1c50121d183cdfbc3a4a5cf2727a2cde1893efe6ca55c7009196/testcontainers-4.10.0.tar.gz", hash = "sha256:03f85c3e505d8b4edeb192c72a961cebbcba0dd94344ae778b4a159cb6dcf8d3", size = 63327, upload-time = "2025-04-02T16:13:27.582Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/0a/824b0c1ecf224802125279c3effff2e25ed785ed046e67da6e53d928de4c/testcontainers-4.10.0-py3-none-any.whl", hash = "sha256:31ed1a81238c7e131a2a29df6db8f23717d892b592fa5a1977fd0dcd0c23fc23", size = 107414, upload-time = "2025-04-02T16:13:25.785Z" }, +] + [[package]] name = "tidb-vector" version = "0.0.9" @@ -5952,6 +6108,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/5f/e0af6f7f6a260d9af67e1db4f54d732abad514252a7a378a6c4d17dd1036/types_pyyaml-6.0.12.20250516-py3-none-any.whl", hash = "sha256:8478208feaeb53a34cb5d970c56a7cd76b72659442e733e268a94dc72b2d0530", size = 20312, upload-time = "2025-05-16T03:08:04.019Z" }, ] +[[package]] +name = "types-redis" +version = "4.6.0.20241004" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "types-pyopenssl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/95/c054d3ac940e8bac4ca216470c80c26688a0e79e09f520a942bb27da3386/types-redis-4.6.0.20241004.tar.gz", hash = "sha256:5f17d2b3f9091ab75384153bfa276619ffa1cf6a38da60e10d5e6749cc5b902e", size = 49679, upload-time = "2024-10-04T02:43:59.224Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/82/7d25dce10aad92d2226b269bce2f85cfd843b4477cd50245d7d40ecf8f89/types_redis-4.6.0.20241004-py3-none-any.whl", hash = "sha256:ef5da68cb827e5f606c8f9c0b49eeee4c2669d6d97122f301d3a55dc6a63f6ed", size = 58737, upload-time = "2024-10-04T02:43:57.968Z" }, +] + [[package]] name = "types-regex" version = "2024.11.6.20250403" diff --git a/dev/pytest/pytest_all_tests.sh b/dev/pytest/pytest_all_tests.sh index 30898b4fcf..9123b2f8ad 100755 --- a/dev/pytest/pytest_all_tests.sh +++ b/dev/pytest/pytest_all_tests.sh @@ -15,3 +15,6 @@ dev/pytest/pytest_workflow.sh # Unit tests dev/pytest/pytest_unit_tests.sh + +# TestContainers tests +dev/pytest/pytest_testcontainers.sh diff --git a/dev/pytest/pytest_testcontainers.sh b/dev/pytest/pytest_testcontainers.sh new file mode 100755 index 0000000000..e55a436138 --- /dev/null +++ b/dev/pytest/pytest_testcontainers.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -x + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +cd "$SCRIPT_DIR/../.." + +pytest api/tests/test_containers_integration_tests diff --git a/dev/start-worker b/dev/start-worker index 7007b265e0..66e446c831 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.." uv --directory api run \ celery -A app.celery worker \ - -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion + -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage diff --git a/docker/.env.example b/docker/.env.example index 88cc544730..711898016e 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -52,6 +52,11 @@ FILES_URL= # Example: INTERNAL_FILES_URL=http://api:5001 INTERNAL_FILES_URL= +# Ensure UTF-8 encoding +LANG=en_US.UTF-8 +LC_ALL=en_US.UTF-8 +PYTHONIOENCODING=utf-8 + # ------------------------------ # Server Configuration # ------------------------------ @@ -210,6 +215,8 @@ DB_DATABASE=dify # The size of the database connection pool. # The default is 30 connections, which can be appropriately increased. SQLALCHEMY_POOL_SIZE=30 +# The default is 10 connections, which allows temporary overflow beyond the pool size. +SQLALCHEMY_MAX_OVERFLOW=10 # Database connection pool recycling time, the default is 3600 seconds. SQLALCHEMY_POOL_RECYCLE=3600 # Whether to print SQL, default is false. @@ -259,6 +266,15 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false +# SSL configuration for Redis (when REDIS_USE_SSL=true) +REDIS_SSL_CERT_REQS=CERT_NONE +# Options: CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED +REDIS_SSL_CA_CERTS= +# Path to CA certificate file for SSL verification +REDIS_SSL_CERTFILE= +# Path to client certificate file for SSL authentication +REDIS_SSL_KEYFILE= +# Path to client private key file for SSL authentication REDIS_DB=0 # Whether to use Redis Sentinel mode. @@ -328,6 +344,25 @@ OPENDAL_SCHEME=fs # Configurations for OpenDAL Local File System. OPENDAL_FS_ROOT=storage +# ClickZetta Volume Configuration (for storage backend) +# To use ClickZetta Volume as storage backend, set STORAGE_TYPE=clickzetta-volume +# Note: ClickZetta Volume will reuse the existing CLICKZETTA_* connection parameters + +# Volume type selection (three types available): +# - user: Personal/small team use, simple config, user-level permissions +# - table: Enterprise multi-tenant, smart routing, table-level + user-level permissions +# - external: Data lake integration, external storage connection, volume-level + storage-level permissions +CLICKZETTA_VOLUME_TYPE=user + +# External Volume name (required only when TYPE=external) +CLICKZETTA_VOLUME_NAME= + +# Table Volume table prefix (used only when TYPE=table) +CLICKZETTA_VOLUME_TABLE_PREFIX=dataset_ + +# Dify file directory prefix (isolates from other apps, recommended to keep default) +CLICKZETTA_VOLUME_DIFY_PREFIX=dify_km + # S3 Configuration # S3_ENDPOINT= @@ -411,7 +446,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `clickzetta`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -578,6 +613,17 @@ ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_PASSWORD=elastic KIBANA_PORT=5601 +# Using ElasticSearch Cloud Serverless, or not. +ELASTICSEARCH_USE_CLOUD=false +ELASTICSEARCH_CLOUD_URL=YOUR-ELASTICSEARCH_CLOUD_URL +ELASTICSEARCH_API_KEY=YOUR-ELASTICSEARCH_API_KEY + +ELASTICSEARCH_VERIFY_CERTS=False +ELASTICSEARCH_CA_CERTS= +ELASTICSEARCH_REQUEST_TIMEOUT=100000 +ELASTICSEARCH_RETRY_ON_TIMEOUT=True +ELASTICSEARCH_MAX_RETRIES=10 + # baidu vector configurations, only available when VECTOR_STORE is `baidu` BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 @@ -637,6 +683,21 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com TABLESTORE_INSTANCE_NAME=instance-name TABLESTORE_ACCESS_KEY_ID=xxx TABLESTORE_ACCESS_KEY_SECRET=xxx +TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE=false + +# Clickzetta configuration, only available when VECTOR_STORE is `clickzetta` +CLICKZETTA_USERNAME= +CLICKZETTA_PASSWORD= +CLICKZETTA_INSTANCE= +CLICKZETTA_SERVICE=api.clickzetta.com +CLICKZETTA_WORKSPACE=quick_start +CLICKZETTA_VCLUSTER=default_ap +CLICKZETTA_SCHEMA=dify +CLICKZETTA_BATCH_SIZE=100 +CLICKZETTA_ENABLE_INVERTED_INDEX=true +CLICKZETTA_ANALYZER_TYPE=chinese +CLICKZETTA_ANALYZER_MODE=smart +CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # ------------------------------ # Knowledge Configuration @@ -811,16 +872,30 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms # Repository configuration # Core workflow execution repository implementation +# Options: +# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default) +# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository # Core workflow node execution repository implementation +# Options: +# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default) +# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + # API workflow node execution repository implementation API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository -# API workflow run repository implementation -API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository +# Workflow log cleanup configuration +# Enable automatic cleanup of workflow run logs to manage database size +WORKFLOW_LOG_CLEANUP_ENABLED=false +# Number of days to retain workflow run logs (default: 30 days) +WORKFLOW_LOG_RETENTION_DAYS=30 +# Batch size for workflow log cleanup operations (default: 100) +WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 @@ -857,6 +932,9 @@ TEXT_GENERATION_TIMEOUT_MS=60000 # Allow rendering unsafe URLs which have "data:" scheme. ALLOW_UNSAFE_DATA_SCHEME=false +# Maximum number of tree depth in the workflow +MAX_TREE_DEPTH=50 + # ------------------------------ # Environment Variables for db Service # ------------------------------ @@ -1098,6 +1176,9 @@ MARKETPLACE_API_URL=https://marketplace.dify.ai FORCE_VERIFYING_SIGNATURE=true +PLUGIN_STDIO_BUFFER_SIZE=1024 +PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880 + PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120 PLUGIN_MAX_EXECUTION_TIMEOUT=600 # PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple diff --git a/docker/README.md b/docker/README.md index 22dfe2c91c..b5c46eb9fc 100644 --- a/docker/README.md +++ b/docker/README.md @@ -4,7 +4,7 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T ### What's Updated -- **Certbot Container**: `docker-compose.yaml` now contains `certbot` for managing SSL certificates. This container automatically renews certificates and ensures secure HTTPS connections. +- **Certbot Container**: `docker-compose.yaml` now contains `certbot` for managing SSL certificates. This container automatically renews certificates and ensures secure HTTPS connections.\ For more information, refer `docker/certbot/README.md`. - **Persistent Environment Variables**: Environment variables are now managed through a `.env` file, ensuring that your configurations persist across deployments. @@ -13,43 +13,44 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T > The `.env` file is a crucial component in Docker and Docker Compose environments, serving as a centralized configuration file where you can define environment variables that are accessible to the containers at runtime. This file simplifies the management of environment settings across different stages of development, testing, and production, providing consistency and ease of configuration to deployments. - **Unified Vector Database Services**: All vector database services are now managed from a single Docker Compose file `docker-compose.yaml`. You can switch between different vector databases by setting the `VECTOR_STORE` environment variable in your `.env` file. + - **Mandatory .env File**: A `.env` file is now required to run `docker compose up`. This file is crucial for configuring your deployment and for any custom settings to persist through upgrades. ### How to Deploy Dify with `docker-compose.yaml` 1. **Prerequisites**: Ensure Docker and Docker Compose are installed on your system. -2. **Environment Setup**: - - Navigate to the `docker` directory. - - Copy the `.env.example` file to a new file named `.env` by running `cp .env.example .env`. - - Customize the `.env` file as needed. Refer to the `.env.example` file for detailed configuration options. -3. **Running the Services**: - - Execute `docker compose up` from the `docker` directory to start the services. - - To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`. -4. **SSL Certificate Setup**: - - Refer `docker/certbot/README.md` to set up SSL certificates using Certbot. -5. **OpenTelemetry Collector Setup**: +1. **Environment Setup**: + - Navigate to the `docker` directory. + - Copy the `.env.example` file to a new file named `.env` by running `cp .env.example .env`. + - Customize the `.env` file as needed. Refer to the `.env.example` file for detailed configuration options. +1. **Running the Services**: + - Execute `docker compose up` from the `docker` directory to start the services. + - To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`. +1. **SSL Certificate Setup**: + - Refer `docker/certbot/README.md` to set up SSL certificates using Certbot. +1. **OpenTelemetry Collector Setup**: - Change `ENABLE_OTEL` to `true` in `.env`. - Configure `OTLP_BASE_ENDPOINT` properly. ### How to Deploy Middleware for Developing Dify 1. **Middleware Setup**: - - Use the `docker-compose.middleware.yaml` for setting up essential middleware services like databases and caches. - - Navigate to the `docker` directory. - - Ensure the `middleware.env` file is created by running `cp middleware.env.example middleware.env` (refer to the `middleware.env.example` file). -2. **Running Middleware Services**: - - Navigate to the `docker` directory. - - Execute `docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d` to start the middleware services. (Change the profile to other vector database if you are not using weaviate) + - Use the `docker-compose.middleware.yaml` for setting up essential middleware services like databases and caches. + - Navigate to the `docker` directory. + - Ensure the `middleware.env` file is created by running `cp middleware.env.example middleware.env` (refer to the `middleware.env.example` file). +1. **Running Middleware Services**: + - Navigate to the `docker` directory. + - Execute `docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d` to start the middleware services. (Change the profile to other vector database if you are not using weaviate) ### Migration for Existing Users For users migrating from the `docker-legacy` setup: 1. **Review Changes**: Familiarize yourself with the new `.env` configuration and Docker Compose setup. -2. **Transfer Customizations**: - - If you have customized configurations such as `docker-compose.yaml`, `ssrf_proxy/squid.conf`, or `nginx/conf.d/default.conf`, you will need to reflect these changes in the `.env` file you create. -3. **Data Migration**: - - Ensure that data from services like databases and caches is backed up and migrated appropriately to the new structure if necessary. +1. **Transfer Customizations**: + - If you have customized configurations such as `docker-compose.yaml`, `ssrf_proxy/squid.conf`, or `nginx/conf.d/default.conf`, you will need to reflect these changes in the `.env` file you create. +1. **Data Migration**: + - Ensure that data from services like databases and caches is backed up and migrated appropriately to the new structure if necessary. ### Overview of `.env` @@ -64,39 +65,49 @@ For users migrating from the `docker-legacy` setup: The `.env.example` file provided in the Docker setup is extensive and covers a wide range of configuration options. It is structured into several sections, each pertaining to different aspects of the application and its services. Here are some of the key sections and variables: 1. **Common Variables**: - - `CONSOLE_API_URL`, `SERVICE_API_URL`: URLs for different API services. - - `APP_WEB_URL`: Frontend application URL. - - `FILES_URL`: Base URL for file downloads and previews. -2. **Server Configuration**: - - `LOG_LEVEL`, `DEBUG`, `FLASK_DEBUG`: Logging and debug settings. - - `SECRET_KEY`: A key for encrypting session cookies and other sensitive data. + - `CONSOLE_API_URL`, `SERVICE_API_URL`: URLs for different API services. + - `APP_WEB_URL`: Frontend application URL. + - `FILES_URL`: Base URL for file downloads and previews. -3. **Database Configuration**: - - `DB_USERNAME`, `DB_PASSWORD`, `DB_HOST`, `DB_PORT`, `DB_DATABASE`: PostgreSQL database credentials and connection details. +1. **Server Configuration**: -4. **Redis Configuration**: - - `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`: Redis server connection settings. + - `LOG_LEVEL`, `DEBUG`, `FLASK_DEBUG`: Logging and debug settings. + - `SECRET_KEY`: A key for encrypting session cookies and other sensitive data. -5. **Celery Configuration**: - - `CELERY_BROKER_URL`: Configuration for Celery message broker. +1. **Database Configuration**: -6. **Storage Configuration**: - - `STORAGE_TYPE`, `S3_BUCKET_NAME`, `AZURE_BLOB_ACCOUNT_NAME`: Settings for file storage options like local, S3, Azure Blob, etc. + - `DB_USERNAME`, `DB_PASSWORD`, `DB_HOST`, `DB_PORT`, `DB_DATABASE`: PostgreSQL database credentials and connection details. -7. **Vector Database Configuration**: - - `VECTOR_STORE`: Type of vector database (e.g., `weaviate`, `milvus`). - - Specific settings for each vector store like `WEAVIATE_ENDPOINT`, `MILVUS_URI`. +1. **Redis Configuration**: -8. **CORS Configuration**: - - `WEB_API_CORS_ALLOW_ORIGINS`, `CONSOLE_CORS_ALLOW_ORIGINS`: Settings for cross-origin resource sharing. + - `REDIS_HOST`, `REDIS_PORT`, `REDIS_PASSWORD`: Redis server connection settings. -9. **OpenTelemetry Configuration**: - - `ENABLE_OTEL`: Enable OpenTelemetry collector in api. - - `OTLP_BASE_ENDPOINT`: Endpoint for your OTLP exporter. - -10. **Other Service-Specific Environment Variables**: - - Each service like `nginx`, `redis`, `db`, and vector databases have specific environment variables that are directly referenced in the `docker-compose.yaml`. +1. **Celery Configuration**: + + - `CELERY_BROKER_URL`: Configuration for Celery message broker. + +1. **Storage Configuration**: + + - `STORAGE_TYPE`, `S3_BUCKET_NAME`, `AZURE_BLOB_ACCOUNT_NAME`: Settings for file storage options like local, S3, Azure Blob, etc. + +1. **Vector Database Configuration**: + + - `VECTOR_STORE`: Type of vector database (e.g., `weaviate`, `milvus`). + - Specific settings for each vector store like `WEAVIATE_ENDPOINT`, `MILVUS_URI`. + +1. **CORS Configuration**: + + - `WEB_API_CORS_ALLOW_ORIGINS`, `CONSOLE_CORS_ALLOW_ORIGINS`: Settings for cross-origin resource sharing. + +1. **OpenTelemetry Configuration**: + + - `ENABLE_OTEL`: Enable OpenTelemetry collector in api. + - `OTLP_BASE_ENDPOINT`: Endpoint for your OTLP exporter. + +1. **Other Service-Specific Environment Variables**: + + - Each service like `nginx`, `redis`, `db`, and vector databases have specific environment variables that are directly referenced in the `docker-compose.yaml`. ### Additional Information diff --git a/docker/certbot/README.md b/docker/certbot/README.md index 21be34b33a..62b1eee395 100644 --- a/docker/certbot/README.md +++ b/docker/certbot/README.md @@ -2,12 +2,12 @@ ## Short description -docker compose certbot configurations with Backward compatibility (without certbot container). +docker compose certbot configurations with Backward compatibility (without certbot container).\ Use `docker compose --profile certbot up` to use this features. ## The simplest way for launching new servers with SSL certificates -1. Get letsencrypt certs +1. Get letsencrypt certs\ set `.env` values ```properties NGINX_SSL_CERT_FILENAME=fullchain.pem @@ -25,7 +25,7 @@ Use `docker compose --profile certbot up` to use this features. ```shell docker compose exec -it certbot /bin/sh /update-cert.sh ``` -2. Edit `.env` file and `docker compose --profile certbot up` again. +1. Edit `.env` file and `docker compose --profile certbot up` again.\ set `.env` value additionally ```properties NGINX_HTTPS_ENABLED=true @@ -34,7 +34,7 @@ Use `docker compose --profile certbot up` to use this features. ```shell docker compose --profile certbot up -d --no-deps --force-recreate nginx ``` - Then you can access your serve with HTTPS. + Then you can access your serve with HTTPS.\ [https://your_domain.com](https://your_domain.com) ## SSL certificates renewal diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 394a068200..04981f6b7f 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.7.0 + image: langgenius/dify-api:1.7.2 restart: always environment: # Use the shared environment variables. @@ -31,7 +31,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.7.0 + image: langgenius/dify-api:1.7.2 restart: always environment: # Use the shared environment variables. @@ -58,7 +58,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.7.0 + image: langgenius/dify-api:1.7.2 restart: always environment: # Use the shared environment variables. @@ -76,7 +76,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.7.0 + image: langgenius/dify-web:1.7.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -96,6 +96,7 @@ services: MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99} + MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50} ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true} ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true} ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true} @@ -180,6 +181,8 @@ services: FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} + PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} + PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local} PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage} @@ -538,7 +541,7 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.5.0-beta + image: milvusdb/milvus:v2.5.15 profiles: - milvus command: [ 'milvus', 'run', 'standalone' ] diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 3408fef0c2..9f7cc72586 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -20,7 +20,7 @@ services: ports: - "${EXPOSE_POSTGRES_PORT:-5432}:5432" healthcheck: - test: [ "CMD", "pg_isready" ] + test: [ 'CMD', 'pg_isready', '-h', 'db', '-U', '${PGUSER:-postgres}', '-d', '${POSTGRES_DB:-dify}' ] interval: 1s timeout: 3s retries: 30 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index c2ef2ff723..d3b75d93af 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -12,6 +12,9 @@ x-shared-env: &shared-api-worker-env APP_WEB_URL: ${APP_WEB_URL:-} FILES_URL: ${FILES_URL:-} INTERNAL_FILES_URL: ${INTERNAL_FILES_URL:-} + LANG: ${LANG:-en_US.UTF-8} + LC_ALL: ${LC_ALL:-en_US.UTF-8} + PYTHONIOENCODING: ${PYTHONIOENCODING:-utf-8} LOG_LEVEL: ${LOG_LEVEL:-INFO} LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} @@ -54,6 +57,7 @@ x-shared-env: &shared-api-worker-env DB_PORT: ${DB_PORT:-5432} DB_DATABASE: ${DB_DATABASE:-dify} SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30} + SQLALCHEMY_MAX_OVERFLOW: ${SQLALCHEMY_MAX_OVERFLOW:-10} SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600} SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false} SQLALCHEMY_POOL_PRE_PING: ${SQLALCHEMY_POOL_PRE_PING:-false} @@ -68,6 +72,10 @@ x-shared-env: &shared-api-worker-env REDIS_USERNAME: ${REDIS_USERNAME:-} REDIS_PASSWORD: ${REDIS_PASSWORD:-difyai123456} REDIS_USE_SSL: ${REDIS_USE_SSL:-false} + REDIS_SSL_CERT_REQS: ${REDIS_SSL_CERT_REQS:-CERT_NONE} + REDIS_SSL_CA_CERTS: ${REDIS_SSL_CA_CERTS:-} + REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} + REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} REDIS_DB: ${REDIS_DB:-0} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-} @@ -90,6 +98,10 @@ x-shared-env: &shared-api-worker-env STORAGE_TYPE: ${STORAGE_TYPE:-opendal} OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage} + CLICKZETTA_VOLUME_TYPE: ${CLICKZETTA_VOLUME_TYPE:-user} + CLICKZETTA_VOLUME_NAME: ${CLICKZETTA_VOLUME_NAME:-} + CLICKZETTA_VOLUME_TABLE_PREFIX: ${CLICKZETTA_VOLUME_TABLE_PREFIX:-dataset_} + CLICKZETTA_VOLUME_DIFY_PREFIX: ${CLICKZETTA_VOLUME_DIFY_PREFIX:-dify_km} S3_ENDPOINT: ${S3_ENDPOINT:-} S3_REGION: ${S3_REGION:-us-east-1} S3_BUCKET_NAME: ${S3_BUCKET_NAME:-difyai} @@ -258,6 +270,14 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} KIBANA_PORT: ${KIBANA_PORT:-5601} + ELASTICSEARCH_USE_CLOUD: ${ELASTICSEARCH_USE_CLOUD:-false} + ELASTICSEARCH_CLOUD_URL: ${ELASTICSEARCH_CLOUD_URL:-YOUR-ELASTICSEARCH_CLOUD_URL} + ELASTICSEARCH_API_KEY: ${ELASTICSEARCH_API_KEY:-YOUR-ELASTICSEARCH_API_KEY} + ELASTICSEARCH_VERIFY_CERTS: ${ELASTICSEARCH_VERIFY_CERTS:-False} + ELASTICSEARCH_CA_CERTS: ${ELASTICSEARCH_CA_CERTS:-} + ELASTICSEARCH_REQUEST_TIMEOUT: ${ELASTICSEARCH_REQUEST_TIMEOUT:-100000} + ELASTICSEARCH_RETRY_ON_TIMEOUT: ${ELASTICSEARCH_RETRY_ON_TIMEOUT:-True} + ELASTICSEARCH_MAX_RETRIES: ${ELASTICSEARCH_MAX_RETRIES:-10} BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} @@ -301,6 +321,19 @@ x-shared-env: &shared-api-worker-env TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name} TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx} TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx} + TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false} + CLICKZETTA_USERNAME: ${CLICKZETTA_USERNAME:-} + CLICKZETTA_PASSWORD: ${CLICKZETTA_PASSWORD:-} + CLICKZETTA_INSTANCE: ${CLICKZETTA_INSTANCE:-} + CLICKZETTA_SERVICE: ${CLICKZETTA_SERVICE:-api.clickzetta.com} + CLICKZETTA_WORKSPACE: ${CLICKZETTA_WORKSPACE:-quick_start} + CLICKZETTA_VCLUSTER: ${CLICKZETTA_VCLUSTER:-default_ap} + CLICKZETTA_SCHEMA: ${CLICKZETTA_SCHEMA:-dify} + CLICKZETTA_BATCH_SIZE: ${CLICKZETTA_BATCH_SIZE:-100} + CLICKZETTA_ENABLE_INVERTED_INDEX: ${CLICKZETTA_ENABLE_INVERTED_INDEX:-true} + CLICKZETTA_ANALYZER_TYPE: ${CLICKZETTA_ANALYZER_TYPE:-chinese} + CLICKZETTA_ANALYZER_MODE: ${CLICKZETTA_ANALYZER_MODE:-smart} + CLICKZETTA_VECTOR_DISTANCE_FUNCTION: ${CLICKZETTA_VECTOR_DISTANCE_FUNCTION:-cosine_distance} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} @@ -362,8 +395,11 @@ x-shared-env: &shared-api-worker-env WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms} CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository} CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository} - API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository} + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} + WORKFLOW_LOG_CLEANUP_ENABLED: ${WORKFLOW_LOG_CLEANUP_ENABLED:-false} + WORKFLOW_LOG_RETENTION_DAYS: ${WORKFLOW_LOG_RETENTION_DAYS:-30} + WORKFLOW_LOG_CLEANUP_BATCH_SIZE: ${WORKFLOW_LOG_CLEANUP_BATCH_SIZE:-100} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} @@ -376,6 +412,7 @@ x-shared-env: &shared-api-worker-env MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} + MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50} POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}} @@ -477,6 +514,8 @@ x-shared-env: &shared-api-worker-env MARKETPLACE_ENABLED: ${MARKETPLACE_ENABLED:-true} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} + PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} + PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} @@ -539,7 +578,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.7.0 + image: langgenius/dify-api:1.7.2 restart: always environment: # Use the shared environment variables. @@ -568,7 +607,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.7.0 + image: langgenius/dify-api:1.7.2 restart: always environment: # Use the shared environment variables. @@ -595,7 +634,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.7.0 + image: langgenius/dify-api:1.7.2 restart: always environment: # Use the shared environment variables. @@ -613,7 +652,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.7.0 + image: langgenius/dify-web:1.7.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -633,6 +672,7 @@ services: MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99} + MAX_TREE_DEPTH: ${MAX_TREE_DEPTH:-50} ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true} ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true} ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true} @@ -717,6 +757,8 @@ services: FORCE_VERIFYING_SIGNATURE: ${FORCE_VERIFYING_SIGNATURE:-true} PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120} PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600} + PLUGIN_STDIO_BUFFER_SIZE: ${PLUGIN_STDIO_BUFFER_SIZE:-1024} + PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880} PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local} PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage} @@ -1075,7 +1117,7 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.5.0-beta + image: milvusdb/milvus:v2.5.15 profiles: - milvus command: [ 'milvus', 'run', 'standalone' ] diff --git a/images/GitHub_README_if.png b/images/GitHub_README_if.png index 10c9d87b08..281d95cf9c 100644 Binary files a/images/GitHub_README_if.png and b/images/GitHub_README_if.png differ diff --git a/sdks/nodejs-client/README.md b/sdks/nodejs-client/README.md index 37b5ca2d0a..3a5688bcbe 100644 --- a/sdks/nodejs-client/README.md +++ b/sdks/nodejs-client/README.md @@ -1,12 +1,15 @@ # Dify Node.js SDK + This is the Node.js SDK for the Dify API, which allows you to easily integrate Dify into your Node.js applications. ## Install + ```bash npm install dify-client ``` ## Usage + After installing the SDK, you can use it in your project like this: ```js @@ -60,4 +63,5 @@ client.messageFeedback(messageId, rating, user) Replace 'your-api-key-here' with your actual Dify API key.Replace 'your-app-id-here' with your actual Dify APP ID. ## License + This SDK is released under the MIT License. diff --git a/sdks/php-client/README.md b/sdks/php-client/README.md index 91e77ad9ff..444b16a565 100644 --- a/sdks/php-client/README.md +++ b/sdks/php-client/README.md @@ -11,7 +11,7 @@ This is the PHP SDK for the Dify API, which allows you to easily integrate Dify If you want to try the example, you can run `composer install` in this directory. -In exist project, copy the `dify-client.php` to you project, and merge the following to your `composer.json` file, then run `composer install && composer dump-autoload` to install. Guzzle does not require 7.9, other versions have not been tested, but you can try. +In exist project, copy the `dify-client.php` to you project, and merge the following to your `composer.json` file, then run `composer install && composer dump-autoload` to install. Guzzle does not require 7.9, other versions have not been tested, but you can try. ```json { diff --git a/sdks/python-client/README.md b/sdks/python-client/README.md index 7401fd2fd4..34b14b3a94 100644 --- a/sdks/python-client/README.md +++ b/sdks/python-client/README.md @@ -141,8 +141,6 @@ with open(file_path, "rb") as file: result = response.json() print(f'upload_file_id: {result.get("id")}') ``` - - - Others @@ -184,7 +182,8 @@ print('[rename result]') print(rename_conversation_response.json()) ``` -* Using the Workflow Client +- Using the Workflow Client + ```python import json import requests diff --git a/web/Dockerfile b/web/Dockerfile index d59039528c..1376dec749 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -6,7 +6,7 @@ LABEL maintainer="takatost@gmail.com" # RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories RUN apk add --no-cache tzdata -RUN npm install -g pnpm@10.13.1 +RUN corepack enable ENV PNPM_HOME="/pnpm" ENV PATH="$PNPM_HOME:$PATH" @@ -19,6 +19,9 @@ WORKDIR /app/web COPY package.json . COPY pnpm-lock.yaml . +# Use packageManager from package.json +RUN corepack install + # if you located in China, you can use taobao registry to speed up # RUN pnpm install --frozen-lockfile --registry https://registry.npmmirror.com/ @@ -31,7 +34,7 @@ COPY --from=packages /app/web/ . COPY . . ENV NODE_OPTIONS="--max-old-space-size=4096" -RUN pnpm build +RUN pnpm build:docker # production stage diff --git a/web/README.md b/web/README.md index 3d9fd2de87..a47cfab041 100644 --- a/web/README.md +++ b/web/README.md @@ -7,6 +7,7 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next ### Run by source code Before starting the web frontend service, please make sure the following environment is ready. + - [Node.js](https://nodejs.org) >= v22.11.x - [pnpm](https://pnpm.io) v10.x @@ -103,11 +104,9 @@ pnpm run test ``` If you are not familiar with writing tests, here is some code to refer to: -* [classnames.spec.ts](./utils/classnames.spec.ts) -* [index.spec.tsx](./app/components/base/button/index.spec.tsx) - - +- [classnames.spec.ts](./utils/classnames.spec.ts) +- [index.spec.tsx](./app/components/base/button/index.spec.tsx) ## Documentation diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts new file mode 100644 index 0000000000..b4c4f1540d --- /dev/null +++ b/web/__tests__/check-i18n.test.ts @@ -0,0 +1,762 @@ +import fs from 'node:fs' +import path from 'node:path' + +// Mock functions to simulate the check-i18n functionality +const vm = require('node:vm') +const transpile = require('typescript').transpile + +describe('check-i18n script functionality', () => { + const testDir = path.join(__dirname, '../i18n-test') + const testEnDir = path.join(testDir, 'en-US') + const testZhDir = path.join(testDir, 'zh-Hans') + + // Helper function that replicates the getKeysFromLanguage logic + async function getKeysFromLanguage(language: string, testPath = testDir): Promise { + return new Promise((resolve, reject) => { + const folderPath = path.resolve(testPath, language) + const allKeys: string[] = [] + + if (!fs.existsSync(folderPath)) { + resolve([]) + return + } + + fs.readdir(folderPath, (err, files) => { + if (err) { + reject(err) + return + } + + const translationFiles = files.filter(file => /\.(ts|js)$/.test(file)) + + translationFiles.forEach((file) => { + const filePath = path.join(folderPath, file) + const fileName = file.replace(/\.[^/.]+$/, '') + const camelCaseFileName = fileName.replace(/[-_](.)/g, (_, c) => + c.toUpperCase(), + ) + + try { + const content = fs.readFileSync(filePath, 'utf8') + const moduleExports = {} + const context = { + exports: moduleExports, + module: { exports: moduleExports }, + require, + console, + __filename: filePath, + __dirname: folderPath, + } + + vm.runInNewContext(transpile(content), context) + const translationObj = (context.module.exports as any).default || context.module.exports + + if (!translationObj || typeof translationObj !== 'object') + throw new Error(`Error parsing file: ${filePath}`) + + const nestedKeys: string[] = [] + const iterateKeys = (obj: any, prefix = '') => { + for (const key in obj) { + const nestedKey = prefix ? `${prefix}.${key}` : key + if (typeof obj[key] === 'object' && obj[key] !== null && !Array.isArray(obj[key])) { + // This is an object (but not array), recurse into it but don't add it as a key + iterateKeys(obj[key], nestedKey) + } + else { + // This is a leaf node (string, number, boolean, array, etc.), add it as a key + nestedKeys.push(nestedKey) + } + } + } + iterateKeys(translationObj) + + const fileKeys = nestedKeys.map(key => `${camelCaseFileName}.${key}`) + allKeys.push(...fileKeys) + } + catch (error) { + reject(error) + } + }) + resolve(allKeys) + }) + }) + } + + beforeEach(() => { + // Clean up and create test directories + if (fs.existsSync(testDir)) + fs.rmSync(testDir, { recursive: true }) + + fs.mkdirSync(testDir, { recursive: true }) + fs.mkdirSync(testEnDir, { recursive: true }) + fs.mkdirSync(testZhDir, { recursive: true }) + }) + + afterEach(() => { + // Clean up test files + if (fs.existsSync(testDir)) + fs.rmSync(testDir, { recursive: true }) + }) + + describe('Key extraction logic', () => { + it('should extract only leaf node keys, not intermediate objects', async () => { + const testContent = `const translation = { + simple: 'Simple Value', + nested: { + level1: 'Level 1 Value', + deep: { + level2: 'Level 2 Value' + } + }, + array: ['not extracted'], + number: 42, + boolean: true +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'test.ts'), testContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toEqual([ + 'test.simple', + 'test.nested.level1', + 'test.nested.deep.level2', + 'test.array', + 'test.number', + 'test.boolean', + ]) + + // Should not include intermediate object keys + expect(keys).not.toContain('test.nested') + expect(keys).not.toContain('test.nested.deep') + }) + + it('should handle camelCase file name conversion correctly', async () => { + const testContent = `const translation = { + key: 'value' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), testContent) + fs.writeFileSync(path.join(testEnDir, 'user_profile.ts'), testContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('appDebug.key') + expect(keys).toContain('userProfile.key') + }) + }) + + describe('Missing keys detection', () => { + it('should detect missing keys in target language', async () => { + const enContent = `const translation = { + common: { + save: 'Save', + cancel: 'Cancel', + delete: 'Delete' + }, + app: { + title: 'My App', + version: '1.0' + } +} + +export default translation +` + + const zhContent = `const translation = { + common: { + save: '保存', + cancel: '取消' + // missing 'delete' + }, + app: { + title: '我的应用' + // missing 'version' + } +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) + fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeys = await getKeysFromLanguage('zh-Hans') + + const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) + + expect(missingKeys).toContain('test.common.delete') + expect(missingKeys).toContain('test.app.version') + expect(missingKeys).toHaveLength(2) + }) + }) + + describe('Extra keys detection', () => { + it('should detect extra keys in target language', async () => { + const enContent = `const translation = { + common: { + save: 'Save', + cancel: 'Cancel' + } +} + +export default translation +` + + const zhContent = `const translation = { + common: { + save: '保存', + cancel: '取消', + delete: '删除', // extra key + extra: '额外的' // another extra key + }, + newSection: { + someKey: '某个值' // extra section + } +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) + fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeys = await getKeysFromLanguage('zh-Hans') + + const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) + + expect(extraKeys).toContain('test.common.delete') + expect(extraKeys).toContain('test.common.extra') + expect(extraKeys).toContain('test.newSection.someKey') + expect(extraKeys).toHaveLength(3) + }) + }) + + describe('File filtering logic', () => { + it('should filter keys by specific file correctly', async () => { + // Create multiple files + const file1Content = `const translation = { + button: 'Button', + text: 'Text' +} + +export default translation +` + + const file2Content = `const translation = { + title: 'Title', + description: 'Description' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'components.ts'), file1Content) + fs.writeFileSync(path.join(testEnDir, 'pages.ts'), file2Content) + fs.writeFileSync(path.join(testZhDir, 'components.ts'), file1Content) + fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content) + + const allEnKeys = await getKeysFromLanguage('en-US') + + // Test file filtering logic + const targetFile = 'components' + const filteredEnKeys = allEnKeys.filter(key => + key.startsWith(targetFile.replace(/[-_](.)/g, (_, c) => c.toUpperCase())), + ) + + expect(allEnKeys).toHaveLength(4) // 2 keys from each file + expect(filteredEnKeys).toHaveLength(2) // only components keys + expect(filteredEnKeys).toContain('components.button') + expect(filteredEnKeys).toContain('components.text') + expect(filteredEnKeys).not.toContain('pages.title') + expect(filteredEnKeys).not.toContain('pages.description') + }) + }) + + describe('Complex nested structure handling', () => { + it('should handle deeply nested objects correctly', async () => { + const complexContent = `const translation = { + level1: { + level2: { + level3: { + level4: { + deepValue: 'Deep Value' + }, + anotherValue: 'Another Value' + }, + simpleValue: 'Simple Value' + }, + directValue: 'Direct Value' + }, + rootValue: 'Root Value' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'complex.ts'), complexContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('complex.level1.level2.level3.level4.deepValue') + expect(keys).toContain('complex.level1.level2.level3.anotherValue') + expect(keys).toContain('complex.level1.level2.simpleValue') + expect(keys).toContain('complex.level1.directValue') + expect(keys).toContain('complex.rootValue') + + // Should not include intermediate objects + expect(keys).not.toContain('complex.level1') + expect(keys).not.toContain('complex.level1.level2') + expect(keys).not.toContain('complex.level1.level2.level3') + expect(keys).not.toContain('complex.level1.level2.level3.level4') + }) + }) + + describe('Edge cases', () => { + it('should handle empty objects', async () => { + const emptyContent = `const translation = { + empty: {}, + withValue: 'value' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'empty.ts'), emptyContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('empty.withValue') + expect(keys).not.toContain('empty.empty') + }) + + it('should handle special characters in keys', async () => { + const specialContent = `const translation = { + 'key-with-dash': 'value1', + 'key_with_underscore': 'value2', + 'key.with.dots': 'value3', + normalKey: 'value4' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'special.ts'), specialContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('special.key-with-dash') + expect(keys).toContain('special.key_with_underscore') + expect(keys).toContain('special.key.with.dots') + expect(keys).toContain('special.normalKey') + }) + + it('should handle different value types', async () => { + const typesContent = `const translation = { + stringValue: 'string', + numberValue: 42, + booleanValue: true, + nullValue: null, + undefinedValue: undefined, + arrayValue: ['array', 'values'], + objectValue: { + nested: 'nested value' + } +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'types.ts'), typesContent) + + const keys = await getKeysFromLanguage('en-US') + + expect(keys).toContain('types.stringValue') + expect(keys).toContain('types.numberValue') + expect(keys).toContain('types.booleanValue') + expect(keys).toContain('types.nullValue') + expect(keys).toContain('types.undefinedValue') + expect(keys).toContain('types.arrayValue') + expect(keys).toContain('types.objectValue.nested') + expect(keys).not.toContain('types.objectValue') + }) + }) + + describe('Real-world scenario tests', () => { + it('should handle app-debug structure like real files', async () => { + const appDebugEn = `const translation = { + pageTitle: { + line1: 'Prompt', + line2: 'Engineering' + }, + operation: { + applyConfig: 'Publish', + resetConfig: 'Reset', + debugConfig: 'Debug' + }, + generate: { + instruction: 'Instructions', + generate: 'Generate', + resTitle: 'Generated Prompt', + noDataLine1: 'Describe your use case on the left,', + noDataLine2: 'the orchestration preview will show here.' + } +} + +export default translation +` + + const appDebugZh = `const translation = { + pageTitle: { + line1: '提示词', + line2: '编排' + }, + operation: { + applyConfig: '发布', + resetConfig: '重置', + debugConfig: '调试' + }, + generate: { + instruction: '指令', + generate: '生成', + resTitle: '生成的提示词', + noData: '在左侧描述您的用例,编排预览将在此处显示。' // This is extra + } +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), appDebugEn) + fs.writeFileSync(path.join(testZhDir, 'app-debug.ts'), appDebugZh) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeys = await getKeysFromLanguage('zh-Hans') + + const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) + const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) + + expect(missingKeys).toContain('appDebug.generate.noDataLine1') + expect(missingKeys).toContain('appDebug.generate.noDataLine2') + expect(extraKeys).toContain('appDebug.generate.noData') + + expect(missingKeys).toHaveLength(2) + expect(extraKeys).toHaveLength(1) + }) + + it('should handle time structure with operation nested keys', async () => { + const timeEn = `const translation = { + months: { + January: 'January', + February: 'February' + }, + operation: { + now: 'Now', + ok: 'OK', + cancel: 'Cancel', + pickDate: 'Pick Date' + }, + title: { + pickTime: 'Pick Time' + }, + defaultPlaceholder: 'Pick a time...' +} + +export default translation +` + + const timeZh = `const translation = { + months: { + January: '一月', + February: '二月' + }, + operation: { + now: '此刻', + ok: '确定', + cancel: '取消', + pickDate: '选择日期' + }, + title: { + pickTime: '选择时间' + }, + pickDate: '选择日期', // This is extra - duplicates operation.pickDate + defaultPlaceholder: '请选择时间...' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'time.ts'), timeEn) + fs.writeFileSync(path.join(testZhDir, 'time.ts'), timeZh) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeys = await getKeysFromLanguage('zh-Hans') + + const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) + const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) + + expect(missingKeys).toHaveLength(0) // No missing keys + expect(extraKeys).toContain('time.pickDate') // Extra root-level pickDate + expect(extraKeys).toHaveLength(1) + + // Should have both keys available + expect(zhKeys).toContain('time.operation.pickDate') // Correct nested key + expect(zhKeys).toContain('time.pickDate') // Extra duplicate key + }) + }) + + describe('Statistics calculation', () => { + it('should calculate correct difference statistics', async () => { + const enContent = `const translation = { + key1: 'value1', + key2: 'value2', + key3: 'value3' +} + +export default translation +` + + const zhContentMissing = `const translation = { + key1: 'value1', + key2: 'value2' + // missing key3 +} + +export default translation +` + + const zhContentExtra = `const translation = { + key1: 'value1', + key2: 'value2', + key3: 'value3', + key4: 'extra', + key5: 'extra2' +} + +export default translation +` + + fs.writeFileSync(path.join(testEnDir, 'stats.ts'), enContent) + + // Test missing keys scenario + fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentMissing) + + const enKeys = await getKeysFromLanguage('en-US') + const zhKeysMissing = await getKeysFromLanguage('zh-Hans') + + expect(enKeys.length - zhKeysMissing.length).toBe(1) // +1 means 1 missing key + + // Test extra keys scenario + fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentExtra) + + const zhKeysExtra = await getKeysFromLanguage('zh-Hans') + + expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys + }) + }) + + describe('Auto-remove multiline key-value pairs', () => { + // Helper function to simulate removeExtraKeysFromFile logic + function removeExtraKeysFromFile(content: string, keysToRemove: string[]): string { + const lines = content.split('\n') + const linesToRemove: number[] = [] + + for (const keyToRemove of keysToRemove) { + let targetLineIndex = -1 + const linesToRemoveForKey: number[] = [] + + // Find the key line (simplified for single-level keys in test) + for (let i = 0; i < lines.length; i++) { + const line = lines[i] + const keyPattern = new RegExp(`^\\s*${keyToRemove}\\s*:`) + if (keyPattern.test(line)) { + targetLineIndex = i + break + } + } + + if (targetLineIndex !== -1) { + linesToRemoveForKey.push(targetLineIndex) + + // Check if this is a multiline key-value pair + const keyLine = lines[targetLineIndex] + const trimmedKeyLine = keyLine.trim() + + // If key line ends with ":" (not complete value), it's likely multiline + if (trimmedKeyLine.endsWith(':') && !trimmedKeyLine.includes('{') && !trimmedKeyLine.match(/:\s*['"`]/)) { + // Find the value lines that belong to this key + let currentLine = targetLineIndex + 1 + let foundValue = false + + while (currentLine < lines.length) { + const line = lines[currentLine] + const trimmed = line.trim() + + // Skip empty lines + if (trimmed === '') { + currentLine++ + continue + } + + // Check if this line starts a new key (indicates end of current value) + if (trimmed.match(/^\w+\s*:/)) + break + + // Check if this line is part of the value + if (trimmed.startsWith('\'') || trimmed.startsWith('"') || trimmed.startsWith('`') || foundValue) { + linesToRemoveForKey.push(currentLine) + foundValue = true + + // Check if this line ends the value (ends with quote and comma/no comma) + if ((trimmed.endsWith('\',') || trimmed.endsWith('",') || trimmed.endsWith('`,') + || trimmed.endsWith('\'') || trimmed.endsWith('"') || trimmed.endsWith('`')) + && !trimmed.startsWith('//')) + break + } + else { + break + } + + currentLine++ + } + } + + linesToRemove.push(...linesToRemoveForKey) + } + } + + // Remove duplicates and sort in reverse order + const uniqueLinesToRemove = [...new Set(linesToRemove)].sort((a, b) => b - a) + + for (const lineIndex of uniqueLinesToRemove) + lines.splice(lineIndex, 1) + + return lines.join('\n') + } + + it('should remove single-line key-value pairs correctly', () => { + const content = `const translation = { + keepThis: 'This should stay', + removeThis: 'This should be removed', + alsoKeep: 'This should also stay', +} + +export default translation` + + const result = removeExtraKeysFromFile(content, ['removeThis']) + + expect(result).toContain('keepThis: \'This should stay\'') + expect(result).toContain('alsoKeep: \'This should also stay\'') + expect(result).not.toContain('removeThis: \'This should be removed\'') + }) + + it('should remove multiline key-value pairs completely', () => { + const content = `const translation = { + keepThis: 'This should stay', + removeMultiline: + 'This is a multiline value that should be removed completely', + alsoKeep: 'This should also stay', +} + +export default translation` + + const result = removeExtraKeysFromFile(content, ['removeMultiline']) + + expect(result).toContain('keepThis: \'This should stay\'') + expect(result).toContain('alsoKeep: \'This should also stay\'') + expect(result).not.toContain('removeMultiline:') + expect(result).not.toContain('This is a multiline value that should be removed completely') + }) + + it('should handle mixed single-line and multiline removals', () => { + const content = `const translation = { + keepThis: 'Keep this', + removeSingle: 'Remove this single line', + removeMultiline: + 'Remove this multiline value', + anotherMultiline: + 'Another multiline that spans multiple lines', + keepAnother: 'Keep this too', +} + +export default translation` + + const result = removeExtraKeysFromFile(content, ['removeSingle', 'removeMultiline', 'anotherMultiline']) + + expect(result).toContain('keepThis: \'Keep this\'') + expect(result).toContain('keepAnother: \'Keep this too\'') + expect(result).not.toContain('removeSingle:') + expect(result).not.toContain('removeMultiline:') + expect(result).not.toContain('anotherMultiline:') + expect(result).not.toContain('Remove this single line') + expect(result).not.toContain('Remove this multiline value') + expect(result).not.toContain('Another multiline that spans multiple lines') + }) + + it('should properly detect multiline vs single-line patterns', () => { + const multilineContent = `const translation = { + singleLine: 'This is single line', + multilineKey: + 'This is multiline', + keyWithColon: 'Value with: colon inside', + objectKey: { + nested: 'value' + }, +} + +export default translation` + + // Test that single line with colon in value is not treated as multiline + const result1 = removeExtraKeysFromFile(multilineContent, ['keyWithColon']) + expect(result1).not.toContain('keyWithColon:') + expect(result1).not.toContain('Value with: colon inside') + + // Test that true multiline is handled correctly + const result2 = removeExtraKeysFromFile(multilineContent, ['multilineKey']) + expect(result2).not.toContain('multilineKey:') + expect(result2).not.toContain('This is multiline') + + // Test that object key removal works (note: this is a simplified test) + // In real scenario, object removal would be more complex + const result3 = removeExtraKeysFromFile(multilineContent, ['objectKey']) + expect(result3).not.toContain('objectKey: {') + // Note: Our simplified test function doesn't handle nested object removal perfectly + // This is acceptable as it's testing the main multiline string removal functionality + }) + + it('should handle real-world Polish translation structure', () => { + const polishContent = `const translation = { + createApp: 'UTWÓRZ APLIKACJĘ', + newApp: { + captionAppType: 'Jaki typ aplikacji chcesz stworzyć?', + chatbotDescription: + 'Zbuduj aplikację opartą na czacie. Ta aplikacja używa formatu pytań i odpowiedzi.', + agentDescription: + 'Zbuduj inteligentnego agenta, który może autonomicznie wybierać narzędzia.', + basic: 'Podstawowy', + }, +} + +export default translation` + + const result = removeExtraKeysFromFile(polishContent, ['captionAppType', 'chatbotDescription', 'agentDescription']) + + expect(result).toContain('createApp: \'UTWÓRZ APLIKACJĘ\'') + expect(result).toContain('basic: \'Podstawowy\'') + expect(result).not.toContain('captionAppType:') + expect(result).not.toContain('chatbotDescription:') + expect(result).not.toContain('agentDescription:') + expect(result).not.toContain('Jaki typ aplikacji') + expect(result).not.toContain('Zbuduj aplikację opartą na czacie') + expect(result).not.toContain('Zbuduj inteligentnego agenta') + }) + }) +}) diff --git a/web/__tests__/description-validation.test.tsx b/web/__tests__/description-validation.test.tsx new file mode 100644 index 0000000000..85263b035f --- /dev/null +++ b/web/__tests__/description-validation.test.tsx @@ -0,0 +1,97 @@ +/** + * Description Validation Test + * + * Tests for the 400-character description validation across App and Dataset + * creation and editing workflows to ensure consistent validation behavior. + */ + +describe('Description Validation Logic', () => { + // Simulate backend validation function + const validateDescriptionLength = (description?: string | null) => { + if (description && description.length > 400) + throw new Error('Description cannot exceed 400 characters.') + + return description + } + + describe('Backend Validation Function', () => { + test('allows description within 400 characters', () => { + const validDescription = 'x'.repeat(400) + expect(() => validateDescriptionLength(validDescription)).not.toThrow() + expect(validateDescriptionLength(validDescription)).toBe(validDescription) + }) + + test('allows empty description', () => { + expect(() => validateDescriptionLength('')).not.toThrow() + expect(() => validateDescriptionLength(null)).not.toThrow() + expect(() => validateDescriptionLength(undefined)).not.toThrow() + }) + + test('rejects description exceeding 400 characters', () => { + const invalidDescription = 'x'.repeat(401) + expect(() => validateDescriptionLength(invalidDescription)).toThrow( + 'Description cannot exceed 400 characters.', + ) + }) + }) + + describe('Backend Validation Consistency', () => { + test('App and Dataset have consistent validation limits', () => { + const maxLength = 400 + const validDescription = 'x'.repeat(maxLength) + const invalidDescription = 'x'.repeat(maxLength + 1) + + // Both should accept exactly 400 characters + expect(validDescription.length).toBe(400) + expect(() => validateDescriptionLength(validDescription)).not.toThrow() + + // Both should reject 401 characters + expect(invalidDescription.length).toBe(401) + expect(() => validateDescriptionLength(invalidDescription)).toThrow() + }) + + test('validation error messages are consistent', () => { + const expectedErrorMessage = 'Description cannot exceed 400 characters.' + + // This would be the error message from both App and Dataset backend validation + expect(expectedErrorMessage).toBe('Description cannot exceed 400 characters.') + + const invalidDescription = 'x'.repeat(401) + try { + validateDescriptionLength(invalidDescription) + } + catch (error) { + expect((error as Error).message).toBe(expectedErrorMessage) + } + }) + }) + + describe('Character Length Edge Cases', () => { + const testCases = [ + { length: 0, shouldPass: true, description: 'empty description' }, + { length: 1, shouldPass: true, description: '1 character' }, + { length: 399, shouldPass: true, description: '399 characters' }, + { length: 400, shouldPass: true, description: '400 characters (boundary)' }, + { length: 401, shouldPass: false, description: '401 characters (over limit)' }, + { length: 500, shouldPass: false, description: '500 characters' }, + { length: 1000, shouldPass: false, description: '1000 characters' }, + ] + + testCases.forEach(({ length, shouldPass, description }) => { + test(`handles ${description} correctly`, () => { + const testDescription = length > 0 ? 'x'.repeat(length) : '' + expect(testDescription.length).toBe(length) + + if (shouldPass) { + expect(() => validateDescriptionLength(testDescription)).not.toThrow() + expect(validateDescriptionLength(testDescription)).toBe(testDescription) + } + else { + expect(() => validateDescriptionLength(testDescription)).toThrow( + 'Description cannot exceed 400 characters.', + ) + } + }) + }) + }) +}) diff --git a/web/__tests__/document-detail-navigation-fix.test.tsx b/web/__tests__/document-detail-navigation-fix.test.tsx new file mode 100644 index 0000000000..200ed09ea9 --- /dev/null +++ b/web/__tests__/document-detail-navigation-fix.test.tsx @@ -0,0 +1,305 @@ +/** + * Document Detail Navigation Fix Verification Test + * + * This test specifically validates that the backToPrev function in the document detail + * component correctly preserves pagination and filter states. + */ + +import { fireEvent, render, screen } from '@testing-library/react' +import { useRouter } from 'next/navigation' +import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document' + +// Mock Next.js router +const mockPush = jest.fn() +jest.mock('next/navigation', () => ({ + useRouter: jest.fn(() => ({ + push: mockPush, + })), +})) + +// Mock the document service hooks +jest.mock('@/service/knowledge/use-document', () => ({ + useDocumentDetail: jest.fn(), + useDocumentMetadata: jest.fn(), + useInvalidDocumentList: jest.fn(() => jest.fn()), +})) + +// Mock other dependencies +jest.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContext: jest.fn(() => [null]), +})) + +jest.mock('@/service/use-base', () => ({ + useInvalid: jest.fn(() => jest.fn()), +})) + +jest.mock('@/service/knowledge/use-segment', () => ({ + useSegmentListKey: jest.fn(), + useChildSegmentListKey: jest.fn(), +})) + +// Create a minimal version of the DocumentDetail component that includes our fix +const DocumentDetailWithFix = ({ datasetId, documentId }: { datasetId: string; documentId: string }) => { + const router = useRouter() + + // This is the FIXED implementation from detail/index.tsx + const backToPrev = () => { + // Preserve pagination and filter states when navigating back + const searchParams = new URLSearchParams(window.location.search) + const queryString = searchParams.toString() + const separator = queryString ? '?' : '' + const backPath = `/datasets/${datasetId}/documents${separator}${queryString}` + router.push(backPath) + } + + return ( +
+ +
+ Dataset: {datasetId}, Document: {documentId} +
+
+ ) +} + +describe('Document Detail Navigation Fix Verification', () => { + beforeEach(() => { + jest.clearAllMocks() + + // Mock successful API responses + ;(useDocumentDetail as jest.Mock).mockReturnValue({ + data: { + id: 'doc-123', + name: 'Test Document', + display_status: 'available', + enabled: true, + archived: false, + }, + error: null, + }) + + ;(useDocumentMetadata as jest.Mock).mockReturnValue({ + data: null, + error: null, + }) + }) + + describe('Query Parameter Preservation', () => { + test('preserves pagination state (page 3, limit 25)', () => { + // Simulate user coming from page 3 with 25 items per page + Object.defineProperty(window, 'location', { + value: { + search: '?page=3&limit=25', + }, + writable: true, + }) + + render() + + // User clicks back button + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // Should preserve the pagination state + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=3&limit=25') + + console.log('✅ Pagination state preserved: page=3&limit=25') + }) + + test('preserves search keyword and filters', () => { + // Simulate user with search and filters applied + Object.defineProperty(window, 'location', { + value: { + search: '?page=2&limit=10&keyword=API%20documentation&status=active', + }, + writable: true, + }) + + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // Should preserve all query parameters + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=10&keyword=API+documentation&status=active') + + console.log('✅ Search and filters preserved') + }) + + test('handles complex query parameters with special characters', () => { + // Test with complex query string including encoded characters + Object.defineProperty(window, 'location', { + value: { + search: '?page=1&limit=50&keyword=test%20%26%20debug&sort=name&order=desc&filter=%7B%22type%22%3A%22pdf%22%7D', + }, + writable: true, + }) + + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // URLSearchParams will normalize the encoding, but preserve all parameters + const expectedCall = mockPush.mock.calls[0][0] + expect(expectedCall).toMatch(/^\/datasets\/dataset-123\/documents\?/) + expect(expectedCall).toMatch(/page=1/) + expect(expectedCall).toMatch(/limit=50/) + expect(expectedCall).toMatch(/keyword=test/) + expect(expectedCall).toMatch(/sort=name/) + expect(expectedCall).toMatch(/order=desc/) + + console.log('✅ Complex query parameters handled:', expectedCall) + }) + + test('handles empty query parameters gracefully', () => { + // No query parameters in URL + Object.defineProperty(window, 'location', { + value: { + search: '', + }, + writable: true, + }) + + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // Should navigate to clean documents URL + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents') + + console.log('✅ Empty parameters handled gracefully') + }) + }) + + describe('Different Dataset IDs', () => { + test('works with different dataset identifiers', () => { + Object.defineProperty(window, 'location', { + value: { + search: '?page=5&limit=10', + }, + writable: true, + }) + + // Test with different dataset ID format + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + expect(mockPush).toHaveBeenCalledWith('/datasets/ds-prod-2024-001/documents?page=5&limit=10') + + console.log('✅ Works with different dataset ID formats') + }) + }) + + describe('Real User Scenarios', () => { + test('scenario: user searches, goes to page 3, views document, clicks back', () => { + // User searched for "API" and navigated to page 3 + Object.defineProperty(window, 'location', { + value: { + search: '?keyword=API&page=3&limit=10', + }, + writable: true, + }) + + render() + + // User decides to go back to continue browsing + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // Should return to page 3 of API search results + expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?keyword=API&page=3&limit=10') + + console.log('✅ Real user scenario: search + pagination preserved') + }) + + test('scenario: user applies multiple filters, goes to document, returns', () => { + // User has applied multiple filters and is on page 2 + Object.defineProperty(window, 'location', { + value: { + search: '?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc', + }, + writable: true, + }) + + render() + + fireEvent.click(screen.getByTestId('back-button-fixed')) + + // All filters should be preserved + expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-dataset/documents?page=2&limit=25&status=active&type=pdf&sort=created_at&order=desc') + + console.log('✅ Complex filtering scenario preserved') + }) + }) + + describe('Error Handling and Edge Cases', () => { + test('handles malformed query parameters gracefully', () => { + // Test with potentially problematic query string + Object.defineProperty(window, 'location', { + value: { + search: '?page=invalid&limit=&keyword=test&=emptykey&malformed', + }, + writable: true, + }) + + render() + + // Should not throw errors + expect(() => { + fireEvent.click(screen.getByTestId('back-button-fixed')) + }).not.toThrow() + + // Should still attempt navigation (URLSearchParams will clean up the parameters) + expect(mockPush).toHaveBeenCalled() + const navigationPath = mockPush.mock.calls[0][0] + expect(navigationPath).toMatch(/^\/datasets\/dataset-123\/documents/) + + console.log('✅ Malformed parameters handled gracefully:', navigationPath) + }) + + test('handles very long query strings', () => { + // Test with a very long query string + const longKeyword = 'a'.repeat(1000) + Object.defineProperty(window, 'location', { + value: { + search: `?page=1&keyword=${longKeyword}`, + }, + writable: true, + }) + + render() + + expect(() => { + fireEvent.click(screen.getByTestId('back-button-fixed')) + }).not.toThrow() + + expect(mockPush).toHaveBeenCalled() + + console.log('✅ Long query strings handled') + }) + }) + + describe('Performance Verification', () => { + test('navigation function executes quickly', () => { + Object.defineProperty(window, 'location', { + value: { + search: '?page=1&limit=10&keyword=test', + }, + writable: true, + }) + + render() + + const startTime = performance.now() + fireEvent.click(screen.getByTestId('back-button-fixed')) + const endTime = performance.now() + + const executionTime = endTime - startTime + + // Should execute in less than 10ms + expect(executionTime).toBeLessThan(10) + + console.log(`⚡ Navigation execution time: ${executionTime.toFixed(2)}ms`) + }) + }) +}) diff --git a/web/__tests__/document-list-sorting.test.tsx b/web/__tests__/document-list-sorting.test.tsx new file mode 100644 index 0000000000..1510dbec23 --- /dev/null +++ b/web/__tests__/document-list-sorting.test.tsx @@ -0,0 +1,83 @@ +/** + * Document List Sorting Tests + */ + +describe('Document List Sorting', () => { + const mockDocuments = [ + { id: '1', name: 'Beta.pdf', word_count: 500, hit_count: 10, created_at: 1699123456 }, + { id: '2', name: 'Alpha.txt', word_count: 200, hit_count: 25, created_at: 1699123400 }, + { id: '3', name: 'Gamma.docx', word_count: 800, hit_count: 5, created_at: 1699123500 }, + ] + + const sortDocuments = (docs: any[], field: string, order: 'asc' | 'desc') => { + return [...docs].sort((a, b) => { + let aValue: any + let bValue: any + + switch (field) { + case 'name': + aValue = a.name?.toLowerCase() || '' + bValue = b.name?.toLowerCase() || '' + break + case 'word_count': + aValue = a.word_count || 0 + bValue = b.word_count || 0 + break + case 'hit_count': + aValue = a.hit_count || 0 + bValue = b.hit_count || 0 + break + case 'created_at': + aValue = a.created_at + bValue = b.created_at + break + default: + return 0 + } + + if (field === 'name') { + const result = aValue.localeCompare(bValue) + return order === 'asc' ? result : -result + } + else { + const result = aValue - bValue + return order === 'asc' ? result : -result + } + }) + } + + test('sorts by name descending (default for UI consistency)', () => { + const sorted = sortDocuments(mockDocuments, 'name', 'desc') + expect(sorted.map(doc => doc.name)).toEqual(['Gamma.docx', 'Beta.pdf', 'Alpha.txt']) + }) + + test('sorts by name ascending (after toggle)', () => { + const sorted = sortDocuments(mockDocuments, 'name', 'asc') + expect(sorted.map(doc => doc.name)).toEqual(['Alpha.txt', 'Beta.pdf', 'Gamma.docx']) + }) + + test('sorts by word_count descending', () => { + const sorted = sortDocuments(mockDocuments, 'word_count', 'desc') + expect(sorted.map(doc => doc.word_count)).toEqual([800, 500, 200]) + }) + + test('sorts by hit_count descending', () => { + const sorted = sortDocuments(mockDocuments, 'hit_count', 'desc') + expect(sorted.map(doc => doc.hit_count)).toEqual([25, 10, 5]) + }) + + test('sorts by created_at descending (newest first)', () => { + const sorted = sortDocuments(mockDocuments, 'created_at', 'desc') + expect(sorted.map(doc => doc.created_at)).toEqual([1699123500, 1699123456, 1699123400]) + }) + + test('handles empty values correctly', () => { + const docsWithEmpty = [ + { id: '1', name: 'Test', word_count: 100, hit_count: 5, created_at: 1699123456 }, + { id: '2', name: 'Empty', word_count: 0, hit_count: 0, created_at: 1699123400 }, + ] + + const sorted = sortDocuments(docsWithEmpty, 'word_count', 'desc') + expect(sorted.map(doc => doc.word_count)).toEqual([100, 0]) + }) +}) diff --git a/web/__tests__/goto-anything/command-selector.test.tsx b/web/__tests__/goto-anything/command-selector.test.tsx new file mode 100644 index 0000000000..1db4be31fb --- /dev/null +++ b/web/__tests__/goto-anything/command-selector.test.tsx @@ -0,0 +1,333 @@ +import React from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import '@testing-library/jest-dom' +import CommandSelector from '../../app/components/goto-anything/command-selector' +import type { ActionItem } from '../../app/components/goto-anything/actions/types' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +jest.mock('cmdk', () => ({ + Command: { + Group: ({ children, className }: any) =>
{children}
, + Item: ({ children, onSelect, value, className }: any) => ( +
onSelect && onSelect()} + data-value={value} + data-testid={`command-item-${value}`} + > + {children} +
+ ), + }, +})) + +describe('CommandSelector', () => { + const mockActions: Record = { + app: { + key: '@app', + shortcut: '@app', + title: 'Search Applications', + description: 'Search apps', + search: jest.fn(), + }, + knowledge: { + key: '@knowledge', + shortcut: '@kb', + title: 'Search Knowledge', + description: 'Search knowledge bases', + search: jest.fn(), + }, + plugin: { + key: '@plugin', + shortcut: '@plugin', + title: 'Search Plugins', + description: 'Search plugins', + search: jest.fn(), + }, + node: { + key: '@node', + shortcut: '@node', + title: 'Search Nodes', + description: 'Search workflow nodes', + search: jest.fn(), + }, + } + + const mockOnCommandSelect = jest.fn() + const mockOnCommandValueChange = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Basic Rendering', () => { + it('should render all actions when no filter is provided', () => { + render( + , + ) + + expect(screen.getByTestId('command-item-@app')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@node')).toBeInTheDocument() + }) + + it('should render empty filter as showing all actions', () => { + render( + , + ) + + expect(screen.getByTestId('command-item-@app')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@node')).toBeInTheDocument() + }) + }) + + describe('Filtering Functionality', () => { + it('should filter actions based on searchFilter - single match', () => { + render( + , + ) + + expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument() + expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument() + expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument() + expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument() + }) + + it('should filter actions with multiple matches', () => { + render( + , + ) + + expect(screen.getByTestId('command-item-@app')).toBeInTheDocument() + expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument() + expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument() + expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument() + }) + + it('should be case-insensitive when filtering', () => { + render( + , + ) + + expect(screen.getByTestId('command-item-@app')).toBeInTheDocument() + expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument() + }) + + it('should match partial strings', () => { + render( + , + ) + + expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument() + expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument() + expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument() + expect(screen.getByTestId('command-item-@node')).toBeInTheDocument() + }) + }) + + describe('Empty State', () => { + it('should show empty state when no matches found', () => { + render( + , + ) + + expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument() + expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument() + expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument() + expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument() + + expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument() + }) + + it('should not show empty state when filter is empty', () => { + render( + , + ) + + expect(screen.queryByText('app.gotoAnything.noMatchingCommands')).not.toBeInTheDocument() + }) + }) + + describe('Selection and Highlight Management', () => { + it('should call onCommandValueChange when filter changes and first item differs', () => { + const { rerender } = render( + , + ) + + rerender( + , + ) + + expect(mockOnCommandValueChange).toHaveBeenCalledWith('@kb') + }) + + it('should not call onCommandValueChange if current value still exists', () => { + const { rerender } = render( + , + ) + + rerender( + , + ) + + expect(mockOnCommandValueChange).not.toHaveBeenCalled() + }) + + it('should handle onCommandSelect callback correctly', () => { + render( + , + ) + + const knowledgeItem = screen.getByTestId('command-item-@kb') + fireEvent.click(knowledgeItem) + + expect(mockOnCommandSelect).toHaveBeenCalledWith('@kb') + }) + }) + + describe('Edge Cases', () => { + it('should handle empty actions object', () => { + render( + , + ) + + expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() + }) + + it('should handle special characters in filter', () => { + render( + , + ) + + expect(screen.getByTestId('command-item-@app')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@node')).toBeInTheDocument() + }) + + it('should handle undefined onCommandValueChange gracefully', () => { + const { rerender } = render( + , + ) + + expect(() => { + rerender( + , + ) + }).not.toThrow() + }) + }) + + describe('Backward Compatibility', () => { + it('should work without searchFilter prop (backward compatible)', () => { + render( + , + ) + + expect(screen.getByTestId('command-item-@app')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument() + expect(screen.getByTestId('command-item-@node')).toBeInTheDocument() + }) + + it('should work without commandValue and onCommandValueChange props', () => { + render( + , + ) + + expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument() + expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/__tests__/goto-anything/search-error-handling.test.ts b/web/__tests__/goto-anything/search-error-handling.test.ts new file mode 100644 index 0000000000..d2fd921e1c --- /dev/null +++ b/web/__tests__/goto-anything/search-error-handling.test.ts @@ -0,0 +1,197 @@ +/** + * Test GotoAnything search error handling mechanisms + * + * Main validations: + * 1. @plugin search error handling when API fails + * 2. Regular search (without @prefix) error handling when API fails + * 3. Verify consistent error handling across different search types + * 4. Ensure errors don't propagate to UI layer causing "search failed" + */ + +import { Actions, searchAnything } from '@/app/components/goto-anything/actions' +import { postMarketplace } from '@/service/base' +import { fetchAppList } from '@/service/apps' +import { fetchDatasets } from '@/service/datasets' + +// Mock API functions +jest.mock('@/service/base', () => ({ + postMarketplace: jest.fn(), +})) + +jest.mock('@/service/apps', () => ({ + fetchAppList: jest.fn(), +})) + +jest.mock('@/service/datasets', () => ({ + fetchDatasets: jest.fn(), +})) + +const mockPostMarketplace = postMarketplace as jest.MockedFunction +const mockFetchAppList = fetchAppList as jest.MockedFunction +const mockFetchDatasets = fetchDatasets as jest.MockedFunction + +describe('GotoAnything Search Error Handling', () => { + beforeEach(() => { + jest.clearAllMocks() + // Suppress console.warn for clean test output + jest.spyOn(console, 'warn').mockImplementation(() => { + // Suppress console.warn for clean test output + }) + }) + + afterEach(() => { + jest.restoreAllMocks() + }) + + describe('@plugin search error handling', () => { + it('should return empty array when API fails instead of throwing error', async () => { + // Mock marketplace API failure (403 permission denied) + mockPostMarketplace.mockRejectedValue(new Error('HTTP 403: Forbidden')) + + const pluginAction = Actions.plugin + + // Directly call plugin action's search method + const result = await pluginAction.search('@plugin', 'test', 'en') + + // Should return empty array instead of throwing error + expect(result).toEqual([]) + expect(mockPostMarketplace).toHaveBeenCalledWith('/plugins/search/advanced', { + body: { + page: 1, + page_size: 10, + query: 'test', + type: 'plugin', + }, + }) + }) + + it('should return empty array when user has no plugin data', async () => { + // Mock marketplace returning empty data + mockPostMarketplace.mockResolvedValue({ + data: { plugins: [] }, + }) + + const pluginAction = Actions.plugin + const result = await pluginAction.search('@plugin', '', 'en') + + expect(result).toEqual([]) + }) + + it('should return empty array when API returns unexpected data structure', async () => { + // Mock API returning unexpected data structure + mockPostMarketplace.mockResolvedValue({ + data: null, + }) + + const pluginAction = Actions.plugin + const result = await pluginAction.search('@plugin', 'test', 'en') + + expect(result).toEqual([]) + }) + }) + + describe('Other search types error handling', () => { + it('@app search should return empty array when API fails', async () => { + // Mock app API failure + mockFetchAppList.mockRejectedValue(new Error('API Error')) + + const appAction = Actions.app + const result = await appAction.search('@app', 'test', 'en') + + expect(result).toEqual([]) + }) + + it('@knowledge search should return empty array when API fails', async () => { + // Mock knowledge API failure + mockFetchDatasets.mockRejectedValue(new Error('API Error')) + + const knowledgeAction = Actions.knowledge + const result = await knowledgeAction.search('@knowledge', 'test', 'en') + + expect(result).toEqual([]) + }) + }) + + describe('Unified search entry error handling', () => { + it('regular search (without @prefix) should return successful results even when partial APIs fail', async () => { + // Set app and knowledge success, plugin failure + mockFetchAppList.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 }) + mockFetchDatasets.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 }) + mockPostMarketplace.mockRejectedValue(new Error('Plugin API failed')) + + const result = await searchAnything('en', 'test') + + // Should return successful results even if plugin search fails + expect(result).toEqual([]) + expect(console.warn).toHaveBeenCalledWith('Plugin search failed:', expect.any(Error)) + }) + + it('@plugin dedicated search should return empty array when API fails', async () => { + // Mock plugin API failure + mockPostMarketplace.mockRejectedValue(new Error('Plugin service unavailable')) + + const pluginAction = Actions.plugin + const result = await searchAnything('en', '@plugin test', pluginAction) + + // Should return empty array instead of throwing error + expect(result).toEqual([]) + }) + + it('@app dedicated search should return empty array when API fails', async () => { + // Mock app API failure + mockFetchAppList.mockRejectedValue(new Error('App service unavailable')) + + const appAction = Actions.app + const result = await searchAnything('en', '@app test', appAction) + + expect(result).toEqual([]) + }) + }) + + describe('Error handling consistency validation', () => { + it('all search types should return empty array when encountering errors', async () => { + // Mock all APIs to fail + mockPostMarketplace.mockRejectedValue(new Error('Plugin API failed')) + mockFetchAppList.mockRejectedValue(new Error('App API failed')) + mockFetchDatasets.mockRejectedValue(new Error('Dataset API failed')) + + const actions = [ + { name: '@plugin', action: Actions.plugin }, + { name: '@app', action: Actions.app }, + { name: '@knowledge', action: Actions.knowledge }, + ] + + for (const { name, action } of actions) { + const result = await action.search(name, 'test', 'en') + expect(result).toEqual([]) + } + }) + }) + + describe('Edge case testing', () => { + it('empty search term should be handled properly', async () => { + mockPostMarketplace.mockResolvedValue({ data: { plugins: [] } }) + + const result = await searchAnything('en', '@plugin ', Actions.plugin) + expect(result).toEqual([]) + }) + + it('network timeout should be handled correctly', async () => { + const timeoutError = new Error('Network timeout') + timeoutError.name = 'TimeoutError' + + mockPostMarketplace.mockRejectedValue(timeoutError) + + const result = await searchAnything('en', '@plugin test', Actions.plugin) + expect(result).toEqual([]) + }) + + it('JSON parsing errors should be handled correctly', async () => { + const parseError = new SyntaxError('Unexpected token in JSON') + mockPostMarketplace.mockRejectedValue(parseError) + + const result = await searchAnything('en', '@plugin test', Actions.plugin) + expect(result).toEqual([]) + }) + }) +}) diff --git a/web/__tests__/i18n-upload-features.test.ts b/web/__tests__/i18n-upload-features.test.ts new file mode 100644 index 0000000000..37aefcbef4 --- /dev/null +++ b/web/__tests__/i18n-upload-features.test.ts @@ -0,0 +1,119 @@ +/** + * Test suite for verifying upload feature translations across all locales + * Specifically tests for issue #23062: Missing Upload feature translations (esp. audioUpload) across most locales + */ + +import fs from 'node:fs' +import path from 'node:path' + +// Get all supported locales from the i18n directory +const I18N_DIR = path.join(__dirname, '../i18n') +const getSupportedLocales = (): string[] => { + return fs.readdirSync(I18N_DIR) + .filter(item => fs.statSync(path.join(I18N_DIR, item)).isDirectory()) + .sort() +} + +// Helper function to load translation file content +const loadTranslationContent = (locale: string): string => { + const filePath = path.join(I18N_DIR, locale, 'app-debug.ts') + + if (!fs.existsSync(filePath)) + throw new Error(`Translation file not found: ${filePath}`) + + return fs.readFileSync(filePath, 'utf-8') +} + +// Helper function to check if upload features exist +const hasUploadFeatures = (content: string): { [key: string]: boolean } => { + return { + fileUpload: /fileUpload\s*:\s*{/.test(content), + imageUpload: /imageUpload\s*:\s*{/.test(content), + documentUpload: /documentUpload\s*:\s*{/.test(content), + audioUpload: /audioUpload\s*:\s*{/.test(content), + featureBar: /bar\s*:\s*{/.test(content), + } +} + +describe('Upload Features i18n Translations - Issue #23062', () => { + let supportedLocales: string[] + + beforeAll(() => { + supportedLocales = getSupportedLocales() + console.log(`Testing ${supportedLocales.length} locales for upload features`) + }) + + test('all locales should have translation files', () => { + supportedLocales.forEach((locale) => { + const filePath = path.join(I18N_DIR, locale, 'app-debug.ts') + expect(fs.existsSync(filePath)).toBe(true) + }) + }) + + test('all locales should have required upload features', () => { + const results: { [locale: string]: { [feature: string]: boolean } } = {} + + supportedLocales.forEach((locale) => { + const content = loadTranslationContent(locale) + const features = hasUploadFeatures(content) + results[locale] = features + + // Check that all upload features exist + expect(features.fileUpload).toBe(true) + expect(features.imageUpload).toBe(true) + expect(features.documentUpload).toBe(true) + expect(features.audioUpload).toBe(true) + expect(features.featureBar).toBe(true) + }) + + console.log('✅ All locales have complete upload features') + }) + + test('previously missing locales should now have audioUpload - Issue #23062', () => { + // These locales were specifically missing audioUpload + const previouslyMissingLocales = ['fa-IR', 'hi-IN', 'ro-RO', 'sl-SI', 'th-TH', 'uk-UA', 'vi-VN'] + + previouslyMissingLocales.forEach((locale) => { + const content = loadTranslationContent(locale) + + // Verify audioUpload exists + expect(/audioUpload\s*:\s*{/.test(content)).toBe(true) + + // Verify it has title and description + expect(/audioUpload[^}]*title\s*:/.test(content)).toBe(true) + expect(/audioUpload[^}]*description\s*:/.test(content)).toBe(true) + + console.log(`✅ ${locale} - Issue #23062 resolved: audioUpload feature present`) + }) + }) + + test('upload features should have required properties', () => { + supportedLocales.forEach((locale) => { + const content = loadTranslationContent(locale) + + // Check fileUpload has required properties + if (/fileUpload\s*:\s*{/.test(content)) { + expect(/fileUpload[^}]*title\s*:/.test(content)).toBe(true) + expect(/fileUpload[^}]*description\s*:/.test(content)).toBe(true) + } + + // Check imageUpload has required properties + if (/imageUpload\s*:\s*{/.test(content)) { + expect(/imageUpload[^}]*title\s*:/.test(content)).toBe(true) + expect(/imageUpload[^}]*description\s*:/.test(content)).toBe(true) + } + + // Check documentUpload has required properties + if (/documentUpload\s*:\s*{/.test(content)) { + expect(/documentUpload[^}]*title\s*:/.test(content)).toBe(true) + expect(/documentUpload[^}]*description\s*:/.test(content)).toBe(true) + } + + // Check audioUpload has required properties + if (/audioUpload\s*:\s*{/.test(content)) { + expect(/audioUpload[^}]*title\s*:/.test(content)).toBe(true) + expect(/audioUpload[^}]*description\s*:/.test(content)).toBe(true) + } + }) + }) +}) diff --git a/web/__tests__/navigation-utils.test.ts b/web/__tests__/navigation-utils.test.ts new file mode 100644 index 0000000000..9a388505d6 --- /dev/null +++ b/web/__tests__/navigation-utils.test.ts @@ -0,0 +1,290 @@ +/** + * Navigation Utilities Test + * + * Tests for the navigation utility functions to ensure they handle + * query parameter preservation correctly across different scenarios. + */ + +import { + createBackNavigation, + createNavigationPath, + createNavigationPathWithParams, + datasetNavigation, + extractQueryParams, + mergeQueryParams, +} from '@/utils/navigation' + +// Mock router for testing +const mockPush = jest.fn() +const mockRouter = { push: mockPush } + +describe('Navigation Utilities', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('createNavigationPath', () => { + test('preserves query parameters by default', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10&keyword=test' }, + writable: true, + }) + + const path = createNavigationPath('/datasets/123/documents') + expect(path).toBe('/datasets/123/documents?page=3&limit=10&keyword=test') + }) + + test('returns clean path when preserveParams is false', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10' }, + writable: true, + }) + + const path = createNavigationPath('/datasets/123/documents', false) + expect(path).toBe('/datasets/123/documents') + }) + + test('handles empty query parameters', () => { + Object.defineProperty(window, 'location', { + value: { search: '' }, + writable: true, + }) + + const path = createNavigationPath('/datasets/123/documents') + expect(path).toBe('/datasets/123/documents') + }) + + test('handles errors gracefully', () => { + // Mock window.location to throw an error + Object.defineProperty(window, 'location', { + get: () => { + throw new Error('Location access denied') + }, + configurable: true, + }) + + const consoleSpy = jest.spyOn(console, 'warn').mockImplementation() + const path = createNavigationPath('/datasets/123/documents') + + expect(path).toBe('/datasets/123/documents') + expect(consoleSpy).toHaveBeenCalledWith('Failed to preserve query parameters:', expect.any(Error)) + + consoleSpy.mockRestore() + }) + }) + + describe('createBackNavigation', () => { + test('creates function that navigates with preserved params', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=2&limit=25' }, + writable: true, + }) + + const backFn = createBackNavigation(mockRouter, '/datasets/123/documents') + backFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents?page=2&limit=25') + }) + + test('creates function that navigates without params when specified', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=2&limit=25' }, + writable: true, + }) + + const backFn = createBackNavigation(mockRouter, '/datasets/123/documents', false) + backFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/123/documents') + }) + }) + + describe('extractQueryParams', () => { + test('extracts specified parameters', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10&keyword=test&other=value' }, + writable: true, + }) + + const params = extractQueryParams(['page', 'limit', 'keyword']) + expect(params).toEqual({ + page: '3', + limit: '10', + keyword: 'test', + }) + }) + + test('handles missing parameters', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3' }, + writable: true, + }) + + const params = extractQueryParams(['page', 'limit', 'missing']) + expect(params).toEqual({ + page: '3', + }) + }) + + test('handles errors gracefully', () => { + Object.defineProperty(window, 'location', { + get: () => { + throw new Error('Location access denied') + }, + configurable: true, + }) + + const consoleSpy = jest.spyOn(console, 'warn').mockImplementation() + const params = extractQueryParams(['page', 'limit']) + + expect(params).toEqual({}) + expect(consoleSpy).toHaveBeenCalledWith('Failed to extract query parameters:', expect.any(Error)) + + consoleSpy.mockRestore() + }) + }) + + describe('createNavigationPathWithParams', () => { + test('creates path with specified parameters', () => { + const path = createNavigationPathWithParams('/datasets/123/documents', { + page: 1, + limit: 25, + keyword: 'search term', + }) + + expect(path).toBe('/datasets/123/documents?page=1&limit=25&keyword=search+term') + }) + + test('filters out empty values', () => { + const path = createNavigationPathWithParams('/datasets/123/documents', { + page: 1, + limit: '', + keyword: 'test', + empty: null, + undefined, + }) + + expect(path).toBe('/datasets/123/documents?page=1&keyword=test') + }) + + test('handles errors gracefully', () => { + // Mock URLSearchParams to throw an error + const originalURLSearchParams = globalThis.URLSearchParams + globalThis.URLSearchParams = jest.fn(() => { + throw new Error('URLSearchParams error') + }) as any + + const consoleSpy = jest.spyOn(console, 'warn').mockImplementation() + const path = createNavigationPathWithParams('/datasets/123/documents', { page: 1 }) + + expect(path).toBe('/datasets/123/documents') + expect(consoleSpy).toHaveBeenCalledWith('Failed to create navigation path with params:', expect.any(Error)) + + consoleSpy.mockRestore() + globalThis.URLSearchParams = originalURLSearchParams + }) + }) + + describe('mergeQueryParams', () => { + test('merges new params with existing ones', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10' }, + writable: true, + }) + + const merged = mergeQueryParams({ keyword: 'test', page: '1' }) + const result = merged.toString() + + expect(result).toContain('page=1') // overridden + expect(result).toContain('limit=10') // preserved + expect(result).toContain('keyword=test') // added + }) + + test('removes parameters when value is null', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10&keyword=test' }, + writable: true, + }) + + const merged = mergeQueryParams({ keyword: null, filter: 'active' }) + const result = merged.toString() + + expect(result).toContain('page=3') + expect(result).toContain('limit=10') + expect(result).not.toContain('keyword') + expect(result).toContain('filter=active') + }) + + test('creates fresh params when preserveExisting is false', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=3&limit=10' }, + writable: true, + }) + + const merged = mergeQueryParams({ keyword: 'test' }, false) + const result = merged.toString() + + expect(result).toBe('keyword=test') + }) + }) + + describe('datasetNavigation', () => { + test('backToDocuments creates correct navigation function', () => { + Object.defineProperty(window, 'location', { + value: { search: '?page=2&limit=25' }, + writable: true, + }) + + const backFn = datasetNavigation.backToDocuments(mockRouter, 'dataset-123') + backFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents?page=2&limit=25') + }) + + test('toDocumentDetail creates correct navigation function', () => { + const detailFn = datasetNavigation.toDocumentDetail(mockRouter, 'dataset-123', 'doc-456') + detailFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456') + }) + + test('toDocumentSettings creates correct navigation function', () => { + const settingsFn = datasetNavigation.toDocumentSettings(mockRouter, 'dataset-123', 'doc-456') + settingsFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-123/documents/doc-456/settings') + }) + }) + + describe('Real-world Integration Scenarios', () => { + test('complete user workflow: list -> detail -> back', () => { + // User starts on page 3 with search + Object.defineProperty(window, 'location', { + value: { search: '?page=3&keyword=API&limit=25' }, + writable: true, + }) + + // Create back navigation function (as would be done in detail component) + const backToDocuments = datasetNavigation.backToDocuments(mockRouter, 'main-dataset') + + // User clicks back + backToDocuments() + + // Should return to exact same list state + expect(mockPush).toHaveBeenCalledWith('/datasets/main-dataset/documents?page=3&keyword=API&limit=25') + }) + + test('user applies filters then views document', () => { + // Complex filter state + Object.defineProperty(window, 'location', { + value: { search: '?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc' }, + writable: true, + }) + + const backFn = createBackNavigation(mockRouter, '/datasets/filtered-set/documents') + backFn() + + expect(mockPush).toHaveBeenCalledWith('/datasets/filtered-set/documents?page=1&limit=50&status=active&type=pdf&sort=created_at&order=desc') + }) + }) +}) diff --git a/web/__tests__/plugin-tool-workflow-error.test.tsx b/web/__tests__/plugin-tool-workflow-error.test.tsx new file mode 100644 index 0000000000..370052bc80 --- /dev/null +++ b/web/__tests__/plugin-tool-workflow-error.test.tsx @@ -0,0 +1,207 @@ +/** + * Test cases to reproduce the plugin tool workflow error + * Issue: #23154 - Application error when loading plugin tools in workflow + * Root cause: split() operation called on null/undefined values + */ + +describe('Plugin Tool Workflow Error Reproduction', () => { + /** + * Mock function to simulate the problematic code in switch-plugin-version.tsx:29 + * const [pluginId] = uniqueIdentifier.split(':') + */ + const mockSwitchPluginVersionLogic = (uniqueIdentifier: string | null | undefined) => { + // This directly reproduces the problematic line from switch-plugin-version.tsx:29 + const [pluginId] = uniqueIdentifier!.split(':') + return pluginId + } + + /** + * Test case 1: Simulate null uniqueIdentifier + * This should reproduce the error mentioned in the issue + */ + it('should reproduce error when uniqueIdentifier is null', () => { + expect(() => { + mockSwitchPluginVersionLogic(null) + }).toThrow('Cannot read properties of null (reading \'split\')') + }) + + /** + * Test case 2: Simulate undefined uniqueIdentifier + */ + it('should reproduce error when uniqueIdentifier is undefined', () => { + expect(() => { + mockSwitchPluginVersionLogic(undefined) + }).toThrow('Cannot read properties of undefined (reading \'split\')') + }) + + /** + * Test case 3: Simulate empty string uniqueIdentifier + */ + it('should handle empty string uniqueIdentifier', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('') + expect(result).toBe('') // Empty string split by ':' returns [''] + }).not.toThrow() + }) + + /** + * Test case 4: Simulate malformed uniqueIdentifier without colon separator + */ + it('should handle malformed uniqueIdentifier without colon separator', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('malformed-identifier-without-colon') + expect(result).toBe('malformed-identifier-without-colon') // No colon means full string returned + }).not.toThrow() + }) + + /** + * Test case 5: Simulate valid uniqueIdentifier + */ + it('should work correctly with valid uniqueIdentifier', () => { + expect(() => { + const result = mockSwitchPluginVersionLogic('valid-plugin-id:1.0.0') + expect(result).toBe('valid-plugin-id') + }).not.toThrow() + }) +}) + +/** + * Test for the variable processing split error in use-single-run-form-params + */ +describe('Variable Processing Split Error', () => { + /** + * Mock function to simulate the problematic code in use-single-run-form-params.ts:91 + * const getDependentVars = () => { + * return varInputs.map(item => item.variable.slice(1, -1).split('.')) + * } + */ + const mockGetDependentVars = (varInputs: Array<{ variable: string | null | undefined }>) => { + return varInputs.map((item) => { + // Guard against null/undefined variable to prevent app crash + if (!item.variable || typeof item.variable !== 'string') + return [] + + return item.variable.slice(1, -1).split('.') + }).filter(arr => arr.length > 0) // Filter out empty arrays + } + + /** + * Test case 1: Variable processing with null variable + */ + it('should handle null variable safely', () => { + const varInputs = [{ variable: null }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // null variables are filtered out + }) + + /** + * Test case 2: Variable processing with undefined variable + */ + it('should handle undefined variable safely', () => { + const varInputs = [{ variable: undefined }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // undefined variables are filtered out + }) + + /** + * Test case 3: Variable processing with empty string + */ + it('should handle empty string variable', () => { + const varInputs = [{ variable: '' }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result).toEqual([]) // Empty string is filtered out, so result is empty array + }) + + /** + * Test case 4: Variable processing with valid variable format + */ + it('should work correctly with valid variable format', () => { + const varInputs = [{ variable: '{{workflow.node.output}}' }] + + expect(() => { + mockGetDependentVars(varInputs) + }).not.toThrow() + + const result = mockGetDependentVars(varInputs) + expect(result[0]).toEqual(['{workflow', 'node', 'output}']) + }) +}) + +/** + * Integration test to simulate the complete workflow scenario + */ +describe('Plugin Tool Workflow Integration', () => { + /** + * Simulate the scenario where plugin metadata is incomplete or corrupted + * This can happen when: + * 1. Plugin is being loaded from marketplace but metadata request fails + * 2. Plugin configuration is corrupted in database + * 3. Network issues during plugin loading + */ + it('should reproduce the client-side exception scenario', () => { + // Mock incomplete plugin data that could cause the error + const incompletePluginData = { + // Missing or null uniqueIdentifier + uniqueIdentifier: null, + meta: null, + minimum_dify_version: undefined, + } + + // This simulates the error path that leads to the white screen + expect(() => { + // Simulate the code path in switch-plugin-version.tsx:29 + // The actual problematic code doesn't use optional chaining + const _pluginId = (incompletePluginData.uniqueIdentifier as any).split(':')[0] + }).toThrow('Cannot read properties of null (reading \'split\')') + }) + + /** + * Test the scenario mentioned in the issue where plugin tools are loaded in workflow + */ + it('should simulate plugin tool loading in workflow context', () => { + // Mock the workflow context where plugin tools are being loaded + const workflowPluginTools = [ + { + provider_name: 'test-plugin', + uniqueIdentifier: null, // This is the problematic case + tool_name: 'test-tool', + }, + { + provider_name: 'valid-plugin', + uniqueIdentifier: 'valid-plugin:1.0.0', + tool_name: 'valid-tool', + }, + ] + + // Process each plugin tool + workflowPluginTools.forEach((tool, _index) => { + if (tool.uniqueIdentifier === null) { + // This reproduces the exact error scenario + expect(() => { + const _pluginId = (tool.uniqueIdentifier as any).split(':')[0] + }).toThrow() + } + else { + // Valid tools should work fine + expect(() => { + const _pluginId = tool.uniqueIdentifier.split(':')[0] + }).not.toThrow() + } + }) + }) +}) diff --git a/web/__tests__/real-browser-flicker.test.tsx b/web/__tests__/real-browser-flicker.test.tsx new file mode 100644 index 0000000000..cf3abd5f80 --- /dev/null +++ b/web/__tests__/real-browser-flicker.test.tsx @@ -0,0 +1,445 @@ +/** + * Real Browser Environment Dark Mode Flicker Test + * + * This test attempts to simulate real browser refresh scenarios including: + * 1. SSR HTML generation phase + * 2. Client-side JavaScript loading + * 3. Theme system initialization + * 4. CSS styles application timing + */ + +import { render, screen, waitFor } from '@testing-library/react' +import { ThemeProvider } from 'next-themes' +import useTheme from '@/hooks/use-theme' +import { useEffect, useState } from 'react' + +// Setup browser environment for testing +const setupMockEnvironment = (storedTheme: string | null, systemPrefersDark = false) => { + // Mock localStorage + const mockStorage = { + getItem: jest.fn((key: string) => { + if (key === 'theme') return storedTheme + return null + }), + setItem: jest.fn(), + removeItem: jest.fn(), + } + + // Mock system theme preference + const mockMatchMedia = jest.fn((query: string) => ({ + matches: query.includes('dark') && systemPrefersDark, + media: query, + addListener: jest.fn(), + removeListener: jest.fn(), + })) + + if (typeof window !== 'undefined') { + Object.defineProperty(window, 'localStorage', { + value: mockStorage, + configurable: true, + }) + + Object.defineProperty(window, 'matchMedia', { + value: mockMatchMedia, + configurable: true, + }) + } + + return { mockStorage, mockMatchMedia } +} + +// Simulate real page component based on Dify's actual theme usage +const PageComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + + useEffect(() => { + setMounted(true) + }, []) + + // Simulate common theme usage pattern in Dify + const isDark = mounted ? theme === 'dark' : false + + return ( +
+
+

+ Dify Application +

+
+ Current Theme: {mounted ? theme : 'unknown'} +
+
+ Appearance: {isDark ? 'dark' : 'light'} +
+
+
+ ) +} + +const TestThemeProvider = ({ children }: { children: React.ReactNode }) => ( + + {children} + +) + +describe('Real Browser Environment Dark Mode Flicker Test', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + describe('Page Refresh Scenario Simulation', () => { + test('simulates complete page loading process with dark theme', async () => { + // Setup: User previously selected dark mode + setupMockEnvironment('dark') + + render( + + + , + ) + + // Check initial client-side rendering state + const initialState = { + theme: screen.getByTestId('theme-indicator').textContent, + appearance: screen.getByTestId('visual-appearance').textContent, + } + console.log('Initial client state:', initialState) + + // Wait for theme system to fully initialize + await waitFor(() => { + expect(screen.getByTestId('theme-indicator')).toHaveTextContent('Current Theme: dark') + }) + + const finalState = { + theme: screen.getByTestId('theme-indicator').textContent, + appearance: screen.getByTestId('visual-appearance').textContent, + } + console.log('Final state:', finalState) + + // Document the state change - this is the source of flicker + console.log('State change detection: Initial -> Final') + }) + + test('handles light theme correctly', async () => { + setupMockEnvironment('light') + + render( + + + , + ) + + await waitFor(() => { + expect(screen.getByTestId('theme-indicator')).toHaveTextContent('Current Theme: light') + }) + + expect(screen.getByTestId('visual-appearance')).toHaveTextContent('Appearance: light') + }) + + test('handles system theme with dark preference', async () => { + setupMockEnvironment('system', true) // system theme, dark preference + + render( + + + , + ) + + await waitFor(() => { + expect(screen.getByTestId('theme-indicator')).toHaveTextContent('Current Theme: dark') + }) + + expect(screen.getByTestId('visual-appearance')).toHaveTextContent('Appearance: dark') + }) + + test('handles system theme with light preference', async () => { + setupMockEnvironment('system', false) // system theme, light preference + + render( + + + , + ) + + await waitFor(() => { + expect(screen.getByTestId('theme-indicator')).toHaveTextContent('Current Theme: light') + }) + + expect(screen.getByTestId('visual-appearance')).toHaveTextContent('Appearance: light') + }) + + test('handles no stored theme (defaults to system)', async () => { + setupMockEnvironment(null, false) // no stored theme, system prefers light + + render( + + + , + ) + + await waitFor(() => { + expect(screen.getByTestId('theme-indicator')).toHaveTextContent('Current Theme: light') + }) + }) + + test('measures timing window of style changes', async () => { + setupMockEnvironment('dark') + + const timingData: Array<{ phase: string; timestamp: number; styles: any }> = [] + + const TimingPageComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + const isDark = mounted ? theme === 'dark' : false + + // Record timing and styles for each render phase + const currentStyles = { + backgroundColor: isDark ? '#1f2937' : '#ffffff', + color: isDark ? '#ffffff' : '#000000', + } + + timingData.push({ + phase: mounted ? 'CSR' : 'Initial', + timestamp: performance.now(), + styles: currentStyles, + }) + + useEffect(() => { + setMounted(true) + }, []) + + return ( +
+
+ Phase: {mounted ? 'CSR' : 'Initial'} | Theme: {theme} | Visual: {isDark ? 'dark' : 'light'} +
+
+ ) + } + + render( + + + , + ) + + await waitFor(() => { + expect(screen.getByTestId('timing-status')).toHaveTextContent('Phase: CSR') + }) + + // Analyze timing and style changes + console.log('\n=== Style Change Timeline ===') + timingData.forEach((data, index) => { + console.log(`${index + 1}. ${data.phase}: bg=${data.styles.backgroundColor}, color=${data.styles.color}`) + }) + + // Check if there are style changes (this is visible flicker) + const hasStyleChange = timingData.length > 1 + && timingData[0].styles.backgroundColor !== timingData[timingData.length - 1].styles.backgroundColor + + if (hasStyleChange) + console.log('⚠️ Style changes detected - this causes visible flicker') + else + console.log('✅ No style changes detected') + + expect(timingData.length).toBeGreaterThan(1) + }) + }) + + describe('CSS Application Timing Tests', () => { + test('checks CSS class changes causing flicker', async () => { + setupMockEnvironment('dark') + + const cssStates: Array<{ className: string; timestamp: number }> = [] + + const CSSTestComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + const isDark = mounted ? theme === 'dark' : false + + // Simulate Tailwind CSS class application + const className = `min-h-screen ${isDark ? 'bg-gray-900 text-white' : 'bg-white text-black'}` + + cssStates.push({ + className, + timestamp: performance.now(), + }) + + useEffect(() => { + setMounted(true) + }, []) + + return ( +
+
Classes: {className}
+
+ ) + } + + render( + + + , + ) + + await waitFor(() => { + expect(screen.getByTestId('css-classes')).toHaveTextContent('bg-gray-900 text-white') + }) + + console.log('\n=== CSS Class Change Detection ===') + cssStates.forEach((state, index) => { + console.log(`${index + 1}. ${state.className}`) + }) + + // Check if CSS classes have changed + const hasCSSChange = cssStates.length > 1 + && cssStates[0].className !== cssStates[cssStates.length - 1].className + + if (hasCSSChange) { + console.log('⚠️ CSS class changes detected - may cause style flicker') + console.log(`From: "${cssStates[0].className}"`) + console.log(`To: "${cssStates[cssStates.length - 1].className}"`) + } + + expect(hasCSSChange).toBe(true) // We expect to see this change + }) + }) + + describe('Edge Cases and Error Handling', () => { + test('handles localStorage access errors gracefully', async () => { + // Mock localStorage to throw an error + const mockStorage = { + getItem: jest.fn(() => { + throw new Error('LocalStorage access denied') + }), + setItem: jest.fn(), + removeItem: jest.fn(), + } + + if (typeof window !== 'undefined') { + Object.defineProperty(window, 'localStorage', { + value: mockStorage, + configurable: true, + }) + } + + render( + + + , + ) + + // Should fallback gracefully without crashing + await waitFor(() => { + expect(screen.getByTestId('theme-indicator')).toBeInTheDocument() + }) + + // Should default to light theme when localStorage fails + expect(screen.getByTestId('visual-appearance')).toHaveTextContent('Appearance: light') + }) + + test('handles invalid theme values in localStorage', async () => { + setupMockEnvironment('invalid-theme-value') + + render( + + + , + ) + + await waitFor(() => { + expect(screen.getByTestId('theme-indicator')).toBeInTheDocument() + }) + + // Should handle invalid values gracefully + const themeIndicator = screen.getByTestId('theme-indicator') + expect(themeIndicator).toBeInTheDocument() + }) + }) + + describe('Performance and Regression Tests', () => { + test('verifies ThemeProvider position fix reduces initialization delay', async () => { + const performanceMarks: Array<{ event: string; timestamp: number }> = [] + + const PerformanceTestComponent = () => { + const [mounted, setMounted] = useState(false) + const { theme } = useTheme() + + performanceMarks.push({ event: 'component-render', timestamp: performance.now() }) + + useEffect(() => { + performanceMarks.push({ event: 'mount-start', timestamp: performance.now() }) + setMounted(true) + performanceMarks.push({ event: 'mount-complete', timestamp: performance.now() }) + }, []) + + useEffect(() => { + if (theme) + performanceMarks.push({ event: 'theme-available', timestamp: performance.now() }) + }, [theme]) + + return ( +
+ Mounted: {mounted.toString()} | Theme: {theme || 'loading'} +
+ ) + } + + setupMockEnvironment('dark') + + render( + + + , + ) + + await waitFor(() => { + expect(screen.getByTestId('performance-test')).toHaveTextContent('Theme: dark') + }) + + // Analyze performance timeline + console.log('\n=== Performance Timeline ===') + performanceMarks.forEach((mark) => { + console.log(`${mark.event}: ${mark.timestamp.toFixed(2)}ms`) + }) + + expect(performanceMarks.length).toBeGreaterThan(3) + }) + }) + + describe('Solution Requirements Definition', () => { + test('defines technical requirements to eliminate flicker', () => { + const technicalRequirements = { + ssrConsistency: 'SSR and CSR must render identical initial styles', + synchronousDetection: 'Theme detection must complete synchronously before first render', + noStyleChanges: 'No visible style changes should occur after hydration', + performanceImpact: 'Solution should not significantly impact page load performance', + browserCompatibility: 'Must work consistently across all major browsers', + } + + console.log('\n=== Technical Requirements ===') + Object.entries(technicalRequirements).forEach(([key, requirement]) => { + console.log(`${key}: ${requirement}`) + expect(requirement).toBeDefined() + }) + + // A successful solution should pass all these requirements + }) + }) +}) diff --git a/web/__tests__/unified-tags-logic.test.ts b/web/__tests__/unified-tags-logic.test.ts new file mode 100644 index 0000000000..c920e28e0a --- /dev/null +++ b/web/__tests__/unified-tags-logic.test.ts @@ -0,0 +1,396 @@ +/** + * Unified Tags Editing - Pure Logic Tests + * + * This test file validates the core business logic and state management + * behaviors introduced in the recent 7 commits without requiring complex mocks. + */ + +describe('Unified Tags Editing - Pure Logic Tests', () => { + describe('Tag State Management Logic', () => { + it('should detect when tag values have changed', () => { + const currentValue = ['tag1', 'tag2'] + const newSelectedTagIDs = ['tag1', 'tag3'] + + // This is the valueNotChanged logic from TagSelector component + const valueNotChanged + = currentValue.length === newSelectedTagIDs.length + && currentValue.every(v => newSelectedTagIDs.includes(v)) + && newSelectedTagIDs.every(v => currentValue.includes(v)) + + expect(valueNotChanged).toBe(false) + }) + + it('should correctly identify unchanged tag values', () => { + const currentValue = ['tag1', 'tag2'] + const newSelectedTagIDs = ['tag2', 'tag1'] // Same tags, different order + + const valueNotChanged + = currentValue.length === newSelectedTagIDs.length + && currentValue.every(v => newSelectedTagIDs.includes(v)) + && newSelectedTagIDs.every(v => currentValue.includes(v)) + + expect(valueNotChanged).toBe(true) + }) + + it('should calculate correct tag operations for binding/unbinding', () => { + const currentValue = ['tag1', 'tag2'] + const selectedTagIDs = ['tag2', 'tag3'] + + // This is the handleValueChange logic from TagSelector + const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v)) + const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v)) + + expect(addTagIDs).toEqual(['tag3']) + expect(removeTagIDs).toEqual(['tag1']) + }) + + it('should handle empty tag arrays correctly', () => { + const currentValue: string[] = [] + const selectedTagIDs = ['tag1'] + + const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v)) + const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v)) + + expect(addTagIDs).toEqual(['tag1']) + expect(removeTagIDs).toEqual([]) + expect(currentValue.length).toBe(0) // Verify empty array usage + }) + + it('should handle removing all tags', () => { + const currentValue = ['tag1', 'tag2'] + const selectedTagIDs: string[] = [] + + const addTagIDs = selectedTagIDs.filter(v => !currentValue.includes(v)) + const removeTagIDs = currentValue.filter(v => !selectedTagIDs.includes(v)) + + expect(addTagIDs).toEqual([]) + expect(removeTagIDs).toEqual(['tag1', 'tag2']) + expect(selectedTagIDs.length).toBe(0) // Verify empty array usage + }) + }) + + describe('Fallback Logic (from layout-main.tsx)', () => { + it('should trigger fallback when tags are missing or empty', () => { + const appDetailWithoutTags = { tags: [] } + const appDetailWithTags = { tags: [{ id: 'tag1' }] } + const appDetailWithUndefinedTags = { tags: undefined as any } + + // This simulates the condition in layout-main.tsx + const shouldFallback1 = !appDetailWithoutTags.tags || appDetailWithoutTags.tags.length === 0 + const shouldFallback2 = !appDetailWithTags.tags || appDetailWithTags.tags.length === 0 + const shouldFallback3 = !appDetailWithUndefinedTags.tags || appDetailWithUndefinedTags.tags.length === 0 + + expect(shouldFallback1).toBe(true) // Empty array should trigger fallback + expect(shouldFallback2).toBe(false) // Has tags, no fallback needed + expect(shouldFallback3).toBe(true) // Undefined tags should trigger fallback + }) + + it('should preserve tags when fallback succeeds', () => { + const originalAppDetail = { tags: [] as any[] } + const fallbackResult = { tags: [{ id: 'tag1', name: 'fallback-tag' }] } + + // This simulates the successful fallback in layout-main.tsx + if (fallbackResult?.tags) + originalAppDetail.tags = fallbackResult.tags + + expect(originalAppDetail.tags).toEqual(fallbackResult.tags) + expect(originalAppDetail.tags.length).toBe(1) + }) + + it('should continue with empty tags when fallback fails', () => { + const originalAppDetail: { tags: any[] } = { tags: [] } + const fallbackResult: { tags?: any[] } | null = null + + // This simulates fallback failure in layout-main.tsx + if (fallbackResult?.tags) + originalAppDetail.tags = fallbackResult.tags + + expect(originalAppDetail.tags).toEqual([]) + }) + }) + + describe('TagSelector Auto-initialization Logic', () => { + it('should trigger getTagList when tagList is empty', () => { + const tagList: any[] = [] + let getTagListCalled = false + const getTagList = () => { + getTagListCalled = true + } + + // This simulates the useEffect in TagSelector + if (tagList.length === 0) + getTagList() + + expect(getTagListCalled).toBe(true) + }) + + it('should not trigger getTagList when tagList has items', () => { + const tagList = [{ id: 'tag1', name: 'existing-tag' }] + let getTagListCalled = false + const getTagList = () => { + getTagListCalled = true + } + + // This simulates the useEffect in TagSelector + if (tagList.length === 0) + getTagList() + + expect(getTagListCalled).toBe(false) + }) + }) + + describe('State Initialization Patterns', () => { + it('should maintain AppCard tag state pattern', () => { + const app = { tags: [{ id: 'tag1', name: 'test' }] } + + // Original AppCard pattern: useState(app.tags) + const initialTags = app.tags + expect(Array.isArray(initialTags)).toBe(true) + expect(initialTags.length).toBe(1) + expect(initialTags).toBe(app.tags) // Reference equality for AppCard + }) + + it('should maintain AppInfo tag state pattern', () => { + const appDetail = { tags: [{ id: 'tag1', name: 'test' }] } + + // New AppInfo pattern: useState(appDetail?.tags || []) + const initialTags = appDetail?.tags || [] + expect(Array.isArray(initialTags)).toBe(true) + expect(initialTags.length).toBe(1) + }) + + it('should handle undefined appDetail gracefully in AppInfo', () => { + const appDetail = undefined + + // AppInfo pattern with undefined appDetail + const initialTags = (appDetail as any)?.tags || [] + expect(Array.isArray(initialTags)).toBe(true) + expect(initialTags.length).toBe(0) + }) + }) + + describe('CSS Class and Layout Logic', () => { + it('should apply correct minimum width condition', () => { + const minWidth = 'true' + + // This tests the minWidth logic in TagSelector + const shouldApplyMinWidth = minWidth && '!min-w-80' + expect(shouldApplyMinWidth).toBe('!min-w-80') + }) + + it('should not apply minimum width when not specified', () => { + const minWidth = undefined + + const shouldApplyMinWidth = minWidth && '!min-w-80' + expect(shouldApplyMinWidth).toBeFalsy() + }) + + it('should handle overflow layout classes correctly', () => { + // This tests the layout pattern from AppCard and new AppInfo + const overflowLayoutClasses = { + container: 'flex w-0 grow items-center', + inner: 'w-full', + truncate: 'truncate', + } + + expect(overflowLayoutClasses.container).toContain('w-0 grow') + expect(overflowLayoutClasses.inner).toContain('w-full') + expect(overflowLayoutClasses.truncate).toBe('truncate') + }) + }) + + describe('fetchAppWithTags Service Logic', () => { + it('should correctly find app by ID from app list', () => { + const appList = [ + { id: 'app1', name: 'App 1', tags: [] }, + { id: 'test-app-id', name: 'Test App', tags: [{ id: 'tag1', name: 'test' }] }, + { id: 'app3', name: 'App 3', tags: [] }, + ] + const targetAppId = 'test-app-id' + + // This simulates the logic in fetchAppWithTags + const foundApp = appList.find(app => app.id === targetAppId) + + expect(foundApp).toBeDefined() + expect(foundApp?.id).toBe('test-app-id') + expect(foundApp?.tags.length).toBe(1) + }) + + it('should return null when app not found', () => { + const appList = [ + { id: 'app1', name: 'App 1' }, + { id: 'app2', name: 'App 2' }, + ] + const targetAppId = 'nonexistent-app' + + const foundApp = appList.find(app => app.id === targetAppId) || null + + expect(foundApp).toBeNull() + }) + + it('should handle empty app list', () => { + const appList: any[] = [] + const targetAppId = 'any-app' + + const foundApp = appList.find(app => app.id === targetAppId) || null + + expect(foundApp).toBeNull() + expect(appList.length).toBe(0) // Verify empty array usage + }) + }) + + describe('Data Structure Validation', () => { + it('should maintain consistent tag data structure', () => { + const tag = { + id: 'tag1', + name: 'test-tag', + type: 'app', + binding_count: 1, + } + + expect(tag).toHaveProperty('id') + expect(tag).toHaveProperty('name') + expect(tag).toHaveProperty('type') + expect(tag).toHaveProperty('binding_count') + expect(tag.type).toBe('app') + expect(typeof tag.binding_count).toBe('number') + }) + + it('should handle tag arrays correctly', () => { + const tags = [ + { id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 }, + { id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 }, + ] + + expect(Array.isArray(tags)).toBe(true) + expect(tags.length).toBe(2) + expect(tags.every(tag => tag.type === 'app')).toBe(true) + }) + + it('should validate app data structure with tags', () => { + const app = { + id: 'test-app', + name: 'Test App', + tags: [ + { id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 1 }, + ], + } + + expect(app).toHaveProperty('id') + expect(app).toHaveProperty('name') + expect(app).toHaveProperty('tags') + expect(Array.isArray(app.tags)).toBe(true) + expect(app.tags.length).toBe(1) + }) + }) + + describe('Performance and Edge Cases', () => { + it('should handle large tag arrays efficiently', () => { + const largeTags = Array.from({ length: 100 }, (_, i) => `tag${i}`) + const selectedTags = ['tag1', 'tag50', 'tag99'] + + // Performance test: filtering should be efficient + const startTime = Date.now() + const addTags = selectedTags.filter(tag => !largeTags.includes(tag)) + const removeTags = largeTags.filter(tag => !selectedTags.includes(tag)) + const endTime = Date.now() + + expect(endTime - startTime).toBeLessThan(10) // Should be very fast + expect(addTags.length).toBe(0) // All selected tags exist + expect(removeTags.length).toBe(97) // 100 - 3 = 97 tags to remove + }) + + it('should handle malformed tag data gracefully', () => { + const mixedData = [ + { id: 'valid1', name: 'Valid Tag', type: 'app', binding_count: 1 }, + { id: 'invalid1' }, // Missing required properties + null, + undefined, + { id: 'valid2', name: 'Another Valid', type: 'app', binding_count: 0 }, + ] + + // Filter out invalid entries + const validTags = mixedData.filter((tag): tag is { id: string; name: string; type: string; binding_count: number } => + tag != null + && typeof tag === 'object' + && 'id' in tag + && 'name' in tag + && 'type' in tag + && 'binding_count' in tag + && typeof tag.binding_count === 'number', + ) + + expect(validTags.length).toBe(2) + expect(validTags.every(tag => tag.id && tag.name)).toBe(true) + }) + + it('should handle concurrent tag operations correctly', () => { + const operations = [ + { type: 'add', tagIds: ['tag1', 'tag2'] }, + { type: 'remove', tagIds: ['tag3'] }, + { type: 'add', tagIds: ['tag4'] }, + ] + + // Simulate processing operations + const results = operations.map(op => ({ + ...op, + processed: true, + timestamp: Date.now(), + })) + + expect(results.length).toBe(3) + expect(results.every(result => result.processed)).toBe(true) + }) + }) + + describe('Backward Compatibility Verification', () => { + it('should not break existing AppCard behavior', () => { + // Verify AppCard continues to work with original patterns + const originalAppCardLogic = { + initializeTags: (app: any) => app.tags, + updateTags: (_currentTags: any[], newTags: any[]) => newTags, + shouldRefresh: true, + } + + const app = { tags: [{ id: 'tag1', name: 'original' }] } + const initializedTags = originalAppCardLogic.initializeTags(app) + + expect(initializedTags).toBe(app.tags) + expect(originalAppCardLogic.shouldRefresh).toBe(true) + }) + + it('should ensure AppInfo follows AppCard patterns', () => { + // Verify AppInfo uses compatible state management + const appCardPattern = (app: any) => app.tags + const appInfoPattern = (appDetail: any) => appDetail?.tags || [] + + const appWithTags = { tags: [{ id: 'tag1' }] } + const appWithoutTags = { tags: [] } + const undefinedApp = undefined + + expect(appCardPattern(appWithTags)).toEqual(appInfoPattern(appWithTags)) + expect(appInfoPattern(appWithoutTags)).toEqual([]) + expect(appInfoPattern(undefinedApp)).toEqual([]) + }) + + it('should maintain consistent API parameters', () => { + // Verify service layer maintains expected parameters + const fetchAppListParams = { + url: '/apps', + params: { page: 1, limit: 100 }, + } + + const tagApiParams = { + bindTag: (tagIDs: string[], targetID: string, type: string) => ({ tagIDs, targetID, type }), + unBindTag: (tagID: string, targetID: string, type: string) => ({ tagID, targetID, type }), + } + + expect(fetchAppListParams.url).toBe('/apps') + expect(fetchAppListParams.params.limit).toBe(100) + + const bindResult = tagApiParams.bindTag(['tag1'], 'app1', 'app') + expect(bindResult.tagIDs).toEqual(['tag1']) + expect(bindResult.type).toBe('app') + }) + }) +}) diff --git a/web/__tests__/workflow-parallel-limit.test.tsx b/web/__tests__/workflow-parallel-limit.test.tsx new file mode 100644 index 0000000000..0843122ab4 --- /dev/null +++ b/web/__tests__/workflow-parallel-limit.test.tsx @@ -0,0 +1,301 @@ +/** + * MAX_PARALLEL_LIMIT Configuration Bug Test + * + * This test reproduces and verifies the fix for issue #23083: + * MAX_PARALLEL_LIMIT environment variable does not take effect in iteration panel + */ + +import { render, screen } from '@testing-library/react' +import React from 'react' + +// Mock environment variables before importing constants +const originalEnv = process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT + +// Test with different environment values +function setupEnvironment(value?: string) { + if (value) + process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT = value + else + delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT + + // Clear module cache to force re-evaluation + jest.resetModules() +} + +function restoreEnvironment() { + if (originalEnv) + process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT = originalEnv + else + delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT + + jest.resetModules() +} + +// Mock i18next with proper implementation +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + if (key.includes('MaxParallelismTitle')) return 'Max Parallelism' + if (key.includes('MaxParallelismDesc')) return 'Maximum number of parallel executions' + if (key.includes('parallelMode')) return 'Parallel Mode' + if (key.includes('parallelPanelDesc')) return 'Enable parallel execution' + if (key.includes('errorResponseMethod')) return 'Error Response Method' + return key + }, + }), + initReactI18next: { + type: '3rdParty', + init: jest.fn(), + }, +})) + +// Mock i18next module completely to prevent initialization issues +jest.mock('i18next', () => ({ + use: jest.fn().mockReturnThis(), + init: jest.fn().mockReturnThis(), + t: jest.fn(key => key), + isInitialized: true, +})) + +// Mock the useConfig hook +jest.mock('@/app/components/workflow/nodes/iteration/use-config', () => ({ + __esModule: true, + default: () => ({ + inputs: { + is_parallel: true, + parallel_nums: 5, + error_handle_mode: 'terminated', + }, + changeParallel: jest.fn(), + changeParallelNums: jest.fn(), + changeErrorHandleMode: jest.fn(), + }), +})) + +// Mock other components +jest.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => { + return function MockVarReferencePicker() { + return
VarReferencePicker
+ } +}) + +jest.mock('@/app/components/workflow/nodes/_base/components/split', () => { + return function MockSplit() { + return
Split
+ } +}) + +jest.mock('@/app/components/workflow/nodes/_base/components/field', () => { + return function MockField({ title, children }: { title: string, children: React.ReactNode }) { + return ( +
+ + {children} +
+ ) + } +}) + +jest.mock('@/app/components/base/switch', () => { + return function MockSwitch({ defaultValue }: { defaultValue: boolean }) { + return + } +}) + +jest.mock('@/app/components/base/select', () => { + return function MockSelect() { + return + } +}) + +// Use defaultValue to avoid controlled input warnings +jest.mock('@/app/components/base/slider', () => { + return function MockSlider({ value, max, min }: { value: number, max: number, min: number }) { + return ( + + ) + } +}) + +// Use defaultValue to avoid controlled input warnings +jest.mock('@/app/components/base/input', () => { + return function MockInput({ type, max, min, value }: { type: string, max: number, min: number, value: number }) { + return ( + + ) + } +}) + +describe('MAX_PARALLEL_LIMIT Configuration Bug', () => { + const mockNodeData = { + id: 'test-iteration-node', + type: 'iteration' as const, + data: { + title: 'Test Iteration', + desc: 'Test iteration node', + iterator_selector: ['test'], + output_selector: ['output'], + is_parallel: true, + parallel_nums: 5, + error_handle_mode: 'terminated' as const, + }, + } + + beforeEach(() => { + jest.clearAllMocks() + }) + + afterEach(() => { + restoreEnvironment() + }) + + afterAll(() => { + restoreEnvironment() + }) + + describe('Environment Variable Parsing', () => { + it('should parse MAX_PARALLEL_LIMIT from NEXT_PUBLIC_MAX_PARALLEL_LIMIT environment variable', () => { + setupEnvironment('25') + const { MAX_PARALLEL_LIMIT } = require('@/config') + expect(MAX_PARALLEL_LIMIT).toBe(25) + }) + + it('should fallback to default when environment variable is not set', () => { + setupEnvironment() // No environment variable + const { MAX_PARALLEL_LIMIT } = require('@/config') + expect(MAX_PARALLEL_LIMIT).toBe(10) + }) + + it('should handle invalid environment variable values', () => { + setupEnvironment('invalid') + const { MAX_PARALLEL_LIMIT } = require('@/config') + + // Should fall back to default when parsing fails + expect(MAX_PARALLEL_LIMIT).toBe(10) + }) + + it('should handle empty environment variable', () => { + setupEnvironment('') + const { MAX_PARALLEL_LIMIT } = require('@/config') + + // Should fall back to default when empty + expect(MAX_PARALLEL_LIMIT).toBe(10) + }) + + // Edge cases for boundary values + it('should clamp MAX_PARALLEL_LIMIT to MIN when env is 0 or negative', () => { + setupEnvironment('0') + let { MAX_PARALLEL_LIMIT } = require('@/config') + expect(MAX_PARALLEL_LIMIT).toBe(10) // Falls back to default + + setupEnvironment('-5') + ;({ MAX_PARALLEL_LIMIT } = require('@/config')) + expect(MAX_PARALLEL_LIMIT).toBe(10) // Falls back to default + }) + + it('should handle float numbers by parseInt behavior', () => { + setupEnvironment('12.7') + const { MAX_PARALLEL_LIMIT } = require('@/config') + // parseInt truncates to integer + expect(MAX_PARALLEL_LIMIT).toBe(12) + }) + }) + + describe('UI Component Integration (Main Fix Verification)', () => { + it('should render iteration panel with environment-configured max value', () => { + // Set environment variable to a different value + setupEnvironment('30') + + // Import Panel after setting environment + const Panel = require('@/app/components/workflow/nodes/iteration/panel').default + const { MAX_PARALLEL_LIMIT } = require('@/config') + + render( + , + ) + + // Behavior-focused assertion: UI max should equal MAX_PARALLEL_LIMIT + const numberInput = screen.getByTestId('number-input') + expect(numberInput).toHaveAttribute('data-max', String(MAX_PARALLEL_LIMIT)) + + const slider = screen.getByTestId('slider') + expect(slider).toHaveAttribute('data-max', String(MAX_PARALLEL_LIMIT)) + + // Verify the actual values + expect(MAX_PARALLEL_LIMIT).toBe(30) + expect(numberInput.getAttribute('data-max')).toBe('30') + expect(slider.getAttribute('data-max')).toBe('30') + }) + + it('should maintain UI consistency with different environment values', () => { + setupEnvironment('15') + const Panel = require('@/app/components/workflow/nodes/iteration/panel').default + const { MAX_PARALLEL_LIMIT } = require('@/config') + + render( + , + ) + + // Both input and slider should use the same max value from MAX_PARALLEL_LIMIT + const numberInput = screen.getByTestId('number-input') + const slider = screen.getByTestId('slider') + + expect(numberInput.getAttribute('data-max')).toBe(slider.getAttribute('data-max')) + expect(numberInput.getAttribute('data-max')).toBe(String(MAX_PARALLEL_LIMIT)) + }) + }) + + describe('Legacy Constant Verification (For Transition Period)', () => { + // Marked as transition/deprecation tests + it('should maintain MAX_ITERATION_PARALLEL_NUM for backward compatibility', () => { + const { MAX_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants') + expect(typeof MAX_ITERATION_PARALLEL_NUM).toBe('number') + expect(MAX_ITERATION_PARALLEL_NUM).toBe(10) // Hardcoded legacy value + }) + + it('should demonstrate MAX_PARALLEL_LIMIT vs legacy constant difference', () => { + setupEnvironment('50') + const { MAX_PARALLEL_LIMIT } = require('@/config') + const { MAX_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants') + + // MAX_PARALLEL_LIMIT is configurable, MAX_ITERATION_PARALLEL_NUM is not + expect(MAX_PARALLEL_LIMIT).toBe(50) + expect(MAX_ITERATION_PARALLEL_NUM).toBe(10) + expect(MAX_PARALLEL_LIMIT).not.toBe(MAX_ITERATION_PARALLEL_NUM) + }) + }) + + describe('Constants Validation', () => { + it('should validate that required constants exist and have correct types', () => { + const { MAX_PARALLEL_LIMIT } = require('@/config') + const { MIN_ITERATION_PARALLEL_NUM } = require('@/app/components/workflow/constants') + expect(typeof MAX_PARALLEL_LIMIT).toBe('number') + expect(typeof MIN_ITERATION_PARALLEL_NUM).toBe('number') + expect(MAX_PARALLEL_LIMIT).toBeGreaterThanOrEqual(MIN_ITERATION_PARALLEL_NUM) + }) + }) +}) diff --git a/web/__tests__/xss-fix-verification.test.tsx b/web/__tests__/xss-fix-verification.test.tsx new file mode 100644 index 0000000000..2fa5ab3c05 --- /dev/null +++ b/web/__tests__/xss-fix-verification.test.tsx @@ -0,0 +1,212 @@ +/** + * XSS Fix Verification Test + * + * This test verifies that the XSS vulnerability in check-code pages has been + * properly fixed by replacing dangerouslySetInnerHTML with safe React rendering. + */ + +import React from 'react' +import { cleanup, render } from '@testing-library/react' +import '@testing-library/jest-dom' + +// Mock i18next with the new safe translation structure +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => { + if (key === 'login.checkCode.tipsPrefix') + return 'We send a verification code to ' + + return key + }, + }), +})) + +// Mock Next.js useSearchParams +jest.mock('next/navigation', () => ({ + useSearchParams: () => ({ + get: (key: string) => { + if (key === 'email') + return 'test@example.com' + return null + }, + }), +})) + +// Fixed CheckCode component implementation (current secure version) +const SecureCheckCodeComponent = ({ email }: { email: string }) => { + const { t } = require('react-i18next').useTranslation() + + return ( +
+

Check Code

+

+ + {t('login.checkCode.tipsPrefix')} + {email} + +

+
+ ) +} + +// Vulnerable implementation for comparison (what we fixed) +const VulnerableCheckCodeComponent = ({ email }: { email: string }) => { + const mockTranslation = (key: string, params?: any) => { + if (key === 'login.checkCode.tips' && params?.email) + return `We send a verification code to ${params.email}` + + return key + } + + return ( +
+

Check Code

+

+ +

+
+ ) +} + +describe('XSS Fix Verification - Check Code Pages Security', () => { + afterEach(() => { + cleanup() + }) + + const maliciousEmail = 'test@example.com' + + it('should securely render email with HTML characters as text (FIXED VERSION)', () => { + console.log('\n🔒 Security Fix Verification Report') + console.log('===================================') + + const { container } = render() + + const spanElement = container.querySelector('span') + const strongElement = container.querySelector('strong') + const scriptElements = container.querySelectorAll('script') + + console.log('\n✅ Fixed Implementation Results:') + console.log('- Email rendered in strong tag:', strongElement?.textContent) + console.log('- HTML tags visible as text:', strongElement?.textContent?.includes('', + 'normal@email.com', + ] + + testCases.forEach((testEmail, index) => { + const { container } = render() + + const strongElement = container.querySelector('strong') + const scriptElements = container.querySelectorAll('script') + const imgElements = container.querySelectorAll('img') + const divElements = container.querySelectorAll('div:not([data-testid])') + + console.log(`\n📧 Test Case ${index + 1}: ${testEmail.substring(0, 20)}...`) + console.log(` - Script elements: ${scriptElements.length}`) + console.log(` - Img elements: ${imgElements.length}`) + console.log(` - Malicious divs: ${divElements.length - 1}`) // -1 for container div + console.log(` - Text content: ${strongElement?.textContent === testEmail ? 'SAFE' : 'ISSUE'}`) + + // All should be safe + expect(scriptElements).toHaveLength(0) + expect(imgElements).toHaveLength(0) + expect(strongElement?.textContent).toBe(testEmail) + }) + + console.log('\n✅ All test cases passed - secure rendering confirmed') + }) + + it('should validate the translation structure is secure', () => { + console.log('\n🔍 Translation Security Analysis') + console.log('=================================') + + const { t } = require('react-i18next').useTranslation() + const prefix = t('login.checkCode.tipsPrefix') + + console.log('- Translation key used: login.checkCode.tipsPrefix') + console.log('- Translation value:', prefix) + console.log('- Contains HTML tags:', prefix.includes('<')) + console.log('- Pure text content:', !prefix.includes('<') && !prefix.includes('>')) + + // Verify translation is plain text + expect(prefix).toBe('We send a verification code to ') + expect(prefix).not.toContain('<') + expect(prefix).not.toContain('>') + expect(typeof prefix).toBe('string') + + console.log('\n✅ Translation structure is secure - no HTML content') + }) + + it('should confirm React automatic escaping works correctly', () => { + console.log('\n⚡ React Security Mechanism Test') + console.log('=================================') + + // Test React's automatic escaping with various inputs + const dangerousInputs = [ + '', + '', + '">', + '\'>alert(3)', + '
click
', + ] + + dangerousInputs.forEach((input, index) => { + const TestComponent = () => {input} + const { container } = render() + + const strongElement = container.querySelector('strong') + const scriptElements = container.querySelectorAll('script') + + console.log(`\n🧪 Input ${index + 1}: ${input.substring(0, 30)}...`) + console.log(` - Rendered as text: ${strongElement?.textContent === input}`) + console.log(` - No script execution: ${scriptElements.length === 0}`) + + expect(strongElement?.textContent).toBe(input) + expect(scriptElements).toHaveLength(0) + }) + + console.log('\n🛡️ React automatic escaping is working perfectly') + }) +}) + +export {} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/develop/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/develop/page.tsx index 415d82285c..11335b270c 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/develop/page.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/develop/page.tsx @@ -1,5 +1,5 @@ import React from 'react' -import type { Locale } from '@/i18n' +import type { Locale } from '@/i18n-config' import DevelopMain from '@/app/components/develop' export type IDevelopProps = { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 6b3807f1c6..6d337e3c47 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -20,12 +20,18 @@ import cn from '@/utils/classnames' import { useStore } from '@/app/components/app/store' import AppSideBar from '@/app/components/app-sidebar' import type { NavIcon } from '@/app/components/app-sidebar/navLink' -import { fetchAppDetail } from '@/service/apps' +import { fetchAppDetailDirect } from '@/service/apps' import { useAppContext } from '@/context/app-context' import Loading from '@/app/components/base/loading' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import type { App } from '@/types/app' import useDocumentTitle from '@/hooks/use-document-title' +import { useStore as useTagStore } from '@/app/components/base/tag-management/store' +import dynamic from 'next/dynamic' + +const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), { + ssr: false, +}) export type IAppDetailLayoutProps = { children: React.ReactNode @@ -48,6 +54,7 @@ const AppDetailLayout: FC = (props) => { setAppDetail: state.setAppDetail, setAppSiderbarExpand: state.setAppSiderbarExpand, }))) + const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [isLoadingAppDetail, setIsLoadingAppDetail] = useState(false) const [appDetailRes, setAppDetailRes] = useState(null) const [navigation, setNavigation] = useState = (props) => { useEffect(() => { setAppDetail() setIsLoadingAppDetail(true) - fetchAppDetail({ url: '/apps', id: appId }).then((res) => { + fetchAppDetailDirect({ url: '/apps', id: appId }).then((res: App) => { setAppDetailRes(res) }).catch((e: any) => { if (e.status === 404) @@ -163,6 +170,9 @@ const AppDetailLayout: FC = (props) => {
{children}
+ {showTagManagementModal && ( + + )}
) } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx similarity index 98% rename from web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx rename to web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index 3d572b926a..e58e79918f 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import AppCard from '@/app/components/app/overview/appCard' +import AppCard from '@/app/components/app/overview/app-card' import Loading from '@/app/components/base/loading' import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card' import { ToastContext } from '@/app/components/base/toast' @@ -17,7 +17,7 @@ import type { App } from '@/types/app' import type { UpdateAppSiteCodeResponse } from '@/models/app' import { asyncRunSafe } from '@/utils' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' -import type { IAppCardProps } from '@/app/components/app/overview/appCard' +import type { IAppCardProps } from '@/app/components/app/overview/app-card' import { useStore as useAppStore } from '@/app/components/app/store' export type ICardViewProps = { diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx similarity index 98% rename from web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx rename to web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx index 646c8bd93d..847de19165 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chart-view.tsx @@ -3,8 +3,8 @@ import React, { useState } from 'react' import dayjs from 'dayjs' import quarterOfYear from 'dayjs/plugin/quarterOfYear' import { useTranslation } from 'react-i18next' -import type { PeriodParams } from '@/app/components/app/overview/appChart' -import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, MessagesChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/appChart' +import type { PeriodParams } from '@/app/components/app/overview/app-chart' +import { AvgResponseTime, AvgSessionInteractions, AvgUserInteractions, ConversationsChart, CostChart, EndUsersChart, MessagesChart, TokenPerSecond, UserSatisfactionRate, WorkflowCostChart, WorkflowDailyTerminalsChart, WorkflowMessagesChart } from '@/app/components/app/overview/app-chart' import type { Item } from '@/app/components/base/select' import { SimpleSelect } from '@/app/components/base/select' import { TIME_PERIOD_MAPPING } from '@/app/components/app/log/filter' @@ -54,6 +54,7 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) { ({ value: k, name: t(`appLog.filter.period.${v.name}`) }))} className='mt-0 !w-40' + notClearable={true} onSelect={(item) => { const id = item.value const value = TIME_PERIOD_MAPPING[id]?.value ?? '-1' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx index e0c09e739e..bc07a799e4 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx @@ -1,5 +1,5 @@ import React from 'react' -import ChartView from './chartView' +import ChartView from './chart-view' import TracingPanel from './tracing/panel' import ApikeyInfoPanel from '@/app/components/app/overview/apikey-info-panel' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx new file mode 100644 index 0000000000..a3281be8eb --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx @@ -0,0 +1,156 @@ +import React from 'react' +import { render } from '@testing-library/react' +import '@testing-library/jest-dom' +import { OpikIconBig } from '@/app/components/base/icons/src/public/tracing' + +// Mock dependencies to isolate the SVG rendering issue +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('SVG Attribute Error Reproduction', () => { + // Capture console errors + const originalError = console.error + let errorMessages: string[] = [] + + beforeEach(() => { + errorMessages = [] + console.error = jest.fn((message) => { + errorMessages.push(message) + originalError(message) + }) + }) + + afterEach(() => { + console.error = originalError + }) + + it('should reproduce inkscape attribute errors when rendering OpikIconBig', () => { + console.log('\n=== TESTING OpikIconBig SVG ATTRIBUTE ERRORS ===') + + // Test multiple renders to check for inconsistency + for (let i = 0; i < 5; i++) { + console.log(`\nRender attempt ${i + 1}:`) + + const { unmount } = render() + + // Check for specific inkscape attribute errors + const inkscapeErrors = errorMessages.filter(msg => + typeof msg === 'string' && msg.includes('inkscape'), + ) + + if (inkscapeErrors.length > 0) { + console.log(`Found ${inkscapeErrors.length} inkscape errors:`) + inkscapeErrors.forEach((error, index) => { + console.log(` ${index + 1}. ${error.substring(0, 100)}...`) + }) + } + else { + console.log('No inkscape errors found in this render') + } + + unmount() + + // Clear errors for next iteration + errorMessages = [] + } + }) + + it('should analyze the SVG structure causing the errors', () => { + console.log('\n=== ANALYZING SVG STRUCTURE ===') + + // Import the JSON data directly + const iconData = require('@/app/components/base/icons/src/public/tracing/OpikIconBig.json') + + console.log('Icon structure analysis:') + console.log('- Root element:', iconData.icon.name) + console.log('- Children count:', iconData.icon.children?.length || 0) + + // Find problematic elements + const findProblematicElements = (node: any, path = '') => { + const problematicElements: any[] = [] + + if (node.name && (node.name.includes(':') || node.name.startsWith('sodipodi'))) { + problematicElements.push({ + path, + name: node.name, + attributes: Object.keys(node.attributes || {}), + }) + } + + // Check attributes for inkscape/sodipodi properties + if (node.attributes) { + const problematicAttrs = Object.keys(node.attributes).filter(attr => + attr.startsWith('inkscape:') || attr.startsWith('sodipodi:'), + ) + + if (problematicAttrs.length > 0) { + problematicElements.push({ + path, + name: node.name, + problematicAttributes: problematicAttrs, + }) + } + } + + if (node.children) { + node.children.forEach((child: any, index: number) => { + problematicElements.push( + ...findProblematicElements(child, `${path}/${node.name}[${index}]`), + ) + }) + } + + return problematicElements + } + + const problematicElements = findProblematicElements(iconData.icon, 'root') + + console.log(`\n🚨 Found ${problematicElements.length} problematic elements:`) + problematicElements.forEach((element, index) => { + console.log(`\n${index + 1}. Element: ${element.name}`) + console.log(` Path: ${element.path}`) + if (element.problematicAttributes) + console.log(` Problematic attributes: ${element.problematicAttributes.join(', ')}`) + }) + }) + + it('should test the normalizeAttrs function behavior', () => { + console.log('\n=== TESTING normalizeAttrs FUNCTION ===') + + const { normalizeAttrs } = require('@/app/components/base/icons/utils') + + const testAttributes = { + 'inkscape:showpageshadow': '2', + 'inkscape:pageopacity': '0.0', + 'inkscape:pagecheckerboard': '0', + 'inkscape:deskcolor': '#d1d1d1', + 'sodipodi:docname': 'opik-icon-big.svg', + 'xmlns:inkscape': 'https://www.inkscape.org/namespaces/inkscape', + 'xmlns:sodipodi': 'https://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd', + 'xmlns:svg': 'https://www.w3.org/2000/svg', + 'data-name': 'Layer 1', + 'normal-attr': 'value', + 'class': 'test-class', + } + + console.log('Input attributes:', Object.keys(testAttributes)) + + const normalized = normalizeAttrs(testAttributes) + + console.log('Normalized attributes:', Object.keys(normalized)) + console.log('Normalized values:', normalized) + + // Check if problematic attributes are still present + const problematicKeys = Object.keys(normalized).filter(key => + key.toLowerCase().includes('inkscape') || key.toLowerCase().includes('sodipodi'), + ) + + if (problematicKeys.length > 0) + console.log(`🚨 PROBLEM: Still found problematic attributes: ${problematicKeys.join(', ')}`) + else + console.log('✅ No problematic attributes found after normalization') + }) +}) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx index 3d05575127..1ab40e31bf 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-button.tsx @@ -1,12 +1,9 @@ 'use client' import type { FC } from 'react' -import React, { useCallback, useEffect, useRef, useState } from 'react' -import { - RiEqualizer2Line, -} from '@remixicon/react' +import React, { useCallback, useRef, useState } from 'react' + import type { PopupProps } from './config-popup' import ConfigPopup from './config-popup' -import cn from '@/utils/classnames' import { PortalToFollowElem, PortalToFollowElemContent, @@ -17,13 +14,13 @@ type Props = { readOnly: boolean className?: string hasConfigured: boolean - controlShowPopup?: number + children?: React.ReactNode } & PopupProps const ConfigBtn: FC = ({ className, hasConfigured, - controlShowPopup, + children, ...popupProps }) => { const [open, doSetOpen] = useState(false) @@ -37,13 +34,6 @@ const ConfigBtn: FC = ({ setOpen(!openRef.current) }, [setOpen]) - useEffect(() => { - if (controlShowPopup) - // setOpen(!openRef.current) - setOpen(true) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [controlShowPopup]) - if (popupProps.readOnly && !hasConfigured) return null @@ -52,14 +42,11 @@ const ConfigBtn: FC = ({ open={open} onOpenChange={setOpen} placement='bottom-end' - offset={{ - mainAxis: 12, - crossAxis: hasConfigured ? 8 : 49, - }} + offset={12} > -
- +
+ {children}
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index d082523222..7564a0f3c8 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -1,8 +1,9 @@ 'use client' import type { FC } from 'react' -import React, { useCallback, useEffect, useState } from 'react' +import React, { useEffect, useState } from 'react' import { RiArrowDownDoubleLine, + RiEqualizer2Line, } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { usePathname } from 'next/navigation' @@ -180,10 +181,6 @@ const Panel: FC = () => { })() }, []) - const [controlShowPopup, setControlShowPopup] = useState(0) - const showPopup = useCallback(() => { - setControlShowPopup(Date.now()) - }, [setControlShowPopup]) if (!isLoaded) { return (
@@ -196,46 +193,66 @@ const Panel: FC = () => { return (
-
- {!inUseTracingProvider && ( - <> + {!inUseTracingProvider && ( + +
{t(`${I18N_PREFIX}.title`)}
-
e.stopPropagation()}> - +
+
- - )} - {hasConfiguredTracing && ( - <> +
+ + )} + {hasConfiguredTracing && ( + +
@@ -243,33 +260,14 @@ const Panel: FC = () => {
{InUseProviderIcon && } - -
e.stopPropagation()}> - +
+
- - )} -
-
+ +
+
+ )} +
) } export default React.memo(Panel) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 426778c835..f8189b0c8a 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -4,7 +4,6 @@ import React, { useEffect, useMemo } from 'react' import { usePathname } from 'next/navigation' import useSWR from 'swr' import { useTranslation } from 'react-i18next' -import { useBoolean } from 'ahooks' import { RiEqualizer2Fill, RiEqualizer2Line, @@ -44,45 +43,57 @@ type IExtraInfoProps = { } const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => { - const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile) const { t } = useTranslation() const docLink = useDocLink() const hasRelatedApps = relatedApps?.data && relatedApps?.data?.length > 0 const relatedAppsTotal = relatedApps?.data?.length || 0 - useEffect(() => { - setShowTips(!isMobile) - }, [isMobile, setShowTips]) - return
- {hasRelatedApps && ( - <> - {!isMobile && ( - - } - > -
- {relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')} - -
-
- )} + {/* Related apps for desktop */} +
+ + } + > +
+ {relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')} + +
+
+
- {isMobile &&
- {relatedAppsTotal || '--'} - -
} - - )} - {!hasRelatedApps && !expand && ( + {/* Related apps for mobile */} +
+
+ {relatedAppsTotal || '--'} + +
+
+ + {/* No related apps tooltip */} +
{
} > -
+
{t('common.datasetMenus.noRelatedApp')}
- )} +
} diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index d9a196d854..688f2c9fc2 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -1,5 +1,5 @@ import React from 'react' -import { getLocaleOnServer, useTranslation as translate } from '@/i18n/server' +import { getLocaleOnServer, useTranslation as translate } from '@/i18n-config/server' import Form from '@/app/components/datasets/settings/form' const Settings = async () => { diff --git a/web/app/(commonLayout)/datasets/Doc.tsx b/web/app/(commonLayout)/datasets/Doc.tsx deleted file mode 100644 index efdfe157f2..0000000000 --- a/web/app/(commonLayout)/datasets/Doc.tsx +++ /dev/null @@ -1,131 +0,0 @@ -'use client' - -import { useEffect, useMemo, useState } from 'react' -import { useContext } from 'use-context-selector' -import { useTranslation } from 'react-i18next' -import { RiListUnordered } from '@remixicon/react' -import TemplateEn from './template/template.en.mdx' -import TemplateZh from './template/template.zh.mdx' -import TemplateJa from './template/template.ja.mdx' -import I18n from '@/context/i18n' -import { LanguagesSupported } from '@/i18n/language' -import useTheme from '@/hooks/use-theme' -import { Theme } from '@/types/app' -import cn from '@/utils/classnames' - -type DocProps = { - apiBaseUrl: string -} - -const Doc = ({ apiBaseUrl }: DocProps) => { - const { locale } = useContext(I18n) - const { t } = useTranslation() - const [toc, setToc] = useState>([]) - const [isTocExpanded, setIsTocExpanded] = useState(false) - const { theme } = useTheme() - - // Set initial TOC expanded state based on screen width - useEffect(() => { - const mediaQuery = window.matchMedia('(min-width: 1280px)') - setIsTocExpanded(mediaQuery.matches) - }, []) - - // Extract TOC from article content - useEffect(() => { - const extractTOC = () => { - const article = document.querySelector('article') - if (article) { - const headings = article.querySelectorAll('h2') - const tocItems = Array.from(headings).map((heading) => { - const anchor = heading.querySelector('a') - if (anchor) { - return { - href: anchor.getAttribute('href') || '', - text: anchor.textContent || '', - } - } - return null - }).filter((item): item is { href: string; text: string } => item !== null) - setToc(tocItems) - } - } - - setTimeout(extractTOC, 0) - }, [locale]) - - // Handle TOC item click - const handleTocClick = (e: React.MouseEvent, item: { href: string; text: string }) => { - e.preventDefault() - const targetId = item.href.replace('#', '') - const element = document.getElementById(targetId) - if (element) { - const scrollContainer = document.querySelector('.scroll-container') - if (scrollContainer) { - const headerOffset = -40 - const elementTop = element.offsetTop - headerOffset - scrollContainer.scrollTo({ - top: elementTop, - behavior: 'smooth', - }) - } - } - } - - const Template = useMemo(() => { - switch (locale) { - case LanguagesSupported[1]: - return - case LanguagesSupported[7]: - return - default: - return - } - }, [apiBaseUrl, locale]) - - return ( -
-
- {isTocExpanded - ? ( - - ) - : ( - - )} -
-
- {Template} -
-
- ) -} - -export default Doc diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/container.tsx similarity index 90% rename from web/app/(commonLayout)/datasets/Container.tsx rename to web/app/(commonLayout)/datasets/container.tsx index 112b6a752e..5328fd03aa 100644 --- a/web/app/(commonLayout)/datasets/Container.tsx +++ b/web/app/(commonLayout)/datasets/container.tsx @@ -9,10 +9,10 @@ import { useQuery } from '@tanstack/react-query' // Components import ExternalAPIPanel from '../../components/datasets/external-api/external-api-panel' -import Datasets from './Datasets' -import DatasetFooter from './DatasetFooter' +import Datasets from './datasets' +import DatasetFooter from './dataset-footer' import ApiServer from '../../components/develop/ApiServer' -import Doc from './Doc' +import Doc from './doc' import TabSliderNew from '@/app/components/base/tab-slider-new' import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' @@ -86,8 +86,8 @@ const Container = () => { }, [currentWorkspace, router]) return ( -
-
+
+
setActiveTab(newActiveTab)} diff --git a/web/app/(commonLayout)/datasets/create/page.tsx b/web/app/(commonLayout)/datasets/create/page.tsx index 663a830665..50fd1f5a19 100644 --- a/web/app/(commonLayout)/datasets/create/page.tsx +++ b/web/app/(commonLayout)/datasets/create/page.tsx @@ -1,9 +1,7 @@ import React from 'react' import DatasetUpdateForm from '@/app/components/datasets/create' -type Props = {} - -const DatasetCreation = async (props: Props) => { +const DatasetCreation = async () => { return ( ) diff --git a/web/app/(commonLayout)/datasets/DatasetCard.tsx b/web/app/(commonLayout)/datasets/dataset-card.tsx similarity index 92% rename from web/app/(commonLayout)/datasets/DatasetCard.tsx rename to web/app/(commonLayout)/datasets/dataset-card.tsx index 4b40be2c7f..3e913ca52f 100644 --- a/web/app/(commonLayout)/datasets/DatasetCard.tsx +++ b/web/app/(commonLayout)/datasets/dataset-card.tsx @@ -5,6 +5,7 @@ import { useRouter } from 'next/navigation' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { RiMoreFill } from '@remixicon/react' +import { mutate } from 'swr' import cn from '@/utils/classnames' import Confirm from '@/app/components/base/confirm' import { ToastContext } from '@/app/components/base/toast' @@ -57,6 +58,19 @@ const DatasetCard = ({ const onConfirmDelete = useCallback(async () => { try { await deleteDataset(dataset.id) + + // Clear SWR cache to prevent stale data in knowledge retrieval nodes + mutate( + (key) => { + if (typeof key === 'string') return key.includes('/datasets') + if (typeof key === 'object' && key !== null) + return key.url === '/datasets' || key.url?.includes('/datasets') + return false + }, + undefined, + { revalidate: true }, + ) + notify({ type: 'success', message: t('dataset.datasetDeleted') }) if (onSuccess) onSuccess() @@ -162,24 +176,19 @@ const DatasetCard = ({
{dataset.description}
-
+
{ e.stopPropagation() e.preventDefault() }}>
cn( - open ? '!bg-black/5 !shadow-none' : '!bg-transparent', - 'h-8 w-8 rounded-md border-none !p-2 hover:!bg-black/5', + open ? '!bg-state-base-hover !shadow-none' : '!bg-transparent', + 'h-8 w-8 rounded-md border-none !p-2 hover:!bg-state-base-hover', ) } className={'!z-20 h-fit !w-[128px]'} diff --git a/web/app/(commonLayout)/datasets/DatasetFooter.tsx b/web/app/(commonLayout)/datasets/dataset-footer.tsx similarity index 100% rename from web/app/(commonLayout)/datasets/DatasetFooter.tsx rename to web/app/(commonLayout)/datasets/dataset-footer.tsx diff --git a/web/app/(commonLayout)/datasets/Datasets.tsx b/web/app/(commonLayout)/datasets/datasets.tsx similarity index 94% rename from web/app/(commonLayout)/datasets/Datasets.tsx rename to web/app/(commonLayout)/datasets/datasets.tsx index 2d4848e92e..4e116c6d39 100644 --- a/web/app/(commonLayout)/datasets/Datasets.tsx +++ b/web/app/(commonLayout)/datasets/datasets.tsx @@ -3,8 +3,8 @@ import { useCallback, useEffect, useRef } from 'react' import useSWRInfinite from 'swr/infinite' import { debounce } from 'lodash-es' -import NewDatasetCard from './NewDatasetCard' -import DatasetCard from './DatasetCard' +import NewDatasetCard from './new-dataset-card' +import DatasetCard from './dataset-card' import type { DataSetListResponse, FetchDatasetsParams } from '@/models/datasets' import { fetchDatasets } from '@/service/datasets' import { useAppContext } from '@/context/app-context' @@ -36,7 +36,7 @@ const getKey = ( } type Props = { - containerRef: React.RefObject + containerRef: React.RefObject tags: string[] keywords: string includeAll: boolean diff --git a/web/app/(commonLayout)/datasets/doc.tsx b/web/app/(commonLayout)/datasets/doc.tsx new file mode 100644 index 0000000000..c31dad3c00 --- /dev/null +++ b/web/app/(commonLayout)/datasets/doc.tsx @@ -0,0 +1,203 @@ +'use client' + +import { useEffect, useMemo, useState } from 'react' +import { useContext } from 'use-context-selector' +import { useTranslation } from 'react-i18next' +import { RiCloseLine, RiListUnordered } from '@remixicon/react' +import TemplateEn from './template/template.en.mdx' +import TemplateZh from './template/template.zh.mdx' +import TemplateJa from './template/template.ja.mdx' +import I18n from '@/context/i18n' +import { LanguagesSupported } from '@/i18n-config/language' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import cn from '@/utils/classnames' + +type DocProps = { + apiBaseUrl: string +} + +const Doc = ({ apiBaseUrl }: DocProps) => { + const { locale } = useContext(I18n) + const { t } = useTranslation() + const [toc, setToc] = useState>([]) + const [isTocExpanded, setIsTocExpanded] = useState(false) + const [activeSection, setActiveSection] = useState('') + const { theme } = useTheme() + + // Set initial TOC expanded state based on screen width + useEffect(() => { + const mediaQuery = window.matchMedia('(min-width: 1280px)') + setIsTocExpanded(mediaQuery.matches) + }, []) + + // Extract TOC from article content + useEffect(() => { + const extractTOC = () => { + const article = document.querySelector('article') + if (article) { + const headings = article.querySelectorAll('h2') + const tocItems = Array.from(headings).map((heading) => { + const anchor = heading.querySelector('a') + if (anchor) { + return { + href: anchor.getAttribute('href') || '', + text: anchor.textContent || '', + } + } + return null + }).filter((item): item is { href: string; text: string } => item !== null) + setToc(tocItems) + // Set initial active section + if (tocItems.length > 0) + setActiveSection(tocItems[0].href.replace('#', '')) + } + } + + setTimeout(extractTOC, 0) + }, [locale]) + + // Track scroll position for active section highlighting + useEffect(() => { + const handleScroll = () => { + const scrollContainer = document.querySelector('.scroll-container') + if (!scrollContainer || toc.length === 0) + return + + // Find active section based on scroll position + let currentSection = '' + toc.forEach((item) => { + const targetId = item.href.replace('#', '') + const element = document.getElementById(targetId) + if (element) { + const rect = element.getBoundingClientRect() + // Consider section active if its top is above the middle of viewport + if (rect.top <= window.innerHeight / 2) + currentSection = targetId + } + }) + + if (currentSection && currentSection !== activeSection) + setActiveSection(currentSection) + } + + const scrollContainer = document.querySelector('.scroll-container') + if (scrollContainer) { + scrollContainer.addEventListener('scroll', handleScroll) + handleScroll() // Initial check + return () => scrollContainer.removeEventListener('scroll', handleScroll) + } + }, [toc, activeSection]) + + // Handle TOC item click + const handleTocClick = (e: React.MouseEvent, item: { href: string; text: string }) => { + e.preventDefault() + const targetId = item.href.replace('#', '') + const element = document.getElementById(targetId) + if (element) { + const scrollContainer = document.querySelector('.scroll-container') + if (scrollContainer) { + const headerOffset = -40 + const elementTop = element.offsetTop - headerOffset + scrollContainer.scrollTo({ + top: elementTop, + behavior: 'smooth', + }) + } + } + } + + const Template = useMemo(() => { + switch (locale) { + case LanguagesSupported[1]: + return + case LanguagesSupported[7]: + return + default: + return + } + }, [apiBaseUrl, locale]) + + return ( +
+
+ {isTocExpanded + ? ( + + ) + : ( + + )} +
+
+ {Template} +
+
+ ) +} + +export default Doc diff --git a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx b/web/app/(commonLayout)/datasets/new-dataset-card.tsx similarity index 100% rename from web/app/(commonLayout)/datasets/NewDatasetCard.tsx rename to web/app/(commonLayout)/datasets/new-dataset-card.tsx diff --git a/web/app/(commonLayout)/datasets/page.tsx b/web/app/(commonLayout)/datasets/page.tsx index 60a542f0a2..cbfe25ebd2 100644 --- a/web/app/(commonLayout)/datasets/page.tsx +++ b/web/app/(commonLayout)/datasets/page.tsx @@ -1,6 +1,6 @@ 'use client' import { useTranslation } from 'react-i18next' -import Container from './Container' +import Container from './container' import useDocumentTitle from '@/hooks/use-document-title' const AppList = () => { diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index ebb2e6a806..0d41691dfd 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
```bash {{ title: 'cURL' }} - curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \ + curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \ --header 'Authorization: Bearer {api_key}' ``` @@ -1873,7 +1873,7 @@ ___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
Okay, I will translate the Chinese text in your document while keeping all formatting and code content unchanged. -
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/(commonLayout)/datasets/template/template.ja.mdx b/web/app/(commonLayout)/datasets/template/template.ja.mdx index 6c0e20e1bb..5c7a752c11 100644 --- a/web/app/(commonLayout)/datasets/template/template.ja.mdx +++ b/web/app/(commonLayout)/datasets/template/template.ja.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
```bash {{ title: 'cURL' }} - curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \ + curl --location --request DELETE '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}/child_chunks/{child_chunk_id}' \ --header 'Authorization: Bearer {api_key}' ``` @@ -1629,7 +1629,7 @@ ___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index c21ce3bf5f..b7ea889a46 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -25,7 +25,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
@@ -1915,7 +1915,7 @@ ___ -
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
+
diff --git a/web/app/(commonLayout)/education-apply/page.tsx b/web/app/(commonLayout)/education-apply/page.tsx index 873034452e..5dd3c35519 100644 --- a/web/app/(commonLayout)/education-apply/page.tsx +++ b/web/app/(commonLayout)/education-apply/page.tsx @@ -13,12 +13,12 @@ import { useProviderContext } from '@/context/provider-context' export default function EducationApply() { const router = useRouter() - const { enableEducationPlan, isEducationAccount } = useProviderContext() + const { enableEducationPlan } = useProviderContext() const searchParams = useSearchParams() const token = searchParams.get('token') const showEducationApplyPage = useMemo(() => { - return enableEducationPlan && !isEducationAccount && token - }, [enableEducationPlan, isEducationAccount, token]) + return enableEducationPlan && token + }, [enableEducationPlan, token]) useEffect(() => { if (!showEducationApplyPage) diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index 64186a1b10..ed1c995e25 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -8,6 +8,7 @@ import Header from '@/app/components/header' import { EventEmitterContextProvider } from '@/context/event-emitter' import { ProviderContextProvider } from '@/context/provider-context' import { ModalContextProvider } from '@/context/modal-context' +import GotoAnything from '@/app/components/goto-anything' const Layout = ({ children }: { children: ReactNode }) => { return ( @@ -22,6 +23,7 @@ const Layout = ({ children }: { children: ReactNode }) => {
{children} + diff --git a/web/app/(commonLayout)/plugins/page.tsx b/web/app/(commonLayout)/plugins/page.tsx index 47f2791075..d07c4307ad 100644 --- a/web/app/(commonLayout)/plugins/page.tsx +++ b/web/app/(commonLayout)/plugins/page.tsx @@ -1,7 +1,7 @@ import PluginPage from '@/app/components/plugins/plugin-page' import PluginsPanel from '@/app/components/plugins/plugin-page/plugins-panel' import Marketplace from '@/app/components/plugins/marketplace' -import { getLocaleOnServer } from '@/i18n/server' +import { getLocaleOnServer } from '@/i18n-config/server' const PluginList = async () => { const locale = await getLocaleOnServer() diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index da754794b1..91e1021610 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -70,7 +70,10 @@ export default function CheckCode() {

{t('login.checkCode.checkYourEmail')}

- + + {t('login.checkCode.tipsPrefix')} + {email} +
{t('login.checkCode.validTime')}

diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index a2ba620ace..c80a006583 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -93,7 +93,10 @@ export default function CheckCode() {

{t('login.checkCode.checkYourEmail')}

- + + {t('login.checkCode.tipsPrefix')} + {email} +
{t('login.checkCode.validTime')}

diff --git a/web/app/account/account-page/AvatarWithEdit.tsx b/web/app/account/account-page/AvatarWithEdit.tsx index 8250789def..88e3a7b343 100644 --- a/web/app/account/account-page/AvatarWithEdit.tsx +++ b/web/app/account/account-page/AvatarWithEdit.tsx @@ -4,7 +4,7 @@ import type { Area } from 'react-easy-crop' import React, { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { RiPencilLine } from '@remixicon/react' +import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react' import { updateUserProfile } from '@/service/common' import { ToastContext } from '@/app/components/base/toast' import ImageInput, { type OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput' @@ -27,6 +27,8 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { const [inputImageInfo, setInputImageInfo] = useState() const [isShowAvatarPicker, setIsShowAvatarPicker] = useState(false) const [uploading, setUploading] = useState(false) + const [isShowDeleteConfirm, setIsShowDeleteConfirm] = useState(false) + const [hoverArea, setHoverArea] = useState('left') const handleImageInput: OnImageInput = useCallback(async (isCropped: boolean, fileOrTempUrl: string | File, croppedAreaPixels?: Area, fileName?: string) => { setInputImageInfo( @@ -48,6 +50,18 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { } }, [notify, onSave, t]) + const handleDeleteAvatar = useCallback(async () => { + try { + await updateUserProfile({ url: 'account/avatar', body: { avatar: '' } }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + setIsShowDeleteConfirm(false) + onSave?.() + } + catch (e) { + notify({ type: 'error', message: (e as Error).message }) + } + }, [notify, onSave, t]) + const { handleLocalFileUpload } = useLocalFileUploader({ limit: 3, disabled: false, @@ -86,12 +100,21 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
{ setIsShowAvatarPicker(true) }} - className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black bg-opacity-50 opacity-0 transition-opacity group-hover:opacity-100" + className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100" + onClick={() => hoverArea === 'right' ? setIsShowDeleteConfirm(true) : setIsShowAvatarPicker(true)} + onMouseMove={(e) => { + const rect = e.currentTarget.getBoundingClientRect() + const x = e.clientX - rect.left + const isRight = x > rect.width / 2 + setHoverArea(isRight ? 'right' : 'left') + }} > - + {hoverArea === 'right' ? + + : - + } +
@@ -115,6 +138,26 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
+ + setIsShowDeleteConfirm(false)} + > +
{t('common.avatar.deleteTitle')}
+

{t('common.avatar.deleteDescription')}

+ +
+ + + +
+
) } diff --git a/web/app/account/header.tsx b/web/app/account/header.tsx index d033bfab61..af09ca1c9c 100644 --- a/web/app/account/header.tsx +++ b/web/app/account/header.tsx @@ -13,14 +13,14 @@ const Header = () => { const router = useRouter() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) - const back = useCallback(() => { - router.back() + const goToStudio = useCallback(() => { + router.push('/apps') }, [router]) return (
-
+
{systemFeatures.branding.enabled && systemFeatures.branding.login_page_logo ? {

{t('common.account.account')}

- )} @@ -322,7 +331,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx className='flex flex-1 flex-col gap-2 overflow-auto px-2 py-1' /> -
+
-
+