diff --git a/.github/dependabot.yml b/.github/dependabot.yml index a183f0b58c..266fa17c29 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,106 +1,6 @@ version: 2 updates: - - package-ecosystem: "pip" - directory: "/api" - open-pull-requests-limit: 10 - schedule: - interval: "weekly" - groups: - flask: - patterns: - - "flask" - - "flask-*" - - "werkzeug" - - "gunicorn" - google: - patterns: - - "google-*" - - "googleapis-*" - opentelemetry: - patterns: - - "opentelemetry-*" - pydantic: - patterns: - - "pydantic" - - "pydantic-*" - llm: - patterns: - - "langfuse" - - "langsmith" - - "litellm" - - "mlflow*" - - "opik" - - "weave*" - - "arize*" - - "tiktoken" - - "transformers" - database: - patterns: - - "sqlalchemy" - - "psycopg2*" - - "psycogreen" - - "redis*" - - "alembic*" - storage: - patterns: - - "boto3*" - - "botocore*" - - "azure-*" - - "bce-*" - - "cos-python-*" - - "esdk-obs-*" - - "google-cloud-storage" - - "opendal" - - "oss2" - - "supabase*" - - "tos*" - vdb: - patterns: - - "alibabacloud*" - - "chromadb" - - "clickhouse-*" - - "clickzetta-*" - - "couchbase" - - "elasticsearch" - - "opensearch-py" - - "oracledb" - - "pgvect*" - - "pymilvus" - - "pymochow" - - "pyobvector" - - "qdrant-client" - - "intersystems-*" - - "tablestore" - - "tcvectordb" - - "tidb-vector" - - "upstash-*" - - "volcengine-*" - - "weaviate-*" - - "xinference-*" - - "mo-vector" - - "mysql-connector-*" - dev: - patterns: - - "coverage" - - "dotenv-linter" - - "faker" - - "lxml-stubs" - - "basedpyright" - - "ruff" - - "pytest*" - - "types-*" - - "boto3-stubs" - - "hypothesis" - - "pandas-stubs" - - "scipy-stubs" - - "import-linter" - - "celery-types" - - "mypy*" - - "pyrefly" - python-packages: - patterns: - - "*" - package-ecosystem: "uv" directory: "/api" open-pull-requests-limit: 10 diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index a069b6cbc7..1e848612ec 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -7,6 +7,7 @@ ## Summary + ## Screenshots @@ -17,7 +18,7 @@ ## Checklist - [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs) -- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) -- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. -- [x] I've updated the documentation accordingly. -- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods +- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) +- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. +- [ ] I've updated the documentation accordingly. +- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index cd967b76cf..fd910531db 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -54,7 +54,7 @@ jobs: run: uv run --project api bash dev/pytest/pytest_unit_tests.sh - name: Upload unit coverage data - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: api-coverage-unit path: coverage-unit @@ -129,7 +129,7 @@ jobs: api/tests/test_containers_integration_tests - name: Upload integration coverage data - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: api-coverage-integration path: coverage-integration diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 79ecdb5938..5f16fc6927 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -81,7 +81,7 @@ jobs: - name: Build Docker image id: build - uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 with: context: ${{ matrix.build_context }} file: ${{ matrix.file }} @@ -101,7 +101,7 @@ jobs: touch "/tmp/digests/${sanitized_digest}" - name: Upload digest - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }} path: /tmp/digests/* diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index cd9d69d871..6a132a5931 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -50,7 +50,7 @@ jobs: uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 - name: Build Docker Image - uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 with: push: false context: ${{ matrix.context }} diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index 0278e1e0d3..eefb1ebbb9 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -21,7 +21,7 @@ jobs: if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} steps: - name: Download pyrefly diff artifact - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -49,7 +49,7 @@ jobs: run: unzip -o pyrefly_diff.zip - name: Post comment - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index 8623d35b04..ac3732579c 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -66,7 +66,7 @@ jobs: echo ${{ github.event.pull_request.number }} > pr_number.txt - name: Upload pyrefly diff - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: pyrefly_diff path: | @@ -75,7 +75,7 @@ jobs: - name: Comment PR with pyrefly diff if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }} - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml new file mode 100644 index 0000000000..51f3ca54b6 --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage-comment.yml @@ -0,0 +1,118 @@ +name: Comment with Pyrefly Type Coverage + +on: + workflow_run: + workflows: + - Pyrefly Type Coverage + types: + - completed + +permissions: {} + +jobs: + comment: + name: Comment PR with type coverage + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + issues: write + pull-requests: write + if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} + steps: + - name: Checkout default branch (trusted code) + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Download type coverage artifact + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const artifacts = await github.rest.actions.listWorkflowRunArtifacts({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: ${{ github.event.workflow_run.id }}, + }); + const match = artifacts.data.artifacts.find((artifact) => + artifact.name === 'pyrefly_type_coverage' + ); + if (!match) { + throw new Error('pyrefly_type_coverage artifact not found'); + } + const download = await github.rest.actions.downloadArtifact({ + owner: context.repo.owner, + repo: context.repo.repo, + artifact_id: match.id, + archive_format: 'zip', + }); + fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data)); + + - name: Unzip artifact + run: unzip -o pyrefly_type_coverage.zip + + - name: Render coverage markdown from structured data + id: render + run: | + comment_body="$(uv run --directory api python api/libs/pyrefly_type_coverage.py \ + --base base_report.json \ + < pr_report.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } > /tmp/type_coverage_comment.md + + - name: Post comment + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + let prNumber = null; + try { + prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10); + } catch (err) { + const prs = context.payload.workflow_run.pull_requests || []; + if (prs.length > 0 && prs[0].number) { + prNumber = prs[0].number; + } + } + if (!prNumber) { + throw new Error('PR number not found in artifact or workflow_run payload'); + } + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const marker = '### Pyrefly Type Coverage'; + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml new file mode 100644 index 0000000000..c795c32e31 --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage.yml @@ -0,0 +1,120 @@ +name: Pyrefly Type Coverage + +on: + pull_request: + paths: + - 'api/**/*.py' + +permissions: + contents: read + +jobs: + pyrefly-type-coverage: + runs-on: ubuntu-latest + permissions: + contents: read + issues: write + pull-requests: write + steps: + - name: Checkout PR branch + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Run pyrefly report on PR branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \ + mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \ + echo '{}' > /tmp/pyrefly_report_pr.json + + - name: Save helper script from base branch + run: | + git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \ + || cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py + + - name: Checkout base branch + run: git checkout ${{ github.base_ref }} + + - name: Run pyrefly report on base branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \ + mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \ + echo '{}' > /tmp/pyrefly_report_base.json + + - name: Generate coverage comparison + id: coverage + run: | + comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \ + --base /tmp/pyrefly_report_base.json \ + < /tmp/pyrefly_report_pr.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md + + # Save structured data for the fork-PR comment workflow + cp /tmp/pyrefly_report_pr.json pr_report.json + cp /tmp/pyrefly_report_base.json base_report.json + + - name: Save PR number + run: | + echo ${{ github.event.pull_request.number }} > pr_number.txt + + - name: Upload type coverage artifact + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: pyrefly_type_coverage + path: | + pr_report.json + base_report.json + pr_number.txt + + - name: Comment PR with type coverage + if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const marker = '### Pyrefly Type Coverage'; + let body; + try { + body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + } catch { + body = `${marker}\n\n_Coverage report unavailable._`; + } + const prNumber = context.payload.pull_request.number; + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 5cf52daed2..c74f4a670a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -23,8 +23,8 @@ jobs: days-before-issue-stale: 15 days-before-issue-close: 3 repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it." - stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it." + stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it." + stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it." stale-issue-label: 'no-issue-activity' stale-pr-label: 'no-pr-activity' - any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted' + any-of-labels: '🌚 invalid,🙋‍♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index e001f4d677..541200293d 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -158,7 +158,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.context.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89 + uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml index 9a11d3e8df..790ea9126d 100644 --- a/.github/workflows/trigger-i18n-sync.yml +++ b/.github/workflows/trigger-i18n-sync.yml @@ -56,7 +56,7 @@ jobs: - name: Trigger i18n sync workflow if: steps.detect.outputs.has_changes == 'true' - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 env: BASE_SHA: ${{ steps.detect.outputs.base_sha }} HEAD_SHA: ${{ steps.detect.outputs.head_sha }} diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml index eb752619be..10dc31bde8 100644 --- a/.github/workflows/web-e2e.yml +++ b/.github/workflows/web-e2e.yml @@ -53,7 +53,7 @@ jobs: - name: Upload Cucumber report if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: cucumber-report path: e2e/cucumber-report @@ -61,7 +61,7 @@ jobs: - name: Upload E2E logs if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: e2e-logs path: e2e/.logs diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 3c36335e79..f3ab4c62c7 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -43,7 +43,7 @@ jobs: - name: Upload blob report if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: blob-report-${{ matrix.shardIndex }} path: web/.vitest-reports/* diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 775401bfa5..d7f007af67 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process. ## Getting Help If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. - -## Automated Agent Contributions - -> [!NOTE] -> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in. diff --git a/api/commands/account.py b/api/commands/account.py index 84af7a5ae6..6a2a2e0428 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,7 +2,6 @@ import base64 import secrets import click -from sqlalchemy.orm import sessionmaker from constants.languages import languages from extensions.ext_database import db @@ -25,30 +24,31 @@ def reset_password(email, new_password, password_confirm): return normalized_email = email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - valid_password(new_password) - except: - click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) - return + try: + valid_password(new_password) + except: + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) + return - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() - # encrypt password with salt - password_hashed = hash_password(new_password, salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(normalized_email) - click.echo(click.style("Password reset successfully.", fg="green")) + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account = db.session.merge(account) + account.password = base64_password_hashed + account.password_salt = base64_salt + db.session.commit() + AccountService.reset_login_error_rate_limit(normalized_email) + click.echo(click.style("Password reset successfully.", fg="green")) @click.command("reset-email", help="Reset the account email.") @@ -65,21 +65,22 @@ def reset_email(email, new_email, email_confirm): return normalized_new_email = new_email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - email_validate(normalized_new_email) - except: - click.echo(click.style(f"Invalid email: {new_email}", fg="red")) - return + try: + email_validate(normalized_new_email) + except: + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) + return - account.email = normalized_new_email - click.echo(click.style("Email updated successfully.", fg="green")) + account = db.session.merge(account) + account.email = normalized_new_email + db.session.commit() + click.echo(click.style("Email updated successfully.", fg="green")) @click.command("create-tenant", help="Create account and tenant.") diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 15ac8bf0bf..817284d26f 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,5 +1,5 @@ import os -from typing import Any, Literal +from typing import Any, Literal, TypedDict from urllib.parse import parse_qsl, quote_plus from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field @@ -107,6 +107,17 @@ class KeywordStoreConfig(BaseSettings): ) +class SQLAlchemyEngineOptionsDict(TypedDict): + pool_size: int + max_overflow: int + pool_recycle: int + pool_pre_ping: bool + connect_args: dict[str, str] + pool_use_lifo: bool + pool_reset_on_return: None + pool_timeout: int + + class DatabaseConfig(BaseSettings): # Database type selector DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field( @@ -209,11 +220,11 @@ class DatabaseConfig(BaseSettings): @computed_field # type: ignore[prop-decorator] @property - def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: + def SQLALCHEMY_ENGINE_OPTIONS(self) -> SQLAlchemyEngineOptionsDict: # Parse DB_EXTRAS for 'options' db_extras_dict = dict(parse_qsl(self.DB_EXTRAS)) options = db_extras_dict.get("options", "") - connect_args = {} + connect_args: dict[str, str] = {} # Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"): timezone_opt = "-c timezone=UTC" @@ -223,7 +234,7 @@ class DatabaseConfig(BaseSettings): merged_options = timezone_opt connect_args = {"options": merged_options} - return { + result: SQLAlchemyEngineOptionsDict = { "pool_size": self.SQLALCHEMY_POOL_SIZE, "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, @@ -233,6 +244,7 @@ class DatabaseConfig(BaseSettings): "pool_reset_on_return": None, "pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT, } + return result class CeleryConfig(DatabaseConfig): diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 9931bb5dd7..528785931e 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -25,7 +25,13 @@ from fields.annotation_fields import ( ) from libs.helper import uuid_value from libs.login import login_required -from services.annotation_service import AppAnnotationService +from services.annotation_service import ( + AppAnnotationService, + EnableAnnotationArgs, + UpdateAnnotationArgs, + UpdateAnnotationSettingArgs, + UpsertAnnotationArgs, +) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource): args = AnnotationReplyPayload.model_validate(console_ns.payload) match action: case "enable": - result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) + enable_args: EnableAnnotationArgs = { + "score_threshold": args.score_threshold, + "embedding_provider_name": args.embedding_provider_name, + "embedding_model_name": args.embedding_model_name, + } + result = AppAnnotationService.enable_app_annotation(enable_args, app_id) case "disable": result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 @@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource): args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump()) + setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold} + result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args) return result, 200 @@ -237,8 +249,16 @@ class AnnotationApi(Resource): def post(self, app_id): app_id = str(app_id) args = CreateAnnotationPayload.model_validate(console_ns.payload) - data = args.model_dump(exclude_none=True) - annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) + upsert_args: UpsertAnnotationArgs = {} + if args.answer is not None: + upsert_args["answer"] = args.answer + if args.content is not None: + upsert_args["content"] = args.content + if args.message_id is not None: + upsert_args["message_id"] = args.message_id + if args.question is not None: + upsert_args["question"] = args.question + annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) args = UpdateAnnotationPayload.model_validate(console_ns.payload) - annotation = AppAnnotationService.update_app_annotation_directly( - args.model_dump(exclude_none=True), app_id, annotation_id - ) + update_args: UpdateAnnotationArgs = {} + if args.answer is not None: + update_args["answer"] = args.answer + if args.question is not None: + update_args["question"] = args.question + annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 8bb5aa2c1b..1869cbf5f6 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,9 +1,11 @@ import json -from typing import cast +from typing import Any, cast from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource +from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required @@ -18,30 +20,30 @@ from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService +class ModelConfigRequest(BaseModel): + provider: str | None = Field(default=None, description="Model provider") + model: str | None = Field(default=None, description="Model name") + configs: dict[str, Any] | None = Field(default=None, description="Model configuration parameters") + opening_statement: str | None = Field(default=None, description="Opening statement") + suggested_questions: list[str] | None = Field(default=None, description="Suggested questions") + more_like_this: dict[str, Any] | None = Field(default=None, description="More like this configuration") + speech_to_text: dict[str, Any] | None = Field(default=None, description="Speech to text configuration") + text_to_speech: dict[str, Any] | None = Field(default=None, description="Text to speech configuration") + retrieval_model: dict[str, Any] | None = Field(default=None, description="Retrieval model configuration") + tools: list[dict[str, Any]] | None = Field(default=None, description="Available tools") + dataset_configs: dict[str, Any] | None = Field(default=None, description="Dataset configurations") + agent_mode: dict[str, Any] | None = Field(default=None, description="Agent mode configuration") + + +register_schema_models(console_ns, ModelConfigRequest) + + @console_ns.route("/apps//model-config") class ModelConfigResource(Resource): @console_ns.doc("update_app_model_config") @console_ns.doc(description="Update application model configuration") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "ModelConfigRequest", - { - "provider": fields.String(description="Model provider"), - "model": fields.String(description="Model name"), - "configs": fields.Raw(description="Model configuration parameters"), - "opening_statement": fields.String(description="Opening statement"), - "suggested_questions": fields.List(fields.String(), description="Suggested questions"), - "more_like_this": fields.Raw(description="More like this configuration"), - "speech_to_text": fields.Raw(description="Speech to text configuration"), - "text_to_speech": fields.Raw(description="Text to speech configuration"), - "retrieval_model": fields.Raw(description="Retrieval model configuration"), - "tools": fields.List(fields.Raw(), description="Available tools"), - "dataset_configs": fields.Raw(description="Dataset configurations"), - "agent_mode": fields.Raw(description="Agent mode configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[ModelConfigRequest.__name__]) @console_ns.response(200, "Model configuration updated successfully") @console_ns.response(400, "Invalid configuration") @console_ns.response(404, "App not found") diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 9771d6f1e5..657e794ac4 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -1,7 +1,7 @@ import logging from collections.abc import Callable from functools import wraps -from typing import Any +from typing import Any, TypedDict from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with @@ -86,7 +86,14 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: return value_type.exposed_type().value -def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None: +class FullContentDict(TypedDict): + size_bytes: int | None + value_type: str + length: int | None + download_url: str + + +def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None: """Serialize full_content information for large variables.""" if not variable.is_truncated(): return None @@ -94,12 +101,13 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None: variable_file = variable.variable_file assert variable_file is not None - return { + result: FullContentDict = { "size_bytes": variable_file.size, "value_type": variable_file.value_type.exposed_type().value, "length": variable_file.length, "download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True), } + return result def _ensure_variable_access( diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 9e7faa09c5..1fd781b4fc 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -14,7 +13,6 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models import Account @@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - raise EmailAlreadyInUseError() - else: - account = self._create_new_account(normalized_email, args.password_confirm) - if not account: - raise AccountNotFoundError() - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(normalized_email) + if account: + raise EmailAlreadyInUseError() + else: + account = self._create_new_account(normalized_email, args.password_confirm) + if not account: + raise AccountNotFoundError() + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 63bc98b53f..ed390a5f89 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_reset_password_email( account=account, @@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt, session) - else: - raise AccountNotFound() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AccountNotFound() return {"result": "success"} - def _update_existing_account(self, account, password_hashed, salt, session): + def _update_existing_account(self, account, password_hashed, salt): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5c7011fd22..d31fb4a46c 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,6 @@ import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(user_info.email) return account diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 51ddf1b292..ce2870c82e 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -4,10 +4,11 @@ from datetime import UTC, datetime, timedelta from typing import Literal from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from enums.cloud_plan import CloudPlan @@ -15,8 +16,6 @@ from extensions.ext_redis import redis_client from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class SubscriptionQuery(BaseModel): plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan") @@ -27,8 +26,7 @@ class PartnerTenantsPayload(BaseModel): click_id: str = Field(..., description="Click Id from partner referral link") -for model in (SubscriptionQuery, PartnerTenantsPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, SubscriptionQuery, PartnerTenantsPayload) @console_ns.route("/billing/subscription") @@ -61,12 +59,7 @@ class PartnerTenants(Resource): @console_ns.doc("sync_partner_tenants_bindings") @console_ns.doc(description="Sync partner tenants bindings") @console_ns.doc(params={"partner_key": "Partner key"}) - @console_ns.expect( - console_ns.model( - "SyncPartnerTenantsBindingsRequest", - {"click_id": fields.String(required=True, description="Click Id from partner referral link")}, - ) - ) + @console_ns.expect(console_ns.models[PartnerTenantsPayload.__name__]) @console_ns.response(200, "Tenants synced to partner successfully") @console_ns.response(400, "Invalid partner information") @setup_required diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index e623722b23..ed3c1a59d4 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -162,7 +162,9 @@ class DataSourceApi(Resource): binding_id = str(binding_id) with sessionmaker(db.engine, expire_on_commit=False).begin() as session: data_source_binding = session.execute( - select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id) + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id + ) ).scalar_one_or_none() if data_source_binding is None: raise NotFound("Data source binding not found.") @@ -222,11 +224,11 @@ class DataSourceNotionListApi(Resource): raise ValueError("Dataset is not notion type.") documents = session.scalars( - select(Document).filter_by( - dataset_id=query.dataset_id, - tenant_id=current_tenant_id, - data_source_type="notion_import", - enabled=True, + select(Document).where( + Document.dataset_id == query.dataset_id, + Document.tenant_id == current_tenant_id, + Document.data_source_type == "notion_import", + Document.enabled.is_(True), ) ).all() if documents: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ab367d8483..b7584f1f00 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -280,7 +280,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id) + query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id) if status: query = DocumentService.apply_display_status_filter(query, status) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index f3866f6aef..e513e8c8f9 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource): @login_required @account_initialization_required def get(self, external_knowledge_api_id): + _, current_tenant_id = current_account_with_tenant() external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check( - external_knowledge_api_id + external_knowledge_api_id, current_tenant_id ) return {"is_using": external_knowledge_api_is_using, "count": count}, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 6c02646c22..a8077d9eb0 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): -# @setup_required -# @login_required -# @account_initialization_required -# @get_rag_pipeline -# def post(self, pipeline: Pipeline, node_id: str): -# """ -# Run rag pipeline datasource -# """ -# # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.has_edit_permission: -# raise Forbidden() -# -# if not isinstance(current_user, Account): -# raise Forbidden() -# -# parser = (reqparse.RequestParser() -# .add_argument("job_id", type=str, required=True, nullable=False, location="json") -# .add_argument("datasource_type", type=str, required=True, location="json") -# ) -# args = parser.parse_args() -# -# job_id = args.get("job_id") -# if job_id == None: -# raise ValueError("missing job_id") -# datasource_type = args.get("datasource_type") -# if datasource_type == None: -# raise ValueError("missing datasource_type") -# -# rag_pipeline_service = RagPipelineService() -# result = rag_pipeline_service.run_datasource_workflow_node_status( -# pipeline=pipeline, -# node_id=node_id, -# job_id=job_id, -# account=current_user, -# datasource_type=datasource_type, -# is_published=True -# ) -# -# return result - - -# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): -# @setup_required -# @login_required -# @account_initialization_required -# @get_rag_pipeline -# def post(self, pipeline: Pipeline, node_id: str): -# """ -# Run rag pipeline datasource -# """ -# # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.has_edit_permission: -# raise Forbidden() -# -# if not isinstance(current_user, Account): -# raise Forbidden() -# -# parser = (reqparse.RequestParser() -# .add_argument("job_id", type=str, required=True, nullable=False, location="json") -# .add_argument("datasource_type", type=str, required=True, location="json") -# ) -# args = parser.parse_args() -# -# job_id = args.get("job_id") -# if job_id == None: -# raise ValueError("missing job_id") -# datasource_type = args.get("datasource_type") -# if datasource_type == None: -# raise ValueError("missing datasource_type") -# -# rag_pipeline_service = RagPipelineService() -# result = rag_pipeline_service.run_datasource_workflow_node_status( -# pipeline=pipeline, -# node_id=node_id, -# job_id=job_id, -# account=current_user, -# datasource_type=datasource_type, -# is_published=False -# ) -# -# return result -# @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//run") class RagPipelinePublishedDatasourceNodeRunApi(Resource): @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__]) diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 5d79e1b5e9..845af37365 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -7,7 +7,8 @@ import logging from collections.abc import Generator from flask import Response, jsonify, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream logger = logging.getLogger(__name__) +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict + action: str + + def _jsonify_form_definition(form: Form) -> Response: payload = form.get_definition().model_dump() payload["expiration_time"] = int(form.expiration_time.timestamp()) @@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource): "action": "Approve" } """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() + payload = HumanInputFormSubmitPayload.model_validate(request.get_json()) current_user, _ = current_account_with_tenant() service = HumanInputService(db.engine) @@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource): service.submit_form_by_token( recipient_type=recipient_type, form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], + selected_action_id=payload.action, + form_data=payload.inputs, submission_user_id=current_user.id, ) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 626d330e9d..af25669ae0 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -8,7 +8,6 @@ from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import supported_language @@ -562,8 +561,7 @@ class ChangeEmailSendEmailApi(Resource): user_email = current_user.email else: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) if account is None: raise AccountNotFound() email_for_sending = account.email diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 1d378c754c..a5846e2815 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -94,10 +94,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]: def plugin_data[**P, R]( - view: Callable[P, R] | None = None, *, payload_type: type[BaseModel], -) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: +) -> Callable[[Callable[P, R]], Callable[P, R]]: def decorator(view_func: Callable[P, R]) -> Callable[P, R]: @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: @@ -116,7 +115,4 @@ def plugin_data[**P, R]( return decorated_view - if view is None: - return decorator - else: - return decorator(view) + return decorator diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index c22190cbc9..00bb9aa463 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import Annotation, AnnotationList from models.model import App -from services.annotation_service import AppAnnotationService +from services.annotation_service import ( + AppAnnotationService, + EnableAnnotationArgs, + InsertAnnotationArgs, + UpdateAnnotationArgs, +) class AnnotationCreatePayload(BaseModel): @@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource): @validate_app_token def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" - args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() + payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}) match action: case "enable": - result = AppAnnotationService.enable_app_annotation(args, app_model.id) + enable_args: EnableAnnotationArgs = { + "score_threshold": payload.score_threshold, + "embedding_provider_name": payload.embedding_provider_name, + "embedding_model_name": payload.embedding_model_name, + } + result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id) case "disable": result = AppAnnotationService.disable_app_annotation(app_model.id) return result, 200 @@ -135,8 +145,9 @@ class AnnotationListApi(Resource): @validate_app_token def post(self, app_model: App): """Create a new annotation.""" - args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() - annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) + payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) + insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer} + annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json"), HTTPStatus.CREATED @@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource): @edit_permission_required def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" - args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() - annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) + payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) + update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer} + annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 9f1ce17ed9..db34aa408e 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -527,7 +527,7 @@ class DocumentListApi(DatasetApiResource): if not dataset: raise NotFound("Dataset not found.") - query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) + query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id) if query_params.status: query = DocumentService.apply_display_status_filter(query, query_params.status) diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 80c3289fb4..61fd794c22 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -3,7 +3,6 @@ import secrets from flask import request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( @@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) - token = None + account = AccountService.get_account_by_email_with_case_fallback(request_email) if account is None: raise AuthenticationFailedError() else: @@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt) - else: - raise AuthenticationFailedError() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AuthenticationFailedError() return {"result": "success"} diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 36728a47d1..aff0b42d95 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -7,7 +7,8 @@ import logging from datetime import datetime from flask import Response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden @@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ logger = logging.getLogger(__name__) + +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict + action: str + + _FORM_SUBMIT_RATE_LIMITER = RateLimiter( prefix="web_form_submit_rate_limit", max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, @@ -112,10 +119,7 @@ class HumanInputFormApi(Resource): "action": "Approve" } """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() + payload = HumanInputFormSubmitPayload.model_validate(request.get_json()) ip_address = extract_remote_ip(request) if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address): @@ -135,8 +139,8 @@ class HumanInputFormApi(Resource): service.submit_form_by_token( recipient_type=recipient_type, form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], + selected_action_id=payload.action, + form_data=payload.inputs, submission_end_user_id=None, # submission_end_user_id=_end_user.id, ) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 11e2aa062d..f07ac64498 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -2,7 +2,7 @@ import json import logging from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Any +from typing import Any, TypedDict from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from graphon.model_runtime.entities.message_entities import ( @@ -29,6 +29,13 @@ from models.model import Message logger = logging.getLogger(__name__) +class ActionDict(TypedDict): + """Shape produced by AgentScratchpadUnit.Action.to_dict().""" + + action: str + action_input: dict[str, Any] | str + + class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ["wenxin"] @@ -331,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): return tool_invoke_response, tool_invoke_meta - def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: + def _convert_dict_to_action(self, action: ActionDict) -> AgentScratchpadUnit.Action: """ convert dict to action """ diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 0cdbb5f50a..a3fb7b4c5d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from graphon.file import File, FileUploadConfig from graphon.model_runtime.entities.model_entities import AIModelEntity @@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel): extras: dict[str, Any] = Field(default_factory=dict) # tracing instance - trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False) + trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False) class EasyUIBasedAppGenerateEntity(AppGenerateEntity): diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 14d1af2e8b..890f1ca319 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ -from typing import Literal, Optional +from typing import Any, Literal, TypedDict from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter -from core.tools.entities.common_entities import I18nObject +from core.tools.entities.common_entities import I18nObject, I18nObjectDict class DatasourceApiEntity(BaseModel): @@ -17,7 +17,24 @@ class DatasourceApiEntity(BaseModel): output_schema: dict | None = None -ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] +ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None + + +class DatasourceProviderApiEntityDict(TypedDict): + id: str + author: str + name: str + plugin_id: str | None + plugin_unique_identifier: str | None + description: I18nObjectDict + icon: str | dict + label: I18nObjectDict + type: str + team_credentials: dict | None + is_team_authorization: bool + allow_delete: bool + datasources: list[Any] + labels: list[str] class DatasourceProviderApiEntity(BaseModel): @@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel): def convert_none_to_empty_list(cls, v): return v if v is not None else [] - def to_dict(self) -> dict: + def to_dict(self) -> DatasourceProviderApiEntityDict: # ------------- # overwrite datasource parameter types for temp fix datasources = jsonable_encoder(self.datasources) @@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel): parameter["type"] = "files" # ------------- - return { + result: DatasourceProviderApiEntityDict = { "id": self.id, "author": self.author, "name": self.name, @@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel): "datasources": datasources, "labels": self.labels, } + return result diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index f20bab53f0..01f87b67f8 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -2,7 +2,7 @@ from __future__ import annotations import enum from enum import StrEnum -from typing import Any +from typing import Any, TypedDict from pydantic import BaseModel, Field, ValidationInfo, field_validator from yarl import URL @@ -179,6 +179,12 @@ class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): datasources: list[DatasourceEntity] = Field(default_factory=list) +class DatasourceInvokeMetaDict(TypedDict): + time_cost: float + error: str | None + tool_config: dict[str, Any] | None + + class DatasourceInvokeMeta(BaseModel): """ Datasource invoke meta @@ -202,12 +208,13 @@ class DatasourceInvokeMeta(BaseModel): """ return cls(time_cost=0.0, error=error, tool_config={}) - def to_dict(self) -> dict: - return { + def to_dict(self) -> DatasourceInvokeMetaDict: + result: DatasourceInvokeMetaDict = { "time_cost": self.time_cost, "error": self.error, "tool_config": self.tool_config, } + return result class DatasourceLabel(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 04f15dee31..c012e128f4 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -71,8 +71,8 @@ class DatasourceFileMessageTransformer: if not isinstance(message.message, DatasourceMessage.BlobMessage): raise ValueError("unexpected message type") - # FIXME: should do a type check here. - assert isinstance(message.message.blob, bytes) + if not isinstance(message.message.blob, bytes): + raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}") tool_file_manager = ToolFileManager() blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw( user_id=user_id, diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index c2789a7a35..b79dbeb7e0 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -32,9 +32,9 @@ class Extensible: name: str tenant_id: str - config: dict | None = None + config: dict[str, Any] | None = None - def __init__(self, tenant_id: str, config: dict | None = None): + def __init__(self, tenant_id: str, config: dict[str, Any] | None = None): self.tenant_id = tenant_id self.config = config diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 564801f189..8ce068cfbb 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any, TypedDict + from sqlalchemy import select from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor @@ -7,6 +10,16 @@ from extensions.ext_database import db from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +class ApiToolConfig(TypedDict, total=False): + """Expected config shape for ApiExternalDataTool. + + Not used directly in method signatures (base class accepts dict[str, Any]); + kept here to document the keys this tool reads from config. + """ + + api_based_extension_id: str + + class ApiExternalDataTool(ExternalDataTool): """ The api external data tool. @@ -16,7 +29,7 @@ class ApiExternalDataTool(ExternalDataTool): """the unique name of external data tool""" @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -37,7 +50,7 @@ class ApiExternalDataTool(ExternalDataTool): if not api_based_extension: raise ValueError("api_based_extension_id is invalid") - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index cbec2e4e42..12bea4e9e5 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -1,4 +1,6 @@ from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any from core.extension.extensible import Extensible, ExtensionModule @@ -15,14 +17,14 @@ class ExternalDataTool(Extensible, ABC): variable: str """the tool variable name of app tool""" - def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None): + def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict[str, Any] | None = None): super().__init__(tenant_id, config) self.app_id = app_id self.variable = variable @classmethod @abstractmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -33,7 +35,7 @@ class ExternalDataTool(Extensible, ABC): raise NotImplementedError @abstractmethod - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py index 9baf6c4682..ae7be91c17 100644 --- a/api/core/logging/structured_formatter.py +++ b/api/core/logging/structured_formatter.py @@ -3,7 +3,7 @@ import logging import traceback from datetime import UTC, datetime -from typing import Any, TypedDict +from typing import Any, NotRequired, TypedDict import orjson @@ -16,6 +16,19 @@ class IdentityDict(TypedDict, total=False): user_type: str +class LogDict(TypedDict): + ts: str + severity: str + service: str + caller: str + message: str + trace_id: NotRequired[str] + span_id: NotRequired[str] + identity: NotRequired[IdentityDict] + attributes: NotRequired[dict[str, Any]] + stack_trace: NotRequired[str] + + class StructuredJSONFormatter(logging.Formatter): """ JSON log formatter following the specified schema: @@ -55,9 +68,9 @@ class StructuredJSONFormatter(logging.Formatter): return json.dumps(log_dict, default=str, ensure_ascii=False) - def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: + def _build_log_dict(self, record: logging.LogRecord) -> LogDict: # Core fields - log_dict: dict[str, Any] = { + log_dict: LogDict = { "ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"), "severity": self.SEVERITY_MAP.get(record.levelno, "INFO"), "service": self._service_name, diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index d015769b54..1d8356acf6 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -146,7 +146,7 @@ def discover_protected_resource_metadata( return ProtectedResourceMetadata.model_validate(response.json()) elif response.status_code == 404: continue # Try next URL - except (RequestError, ValidationError): + except (RequestError, ValidationError, json.JSONDecodeError): continue # Try next URL return None @@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata( return OAuthMetadata.model_validate(response.json()) elif response.status_code == 404: continue # Try next URL - except (RequestError, ValidationError): + except (RequestError, ValidationError, json.JSONDecodeError): continue # Try next URL return None @@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: else: return False, "" return False, "" - except RequestError: + except (RequestError, json.JSONDecodeError, IndexError): # Not support resource discovery, fall back to well-known OAuth metadata return False, "" diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index d8724b8de5..173913196e 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient): logger.exception("Authentication retry failed") raise MCPAuthError(f"Authentication retry failed: {e}") from e - def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any: + def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Execute a function with authentication retry logic. diff --git a/api/core/mcp/entities.py b/api/core/mcp/entities.py index d6d3a677c6..21edc86a57 100644 --- a/api/core/mcp/entities.py +++ b/api/core/mcp/entities.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import StrEnum -from typing import Any, TypeVar +from typing import Any from pydantic import BaseModel @@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) -LifespanContextT = TypeVar("LifespanContextT") - @dataclass -class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]: +class RequestContext[SessionT: BaseSession, LifespanContextT]: request_id: RequestId meta: RequestParams.Meta | None session: SessionT diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 0b3aa79838..70d45b15c4 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul request: ReceiveRequestT _session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]" - _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any] + _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object] def __init__( self, @@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul request_meta: RequestParams.Meta | None, request: ReceiveRequestT, session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]", - on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], + on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object], ): self.request_id = request_id self.request_meta = request_meta diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 2653d20a7d..10e3082aa3 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -31,7 +31,6 @@ ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] RequestId = Annotated[int | str, Field(union_mode="left_to_right")] -type AnyFunction = Callable[..., Any] class RequestParams(BaseModel): diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 09c84538a9..5809d6f74a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -61,27 +61,28 @@ class TokenBufferMemory: :param is_user_message: whether this is a user message :return: PromptMessage """ - if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: - file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) - elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - app = self.conversation.app - if not app: - raise ValueError("App not found for conversation") + match self.conversation.mode: + case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT: + file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW: + app = self.conversation.app + if not app: + raise ValueError("App not found for conversation") - if not message.workflow_run_id: - raise ValueError("Workflow run ID not found") + if not message.workflow_run_id: + raise ValueError("Workflow run ID not found") - workflow_run = self.workflow_run_repo.get_workflow_run_by_id( - tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id - ) - if not workflow_run: - raise ValueError(f"Workflow run not found: {message.workflow_run_id}") - workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) - if not workflow: - raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - else: - raise AssertionError(f"Invalid app mode: {self.conversation.mode}") + workflow_run = self.workflow_run_repo.get_workflow_run_by_id( + tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id + ) + if not workflow_run: + raise ValueError(f"Workflow run not found: {message.workflow_run_id}") + workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + case _: + raise AssertionError(f"Invalid app mode: {self.conversation.mode}") detail = ImagePromptMessageContent.DETAIL.HIGH if file_extra_config and app_record: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 7a214777bc..86d042de3e 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -6,7 +6,7 @@ from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType -from graphon.model_runtime.entities.rerank_entities import RerankResult +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -172,10 +172,10 @@ class ModelInstance: function=self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, - prompt_messages=prompt_messages, + prompt_messages=list(prompt_messages), model_parameters=model_parameters, - tools=tools, - stop=stop, + tools=list(tools) if tools else None, + stop=list(stop) if stop else None, stream=stream, callbacks=callbacks, ), @@ -193,15 +193,12 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") - return cast( - int, - self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model_name, - credentials=self.credentials, - prompt_messages=prompt_messages, - tools=tools, - ), + return self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model_name, + credentials=self.credentials, + prompt_messages=list(prompt_messages), + tools=list(tools) if tools else None, ) def invoke_text_embedding( @@ -216,15 +213,12 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - return cast( - EmbeddingResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - texts=texts, - input_type=input_type, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + texts=texts, + input_type=input_type, ) def invoke_multimodal_embedding( @@ -241,15 +235,12 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - return cast( - EmbeddingResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - multimodel_documents=multimodel_documents, - input_type=input_type, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + multimodel_documents=multimodel_documents, + input_type=input_type, ) def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: @@ -261,14 +252,11 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - return cast( - list[int], - self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model_name, - credentials=self.credentials, - texts=texts, - ), + return self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model_name, + credentials=self.credentials, + texts=texts, ) def invoke_rerank( @@ -289,23 +277,20 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") - return cast( - RerankResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, ) def invoke_multimodal_rerank( self, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, ) -> RerankResult: @@ -320,17 +305,14 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") - return cast( - RerankResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke_multimodal_rerank, - model=self.model_name, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke_multimodal_rerank, + model=self.model_name, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, ) def invoke_moderation(self, text: str) -> bool: @@ -342,14 +324,11 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") - return cast( - bool, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - text=text, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + text=text, ) def invoke_speech2text(self, file: IO[bytes]) -> str: @@ -361,14 +340,11 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") - return cast( - str, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - file=file, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + file=file, ) def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: @@ -381,18 +357,15 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") - return cast( - Iterable[bytes], - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - content_text=content_text, - voice=voice, - ), + return self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + content_text=content_text, + voice=voice, ) - def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): + def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Round-robin invoke :param function: function to invoke @@ -430,9 +403,8 @@ class ModelInstance: continue try: - if "credentials" in kwargs: - del kwargs["credentials"] - return function(*args, **kwargs, credentials=lb_config.credentials) + kwargs["credentials"] = lb_config.credentials + return function(*args, **kwargs) except InvokeRateLimitError as e: # expire in 60 seconds self.load_balancing_manager.cooldown(lb_config, expire=60) diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index e2d2be92cb..c76cb865c3 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Any, Union, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): ) @classmethod - def _get_user(cls, user_id: str) -> Union[EndUser, Account]: + def _get_user(cls, user_id: str) -> EndUser | Account: """ get the user by user id """ diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index c706353ffe..36fca60db3 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -2,7 +2,7 @@ import json import os from collections.abc import Mapping, Sequence from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( @@ -34,6 +34,13 @@ class ModelMode(StrEnum): prompt_file_contents: dict[str, Any] = {} +class PromptTemplateConfigDict(TypedDict): + prompt_template: PromptTemplateParser + custom_variable_keys: list[str] + special_variable_keys: list[str] + prompt_rules: dict[str, Any] + + class SimplePromptTransform(PromptTransform): """ Simple Prompt Transform for Chatbot App Basic Mode. @@ -105,18 +112,13 @@ class SimplePromptTransform(PromptTransform): with_memory_prompt=histories is not None, ) - custom_variable_keys_obj = prompt_template_config["custom_variable_keys"] - special_variable_keys_obj = prompt_template_config["special_variable_keys"] + custom_variable_keys = prompt_template_config["custom_variable_keys"] + if not isinstance(custom_variable_keys, list): + raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}") - # Type check for custom_variable_keys - if not isinstance(custom_variable_keys_obj, list): - raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}") - custom_variable_keys = cast(list[str], custom_variable_keys_obj) - - # Type check for special_variable_keys - if not isinstance(special_variable_keys_obj, list): - raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}") - special_variable_keys = cast(list[str], special_variable_keys_obj) + special_variable_keys = prompt_template_config["special_variable_keys"] + if not isinstance(special_variable_keys, list): + raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}") variables = {k: inputs[k] for k in custom_variable_keys if k in inputs} @@ -150,7 +152,7 @@ class SimplePromptTransform(PromptTransform): has_context: bool, query_in_prompt: bool, with_memory_prompt: bool = False, - ) -> dict[str, object]: + ) -> PromptTemplateConfigDict: prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys: list[str] = [] @@ -173,12 +175,13 @@ class SimplePromptTransform(PromptTransform): prompt += prompt_rules.get("query_prompt", "{{#query#}}") special_variable_keys.append("#query#") - return { + result: PromptTemplateConfigDict = { "prompt_template": PromptTemplateParser(template=prompt), "custom_variable_keys": custom_variable_keys, "special_variable_keys": special_variable_keys, "prompt_rules": prompt_rules, } + return result def _get_chat_model_prompt_messages( self, diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index df02c584ed..90d6d98c63 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -5,6 +5,7 @@ from typing import Any from elasticsearch import Elasticsearch from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -19,6 +20,16 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class HuaweiElasticsearchParamsDict(TypedDict, total=False): + hosts: list[str] + verify_certs: bool + ssl_show_warn: bool + request_timeout: int + retry_on_timeout: bool + max_retries: int + basic_auth: tuple[str, str] + + def create_ssl_context() -> ssl.SSLContext: ssl_context = ssl.create_default_context() ssl_context.check_hostname = False @@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel): raise ValueError("config HOSTS is required") return values - def to_elasticsearch_params(self) -> dict[str, Any]: - params = { - "hosts": self.hosts.split(","), - "verify_certs": False, - "ssl_show_warn": False, - "request_timeout": 30000, - "retry_on_timeout": True, - "max_retries": 10, - } + def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict: + params = HuaweiElasticsearchParamsDict( + hosts=self.hosts.split(","), + verify_certs=False, + ssl_show_warn=False, + request_timeout=30000, + retry_on_timeout=True, + max_retries=10, + ) if self.username and self.password: params["basic_auth"] = (self.username, self.password) return params diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index bfcb620618..fbe0bcad02 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from tenacity import retry, stop_after_attempt, wait_exponential +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field" UGC_INDEX_PREFIX = "ugc_index" +class LindormOpenSearchParamsDict(TypedDict, total=False): + hosts: str | None + use_ssl: bool + pool_maxsize: int + timeout: int + http_auth: tuple[str, str] + + class LindormVectorStoreConfig(BaseModel): hosts: str | None username: str | None = None @@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel): raise ValueError("config PASSWORD is required") return values - def to_opensearch_params(self) -> dict[str, Any]: - params: dict[str, Any] = { - "hosts": self.hosts, - "use_ssl": False, - "pool_maxsize": 128, - "timeout": 30, - } + def to_opensearch_params(self) -> LindormOpenSearchParamsDict: + params = LindormOpenSearchParamsDict( + hosts=self.hosts, + use_ssl=False, + pool_maxsize=128, + timeout=30, + ) if self.username and self.password: params["http_auth"] = (self.username, self.password) return params diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 2f77776807..50d18cdc4c 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -6,6 +6,7 @@ from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from configs.middleware.vdb.opensearch_config import AuthMethod @@ -21,6 +22,20 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class _OpenSearchHostDict(TypedDict): + host: str + port: int + + +class OpenSearchParamsDict(TypedDict, total=False): + hosts: list[_OpenSearchHostDict] + use_ssl: bool + verify_certs: bool + connection_class: type + pool_maxsize: int + http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth + + class OpenSearchConfig(BaseModel): host: str port: int @@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel): service=self.aws_service, # type: ignore[arg-type] ) - def to_opensearch_params(self) -> dict[str, Any]: - params = { - "hosts": [{"host": self.host, "port": self.port}], - "use_ssl": self.secure, - "verify_certs": self.verify_certs, - "connection_class": Urllib3HttpConnection, - "pool_maxsize": 20, - } + def to_opensearch_params(self) -> OpenSearchParamsDict: + params = OpenSearchParamsDict( + hosts=[{"host": self.host, "port": self.port}], + use_ssl=self.secure, + verify_certs=self.verify_certs, + connection_class=Urllib3HttpConnection, + pool_maxsize=20, + ) if self.auth_method == "basic": logger.info("Using basic authentication for OpenSearch Vector DB") diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 3ecc9867fa..64b45bf28b 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator from sqlalchemy import Column, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -79,7 +79,7 @@ class RelytVector(BaseVector): if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" - with Session(self.client) as session: + with sessionmaker(bind=self.client).begin() as session: drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """) session.execute(drop_statement) create_statement = sql_text(f""" @@ -104,7 +104,6 @@ class RelytVector(BaseVector): $$); """) session.execute(index_statement) - session.commit() redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -208,9 +207,8 @@ class RelytVector(BaseVector): self.delete_by_uuids(ids) def delete(self): - with Session(self.client) as session: + with sessionmaker(bind=self.client).begin() as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";""")) - session.commit() def text_exists(self, id: str) -> bool: with Session(self.client) as session: diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index c948917374..e321681093 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -6,7 +6,7 @@ import sqlalchemy from pydantic import BaseModel, model_validator from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import text as sql_text -from sqlalchemy.orm import Session, declarative_base +from sqlalchemy.orm import Session, declarative_base, sessionmaker from configs import dify_config from core.rag.datasource.vdb.field import Field, parse_metadata_json @@ -97,8 +97,7 @@ class TiDBVector(BaseVector): if redis_client.get(collection_exist_cache_key): return tidb_dist_func = self._get_distance_func() - with Session(self._engine) as session: - session.begin() + with sessionmaker(bind=self._engine).begin() as session: create_statement = sql_text(f""" CREATE TABLE IF NOT EXISTS {self._collection_name} ( id CHAR(36) PRIMARY KEY, @@ -115,7 +114,6 @@ class TiDBVector(BaseVector): ); """) session.execute(create_statement) - session.commit() redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -238,9 +236,8 @@ class TiDBVector(BaseVector): return [] def delete(self): - with Session(self._engine) as session: + with sessionmaker(bind=self._engine).begin() as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) - session.commit() def _get_distance_func(self) -> str: match self._distance_func: diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 0ef88e1010..5d879ac3ca 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -41,7 +41,23 @@ class AbstractVectorFactory(ABC): class Vector: def __init__(self, dataset: Dataset, attributes: list | None = None): if attributes is None: - attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + # `is_summary` and `original_chunk_id` are stored on summary vectors + # by `SummaryIndexService` and read back by `RetrievalService` to + # route summary hits through their original parent chunks. They + # must be listed here so vector backends that use this list as an + # explicit return-properties projection (notably Weaviate) actually + # return those fields; without them, summary hits silently + # collapse into `is_summary = False` branches and the summary + # retrieval path is a no-op. See #34884. + attributes = [ + "doc_id", + "dataset_id", + "document_id", + "doc_hash", + "doc_type", + "is_summary", + "original_chunk_id", + ] self._dataset = dataset self._embeddings = self._get_embeddings() self._attributes = attributes diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index 813a84cbbd..aded5315bd 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -6,7 +6,7 @@ from collections.abc import Mapping from typing import Any from flask import current_app -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -63,11 +63,11 @@ class IndexProcessor: summary_index_setting: SummaryIndexSettingDict | None = None, ) -> IndexingResultDict: with session_factory.create_session() as session: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not document: raise KnowledgeIndexNodeError(f"Document {document_id} not found.") - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") @@ -104,12 +104,12 @@ class IndexProcessor: document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.word_count = ( - session.query(func.sum(DocumentSegment.word_count)) - .where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, + session.scalar( + select(func.sum(DocumentSegment.word_count)).where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) ) - .scalar() ) or 0 # Update need_summary based on dataset's summary_index_setting if summary_index_setting and summary_index_setting.get("enable") is True: @@ -118,15 +118,17 @@ class IndexProcessor: document.need_summary = False session.add(document) # update document segment status - session.query(DocumentSegment).where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, - ).update( - { - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - } + session.execute( + update(DocumentSegment) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) + .values( + status="completed", + enabled=True, + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + ) ) result: IndexingResultDict = { @@ -151,11 +153,11 @@ class IndexProcessor: doc_language = None with session_factory.create_session() as session: if document_id: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) else: document = None - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 4a731bf277..a487c49053 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -3,8 +3,7 @@ import logging import re import uuid -from collections.abc import Mapping -from typing import Any, cast +from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) @@ -55,6 +54,12 @@ from services.summary_index_service import SummaryIndexService _file_access_controller = DatabaseFileAccessController() +class ParagraphFormatPreviewDict(TypedDict): + chunk_structure: str + preview: list[dict[str, Any]] + total_segments: int + + class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword = Keyword(dataset) keyword.add_texts(documents) - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict: if isinstance(chunks, list): preview = [] for content in chunks: preview.append({"content": content}) - return { + result: ParagraphFormatPreviewDict = { "chunk_structure": IndexStructureType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks), } + return result else: raise ValueError("Chunks is not a list") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 53596b5de8..ba277d5018 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -3,8 +3,7 @@ import json import logging import uuid -from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict from sqlalchemy import delete, select @@ -36,6 +35,13 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) +class ParentChildFormatPreviewDict(TypedDict): + chunk_structure: str + parent_mode: str + preview: list[dict[str, Any]] + total_segments: int + + class ParentChildIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -153,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if node_ids: # Find segments by index_node_id with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment) - .filter( + segments = session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) @@ -351,17 +355,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict: parent_childs = ParentChildStructureChunk.model_validate(chunks) preview = [] for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) - return { + result: ParentChildFormatPreviewDict = { "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "parent_mode": parent_childs.parent_mode, "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), } + return result def generate_summary_preview( self, diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 273ea0f852..d3f311b08e 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,11 +4,11 @@ import logging import re import threading import uuid -from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict import pandas as pd from flask import Flask, current_app +from sqlalchemy import select from werkzeug.datastructures import FileStorage from core.db.session_factory import session_factory @@ -36,6 +36,12 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) +class QAFormatPreviewDict(TypedDict): + chunk_structure: str + qa_preview: list[dict[str, Any]] + total_segments: int + + class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -158,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor): if node_ids: # Find segments by index_node_id with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment) - .filter( + segments = session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) @@ -230,16 +234,17 @@ class QAIndexProcessor(BaseIndexProcessor): else: raise ValueError("Indexing technique must be high quality.") - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> QAFormatPreviewDict: qa_chunks = QAStructureChunk.model_validate(chunks) preview = [] for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) - return { + result: QAFormatPreviewDict = { "chunk_structure": IndexStructureType.QA_INDEX, "qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks), } + return result def generate_summary_preview( self, diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 8283be19f9..a8d37845a5 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,7 +1,7 @@ import base64 from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import RerankResult +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType @@ -123,7 +123,7 @@ class RerankModelRunner(BaseRerankRunner): :param query_type: query type :return: rerank result """ - docs = [] + docs: list[MultimodalRerankInput] = [] doc_ids = set() unique_documents = [] for document in documents: @@ -138,26 +138,28 @@ class RerankModelRunner(BaseRerankRunner): if upload_file: blob = storage.load_once(upload_file.key) document_file_base64 = base64.b64encode(blob).decode() - document_file_dict = { - "content": document_file_base64, - "content_type": document.metadata["doc_type"], - } - docs.append(document_file_dict) + docs.append( + MultimodalRerankInput( + content=document_file_base64, + content_type=document.metadata["doc_type"], + ) + ) else: - document_text_dict = { - "content": document.page_content, - "content_type": document.metadata.get("doc_type") or DocType.TEXT, - } - docs.append(document_text_dict) + docs.append( + MultimodalRerankInput( + content=document.page_content, + content_type=document.metadata.get("doc_type") or DocType.TEXT, + ) + ) doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) elif document.provider == "external": if document not in unique_documents: docs.append( - { - "content": document.page_content, - "content_type": document.metadata.get("doc_type") or DocType.TEXT, - } + MultimodalRerankInput( + content=document.page_content, + content_type=document.metadata.get("doc_type") or DocType.TEXT, + ) ) unique_documents.append(document) @@ -171,12 +173,12 @@ class RerankModelRunner(BaseRerankRunner): if upload_file: blob = storage.load_once(upload_file.key) file_query = base64.b64encode(blob).decode() - file_query_dict = { - "content": file_query, - "content_type": DocType.IMAGE, - } + file_query_input = MultimodalRerankInput( + content=file_query, + content_type=DocType.IMAGE, + ) rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( - query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n + query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents else: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 0f3351fd68..b681ff5db1 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -14,7 +14,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMU from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from sqlalchemy import and_, func, literal, or_, select +from sqlalchemy import and_, func, literal, or_, select, update from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import ( @@ -276,8 +276,8 @@ class DatasetRetrieval: document_ids = [i.segment.document_id for i in records] with session_factory.create_session() as session: - datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() - documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all() + datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() + documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all() dataset_map = {i.id: i for i in datasets} document_map = {i.id: i for i in documents} @@ -971,9 +971,11 @@ class DatasetRetrieval: # Batch update hit_count for all segments if segment_ids_to_update: - session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, + session.execute( + update(DocumentSegment) + .where(DocumentSegment.id.in_(segment_ids_to_update)) + .values(hit_count=DocumentSegment.hit_count + 1) + .execution_options(synchronize_session=False) ) self._send_trace_task(message_id, documents, timer) @@ -1822,7 +1824,7 @@ class DatasetRetrieval: def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]: with session_factory.create_session() as session: subquery = ( - session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count")) + select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count")) .where( DocumentModel.indexing_status == "completed", DocumentModel.enabled == True, @@ -1834,13 +1836,12 @@ class DatasetRetrieval: .subquery() ) - results = ( - session.query(Dataset) + results = session.scalars( + select(Dataset) .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) .where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids)) .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) - .all() - ) + ).all() available_datasets = [] for dataset in results: diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 6f120bd471..bff5f85dec 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -1,6 +1,8 @@ import concurrent.futures import logging +from sqlalchemy import select + from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict @@ -21,7 +23,7 @@ class SummaryIndex: ) -> None: if is_preview: with session_factory.create_session() as session: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return @@ -34,32 +36,31 @@ class SummaryIndex: if not document_id: return - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) # Skip qa_model documents if document is None or document.doc_form == "qa_model": return - query = session.query(DocumentSegment).filter_by( - dataset_id=dataset_id, - document_id=document_id, - status="completed", - enabled=True, - ) - segments = query.all() + segments = session.scalars( + select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + ).all() segment_ids = [segment.id for segment in segments] if not segment_ids: return - existing_summaries = ( - session.query(DocumentSegmentSummary) - .filter( + existing_summaries = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset_id, DocumentSegmentSummary.status == "completed", ) - .all() - ) + ).all() completed_summary_segment_ids = {i.chunk_id for i in existing_summaries} # Preview mode should process segments that are MISSING completed summaries pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids] @@ -73,7 +74,7 @@ class SummaryIndex: def process_segment(segment_id: str) -> None: """Process a single segment in a thread with a fresh DB session.""" with session_factory.create_session() as session: - segment = session.query(DocumentSegment).filter_by(id=segment_id).first() + segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)) if segment is None: return try: diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 31e879add2..b4253652f9 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -450,6 +450,12 @@ class WorkflowToolParameterConfiguration(BaseModel): form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") +class ToolInvokeMetaDict(TypedDict): + time_cost: float + error: str | None + tool_config: dict[str, Any] | None + + class ToolInvokeMeta(BaseModel): """ Tool invoke meta @@ -473,12 +479,13 @@ class ToolInvokeMeta(BaseModel): """ return cls(time_cost=0.0, error=error, tool_config={}) - def to_dict(self): - return { + def to_dict(self) -> ToolInvokeMetaDict: + result: ToolInvokeMetaDict = { "time_cost": self.time_cost, "error": self.error, "tool_config": self.tool_config, } + return result class ToolLabel(BaseModel): diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 685d687d8c..d1e333f502 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -262,6 +262,8 @@ class ToolEngine: ensure_ascii=False, ) ) + elif response.type == ToolInvokeMessage.MessageType.VARIABLE: + continue else: parts.append(str(response.message)) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index a59d167a0a..d8674b3af9 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -6,7 +6,6 @@ import os import time from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Union from uuid import uuid4 import httpx @@ -158,7 +157,7 @@ class ToolFileManager: return tool_file - def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary(self, id: str) -> tuple[bytes, str] | None: """ get file binary @@ -176,7 +175,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None: """ get file binary diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2593e381cf..be13d40f3e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import sqlalchemy as sa from graphon.runtime import VariablePool @@ -100,7 +100,7 @@ class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} _builtin_providers_loaded = False - _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + _builtin_tools_labels: dict[str, I18nObject | None] = {} @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: @@ -190,7 +190,7 @@ class ToolManager: invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, - ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: + ) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool: """ get the tool runtime @@ -398,7 +398,7 @@ class ToolManager: agent_tool: AgentToolEntity, user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: "VariablePool | None" = None, ) -> Tool: """ get the agent tool runtime @@ -442,7 +442,7 @@ class ToolManager: workflow_tool: WorkflowToolRuntimeSpec, user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: "VariablePool | None" = None, ) -> Tool: """ get the workflow tool runtime @@ -634,7 +634,7 @@ class ToolManager: cls._builtin_providers_loaded = False @classmethod - def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: + def get_tool_label(cls, tool_name: str) -> I18nObject | None: """ get the tool label @@ -682,7 +682,7 @@ class ToolManager: with Session(db.engine, autoflush=False) as session: ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] - return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() + return list(session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)))) @classmethod def list_providers_from_api( @@ -993,7 +993,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str: + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str: try: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) @@ -1001,7 +1001,7 @@ class ToolManager: mcp_provider = mcp_service.get_provider_entity( provider_id=provider_id, tenant_id=tenant_id, by_server_id=True ) - return mcp_provider.provider_icon + return cast(EmojiIconDict | str, mcp_provider.provider_icon) except ValueError: raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") except Exception: @@ -1013,7 +1013,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> str | EmojiIconDict | dict[str, str]: + ) -> str | EmojiIconDict: """ get the tool icon @@ -1052,7 +1052,7 @@ class ToolManager: def _convert_tool_parameters_type( cls, parameters: list[ToolParameter], - variable_pool: Optional["VariablePool"], + variable_pool: "VariablePool | None", tool_configurations: Mapping[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index bb5b3ba76e..2264981abd 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -118,7 +118,8 @@ class ToolFileMessageTransformer: if not isinstance(message.message, ToolInvokeMessage.BlobMessage): raise ValueError("unexpected message type") - assert isinstance(message.message.blob, bytes) + if not isinstance(message.message.blob, bytes): + raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}") tool_file_manager = ToolFileManager() tool_file = tool_file_manager.create_file_by_raw( user_id=user_id, diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index c4b7d57449..2159eb8638 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -17,10 +17,8 @@ class WorkflowToolConfigurationUtils: """ nodes = graph.get("nodes", []) start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) - if not start_node: return [] - return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])] @classmethod diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 1b3ccd1207..86b0550187 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -5,12 +5,30 @@ from typing import Any import pytz # type: ignore[import-untyped] from celery import Celery, Task from celery.schedules import crontab +from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp -def get_celery_ssl_options() -> dict[str, Any] | None: +class _CelerySentinelKwargsDict(TypedDict): + socket_timeout: float | None + password: str | None + + +class CelerySentinelTransportDict(TypedDict): + master_name: str | None + sentinel_kwargs: _CelerySentinelKwargsDict + + +class CelerySSLOptionsDict(TypedDict): + ssl_cert_reqs: int + ssl_ca_certs: str | None + ssl_certfile: str | None + ssl_keyfile: str | None + + +def get_celery_ssl_options() -> CelerySSLOptionsDict | None: """Get SSL configuration for Celery broker/backend connections.""" # Only apply SSL if we're using Redis as broker/backend if not dify_config.BROKER_USE_SSL: @@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None: ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) - ssl_options = { - "ssl_cert_reqs": ssl_cert_reqs, - "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, - "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, - "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, - } - - return ssl_options + return CelerySSLOptionsDict( + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS, + ssl_certfile=dify_config.REDIS_SSL_CERTFILE, + ssl_keyfile=dify_config.REDIS_SSL_KEYFILE, + ) -def get_celery_broker_transport_options() -> dict[str, Any]: +def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]: """Get broker transport options (e.g. Redis Sentinel) for Celery connections.""" if dify_config.CELERY_USE_SENTINEL: - return { - "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, - "sentinel_kwargs": { - "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, - "password": dify_config.CELERY_SENTINEL_PASSWORD, - }, - } + return CelerySentinelTransportDict( + master_name=dify_config.CELERY_SENTINEL_MASTER_NAME, + sentinel_kwargs=_CelerySentinelKwargsDict( + socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, + password=dify_config.CELERY_SENTINEL_PASSWORD, + ), + ) return {} diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index b9e592cadb..20f05b8b9e 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -14,6 +14,7 @@ from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.retry import Retry from redis.sentinel import Sentinel +from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp @@ -126,6 +127,41 @@ redis_client: RedisClientWrapper = RedisClientWrapper() _pubsub_redis_client: redis.Redis | RedisCluster | None = None +class RedisSSLParamsDict(TypedDict): + ssl_cert_reqs: int + ssl_ca_certs: str | None + ssl_certfile: str | None + ssl_keyfile: str | None + + +class RedisHealthParamsDict(TypedDict): + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + health_check_interval: int | None + + +class RedisClusterHealthParamsDict(TypedDict): + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + + +class RedisBaseParamsDict(TypedDict): + username: str | None + password: str | None + db: int + encoding: str + encoding_errors: str + decode_responses: bool + protocol: int + cache_config: CacheConfig | None + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + health_check_interval: int | None + + def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: """Get SSL configuration for Redis connection.""" if not dify_config.REDIS_USE_SSL: @@ -171,17 +207,17 @@ def _get_retry_policy() -> Retry: ) -def _get_connection_health_params() -> dict[str, Any]: +def _get_connection_health_params() -> RedisHealthParamsDict: """Get connection health and retry parameters for standalone and Sentinel Redis clients.""" - return { - "retry": _get_retry_policy(), - "socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT, - "socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, - "health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL, - } + return RedisHealthParamsDict( + retry=_get_retry_policy(), + socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT, + socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, + health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL, + ) -def _get_cluster_connection_health_params() -> dict[str, Any]: +def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict: """Get retry and timeout parameters for Redis Cluster clients. RedisCluster does not support ``health_check_interval`` as a constructor @@ -189,26 +225,31 @@ def _get_cluster_connection_health_params() -> dict[str, Any]: here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout`` are passed through. """ - params = _get_connection_health_params() - return {k: v for k, v in params.items() if k != "health_check_interval"} - - -def _get_base_redis_params() -> dict[str, Any]: - """Get base Redis connection parameters including retry and health policy.""" - return { - "username": dify_config.REDIS_USERNAME, - "password": dify_config.REDIS_PASSWORD or None, - "db": dify_config.REDIS_DB, - "encoding": "utf-8", - "encoding_errors": "strict", - "decode_responses": False, - "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, - "cache_config": _get_cache_configuration(), - **_get_connection_health_params(), + health_params = _get_connection_health_params() + result: RedisClusterHealthParamsDict = { + "retry": health_params["retry"], + "socket_timeout": health_params["socket_timeout"], + "socket_connect_timeout": health_params["socket_connect_timeout"], } + return result -def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: +def _get_base_redis_params() -> RedisBaseParamsDict: + """Get base Redis connection parameters including retry and health policy.""" + return RedisBaseParamsDict( + username=dify_config.REDIS_USERNAME, + password=dify_config.REDIS_PASSWORD or None, + db=dify_config.REDIS_DB, + encoding="utf-8", + encoding_errors="strict", + decode_responses=False, + protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, + cache_config=_get_cache_configuration(), + **_get_connection_health_params(), + ) + + +def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create Redis client using Sentinel configuration.""" if not dify_config.REDIS_SENTINELS: raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") @@ -232,7 +273,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, sentinel_kwargs=sentinel_kwargs, ) - master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) + params: dict[str, Any] = {**redis_params} + master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params) return master @@ -259,18 +301,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: return cluster -def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: +def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create standalone Redis client.""" connection_class, ssl_kwargs = _get_ssl_configuration() - params = {**redis_params} - params.update( - { - "host": dify_config.REDIS_HOST, - "port": dify_config.REDIS_PORT, - "connection_class": connection_class, - } - ) + params: dict[str, Any] = { + **redis_params, + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, + "connection_class": connection_class, + } if dify_config.REDIS_MAX_CONNECTIONS: params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS @@ -293,8 +333,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | kwargs["max_connections"] = max_conns return RedisCluster.from_url(pubsub_url, **kwargs) - health_params = _get_connection_health_params() - kwargs = {**health_params} + standalone_health_params: dict[str, Any] = dict(_get_connection_health_params()) + kwargs = {**standalone_health_params} if max_conns: kwargs["max_connections"] = max_conns return redis.Redis.from_url(pubsub_url, **kwargs) diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 1dd92caeae..ad83826427 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab handler = _get_handler_instance(handler_class or SpanHandler) tracer = get_tracer(__name__) - return handler.wrapper( - tracer=tracer, - wrapped=func, - args=args, - kwargs=kwargs, - ) + return handler.wrapper(tracer, func, *args, **kwargs) return cast(Callable[P, R], wrapper) diff --git a/api/extensions/otel/decorators/handler.py b/api/extensions/otel/decorators/handler.py index e465a615a6..b0d9fa7af6 100644 --- a/api/extensions/otel/decorators/handler.py +++ b/api/extensions/otel/decorators/handler.py @@ -1,8 +1,8 @@ import inspect -from collections.abc import Callable, Mapping +from collections.abc import Callable from typing import Any -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer class SpanHandler: @@ -16,9 +16,9 @@ class SpanHandler: exceptions. Handlers can override the wrapper method to customize behavior. """ - _signature_cache: dict[Callable[..., Any], inspect.Signature] = {} + _signature_cache: dict[Callable[..., object], inspect.Signature] = {} - def _build_span_name(self, wrapped: Callable[..., Any]) -> str: + def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str: """ Build the span name from the wrapped function. @@ -29,11 +29,11 @@ class SpanHandler: """ return f"{wrapped.__module__}.{wrapped.__qualname__}" - def _extract_arguments[T]( + def _extract_arguments[**P, R]( self, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, ) -> dict[str, Any] | None: """ Extract function arguments using inspect.signature. @@ -59,13 +59,13 @@ class SpanHandler: except Exception: return None - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: """ Fully control the wrapper behavior. diff --git a/api/extensions/otel/decorators/handlers/generate_handler.py b/api/extensions/otel/decorators/handlers/generate_handler.py index cc6c75304f..df5142c310 100644 --- a/api/extensions/otel/decorators/handlers/generate_handler.py +++ b/api/extensions/otel/decorators/handlers/generate_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -15,15 +14,15 @@ logger = logging.getLogger(__name__) class AppGenerateHandler(SpanHandler): """Span handler for ``AppGenerateService.generate``.""" - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py index 8abd60197c..6b2112ceb2 100644 --- a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py +++ b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -14,15 +13,15 @@ logger = logging.getLogger(__name__) class WorkflowAppRunnerHandler(SpanHandler): """Span handler for ``WorkflowAppRunner.run``.""" - def wrapper( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., Any], - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - ) -> Any: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/libs/db_migration_lock.py b/api/libs/db_migration_lock.py index 1d3a81e0a2..ca8956e397 100644 --- a/api/libs/db_migration_lock.py +++ b/api/libs/db_migration_lock.py @@ -14,9 +14,15 @@ from __future__ import annotations import logging import threading -from typing import Any +from typing import TYPE_CHECKING, Any +import redis +from redis.cluster import RedisCluster from redis.exceptions import LockNotOwnedError, RedisError +from redis.lock import Lock + +if TYPE_CHECKING: + from extensions.ext_redis import RedisClientWrapper logger = logging.getLogger(__name__) @@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock: primary error/exit code. """ - _redis_client: Any + _redis_client: redis.Redis | RedisCluster | RedisClientWrapper _name: str _ttl_seconds: float _renew_interval_seconds: float _log_context: str | None _logger: logging.Logger - _lock: Any + _lock: Lock | None _stop_event: threading.Event | None _thread: threading.Thread | None _acquired: bool def __init__( self, - redis_client: Any, + redis_client: redis.Redis | RedisCluster | RedisClientWrapper, name: str, ttl_seconds: float = 60, renew_interval_seconds: float | None = None, @@ -127,7 +133,7 @@ class DbMigrationAutoRenewLock: ) self._thread.start() - def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None: + def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None: while not stop_event.wait(self._renew_interval_seconds): try: lock.reacquire() diff --git a/api/libs/external_api.py b/api/libs/external_api.py index e8592407c3..f907d17750 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -17,7 +17,6 @@ def http_status_message(code): def register_external_error_handlers(api: Api): - @api.errorhandler(HTTPException) def handle_http_exception(e: HTTPException): got_request_exception.send(current_app, exception=e) @@ -74,27 +73,18 @@ def register_external_error_handlers(api: Api): headers["Set-Cookie"] = build_force_logout_cookie_headers() return data, status_code, headers - _ = handle_http_exception - - @api.errorhandler(ValueError) def handle_value_error(e: ValueError): got_request_exception.send(current_app, exception=e) status_code = 400 data = {"code": "invalid_param", "message": str(e), "status": status_code} return data, status_code - _ = handle_value_error - - @api.errorhandler(AppInvokeQuotaExceededError) def handle_quota_exceeded(e: AppInvokeQuotaExceededError): got_request_exception.send(current_app, exception=e) status_code = 429 data = {"code": "too_many_requests", "message": str(e), "status": status_code} return data, status_code - _ = handle_quota_exceeded - - @api.errorhandler(Exception) def handle_general_exception(e: Exception): got_request_exception.send(current_app, exception=e) @@ -113,7 +103,10 @@ def register_external_error_handlers(api: Api): return data, status_code - _ = handle_general_exception + api.errorhandler(HTTPException)(handle_http_exception) + api.errorhandler(ValueError)(handle_value_error) + api.errorhandler(AppInvokeQuotaExceededError)(handle_quota_exceeded) + api.errorhandler(Exception)(handle_general_exception) class ExternalApi(Api): diff --git a/api/libs/helper.py b/api/libs/helper.py index ece53e8806..e7decd43b3 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -10,7 +10,7 @@ import uuid from collections.abc import Callable, Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast from uuid import UUID from zoneinfo import available_timezones @@ -81,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str: return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") -def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: +def extract_tenant_id(user: "Account | EndUser") -> str | None: """ Extract tenant_id from Account or EndUser object. @@ -164,7 +164,10 @@ def email(email): EmailStr = Annotated[str, AfterValidator(email)] -def uuid_value(value: Any) -> str: +def uuid_value(value: str | UUID) -> str: + if isinstance(value, UUID): + return str(value) + if value == "": return str(value) @@ -405,7 +408,7 @@ class TokenManager: def generate_token( cls, token_type: str, - account: Optional["Account"] = None, + account: "Account | None" = None, email: str | None = None, additional_data: dict | None = None, ) -> str: @@ -465,9 +468,7 @@ class TokenManager: return current_token @classmethod - def _set_current_token_for_account( - cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float] - ): + def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float): key = cls._get_account_token_key(account_id, token_type) expiry_seconds = int(expiry_minutes * 60) redis_client.setex(key, expiry_seconds, token) diff --git a/api/libs/pyrefly_type_coverage.py b/api/libs/pyrefly_type_coverage.py new file mode 100644 index 0000000000..369b8dff3c --- /dev/null +++ b/api/libs/pyrefly_type_coverage.py @@ -0,0 +1,145 @@ +"""Helpers for generating type-coverage summaries from pyrefly report output.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import TypedDict + + +class CoverageSummary(TypedDict): + n_modules: int + n_typable: int + n_typed: int + n_any: int + n_untyped: int + coverage: float + strict_coverage: float + + +_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__) + +_EMPTY_SUMMARY: CoverageSummary = { + "n_modules": 0, + "n_typable": 0, + "n_typed": 0, + "n_any": 0, + "n_untyped": 0, + "coverage": 0.0, + "strict_coverage": 0.0, +} + + +def parse_summary(report_json: str) -> CoverageSummary: + """Extract the summary section from ``pyrefly report`` JSON output. + + Returns an empty summary when *report_json* is empty or malformed so that + the CI workflow can degrade gracefully instead of crashing. + """ + if not report_json or not report_json.strip(): + return _EMPTY_SUMMARY.copy() + + try: + data = json.loads(report_json) + except json.JSONDecodeError: + return _EMPTY_SUMMARY.copy() + + summary = data.get("summary") + if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary): + return _EMPTY_SUMMARY.copy() + + return { + "n_modules": summary["n_modules"], + "n_typable": summary["n_typable"], + "n_typed": summary["n_typed"], + "n_any": summary["n_any"], + "n_untyped": summary["n_untyped"], + "coverage": summary["coverage"], + "strict_coverage": summary["strict_coverage"], + } + + +def format_summary_markdown(summary: CoverageSummary) -> str: + """Format a single coverage summary as a Markdown table.""" + + return ( + "| Metric | Value |\n" + "| --- | ---: |\n" + f"| Modules | {summary['n_modules']} |\n" + f"| Typable symbols | {summary['n_typable']:,} |\n" + f"| Typed symbols | {summary['n_typed']:,} |\n" + f"| Untyped symbols | {summary['n_untyped']:,} |\n" + f"| Any symbols | {summary['n_any']:,} |\n" + f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n" + f"| Strict coverage | {summary['strict_coverage']:.2f}% |" + ) + + +def format_comparison_markdown( + base: CoverageSummary, + pr: CoverageSummary, +) -> str: + """Format a comparison between base and PR coverage as Markdown.""" + + coverage_delta = pr["coverage"] - base["coverage"] + strict_delta = pr["strict_coverage"] - base["strict_coverage"] + typed_delta = pr["n_typed"] - base["n_typed"] + untyped_delta = pr["n_untyped"] - base["n_untyped"] + + def _fmt_delta(value: float, fmt: str = ".2f") -> str: + sign = "+" if value > 0 else "" + return f"{sign}{value:{fmt}}" + + lines = [ + "| Metric | Base | PR | Delta |", + "| --- | ---: | ---: | ---: |", + (f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"), + ( + f"| Strict coverage | {base['strict_coverage']:.2f}% " + f"| {pr['strict_coverage']:.2f}% " + f"| {_fmt_delta(strict_delta)}% |" + ), + (f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"), + (f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"), + ( + f"| Modules | {base['n_modules']} " + f"| {pr['n_modules']} " + f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |" + ), + ] + return "\n".join(lines) + + +def main() -> int: + """Read pyrefly report JSON from stdin and print a Markdown summary. + + Accepts an optional ``--base `` argument. When provided, the output + includes a base-vs-PR comparison table. + """ + + args = sys.argv[1:] + + base_file: str | None = None + if "--base" in args: + idx = args.index("--base") + if idx + 1 >= len(args): + sys.stderr.write("error: --base requires a file path\n") + return 1 + base_file = args[idx + 1] + + pr_report = sys.stdin.read() + pr_summary = parse_summary(pr_report) + + if base_file is not None: + base_text = Path(base_file).read_text() if Path(base_file).exists() else "" + base_summary = parse_summary(base_text) + sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n") + else: + sys.stdout.write(format_summary_markdown(pr_summary) + "\n") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/api/models/base.py b/api/models/base.py index b7023b9c8b..5acdf184f4 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase): class DefaultFieldsMixin: + """Mixin for models that inherit from Base (non-dataclass).""" + id: Mapped[str] = mapped_column( StringUUID, primary_key=True, @@ -53,6 +55,42 @@ class DefaultFieldsMixin: return f"<{self.__class__.__name__}(id={self.id})>" +class DefaultFieldsDCMixin(MappedAsDataclass): + """Mixin for models that inherit from TypeBase (MappedAsDataclass).""" + + __abstract__ = True + + id: Mapped[str] = mapped_column( + StringUUID, + primary_key=True, + insert_default=lambda: str(uuidv7()), + default_factory=lambda: str(uuidv7()), + init=False, + ) + + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + insert_default=naive_utc_now, + default_factory=naive_utc_now, + init=False, + server_default=func.current_timestamp(), + ) + + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + insert_default=naive_utc_now, + default_factory=naive_utc_now, + init=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(id={self.id})>" + + def gen_uuidv4_string() -> str: """gen_uuidv4_string generate a UUIDv4 string. diff --git a/api/models/dataset.py b/api/models/dataset.py index 97604848af..a8ed821c3a 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -108,6 +108,56 @@ class ExternalKnowledgeApiDict(TypedDict): created_at: str +class DocumentDict(TypedDict): + id: str + tenant_id: str + dataset_id: str + position: int + data_source_type: str + data_source_info: str | None + dataset_process_rule_id: str | None + batch: str + name: str + created_from: str + created_by: str + created_api_request_id: str | None + created_at: datetime + processing_started_at: datetime | None + file_id: str | None + word_count: int | None + parsing_completed_at: datetime | None + cleaning_completed_at: datetime | None + splitting_completed_at: datetime | None + tokens: int | None + indexing_latency: float | None + completed_at: datetime | None + is_paused: bool | None + paused_by: str | None + paused_at: datetime | None + error: str | None + stopped_at: datetime | None + indexing_status: str + enabled: bool + disabled_at: datetime | None + disabled_by: str | None + archived: bool + archived_reason: str | None + archived_by: str | None + archived_at: datetime | None + updated_at: datetime + doc_type: str | None + doc_metadata: Any + doc_form: IndexStructureType + doc_language: str | None + display_status: str | None + data_source_info_dict: dict[str, Any] + average_segment_length: int + dataset_process_rule: ProcessRuleDict | None + dataset: None + segment_count: int | None + hit_count: int | None + + class DatasetPermissionEnum(enum.StrEnum): ONLY_ME = "only_me" ALL_TEAM = "all_team_members" @@ -303,13 +353,17 @@ class Dataset(Base): if self.provider != "external": return None external_knowledge_binding = db.session.scalar( - select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id) + select(ExternalKnowledgeBindings).where( + ExternalKnowledgeBindings.dataset_id == self.id, + ExternalKnowledgeBindings.tenant_id == self.tenant_id, + ) ) if not external_knowledge_binding: return None external_knowledge_api = db.session.scalar( select(ExternalKnowledgeApis).where( - ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id + ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id, + ExternalKnowledgeApis.tenant_id == self.tenant_id, ) ) if external_knowledge_api is None or external_knowledge_api.settings is None: @@ -675,8 +729,8 @@ class Document(Base): ) return built_in_fields - def to_dict(self) -> dict[str, Any]: - return { + def to_dict(self) -> DocumentDict: + result: DocumentDict = { "id": self.id, "tenant_id": self.tenant_id, "dataset_id": self.dataset_id, @@ -721,10 +775,11 @@ class Document(Base): "data_source_info_dict": self.data_source_info_dict, "average_segment_length": self.average_segment_length, "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, - "dataset": None, # Dataset class doesn't have a to_dict method + "dataset": None, "segment_count": self.segment_count, "hit_count": self.hit_count, } + return result @classmethod def from_dict(cls, data: dict[str, Any]): diff --git a/api/models/model.py b/api/models/model.py index ece3ff8b87..0ea2259a19 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -674,28 +674,24 @@ class AppModelConfig(TypeBase): def suggested_questions_list(self) -> list[str]: return json.loads(self.suggested_questions) if self.suggested_questions else [] + def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig: + return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled}) + @property def suggested_questions_after_answer_dict(self) -> EnabledConfig: - return cast( - EnabledConfig, - json.loads(self.suggested_questions_after_answer) - if self.suggested_questions_after_answer - else {"enabled": False}, - ) + return self._get_enabled_config(self.suggested_questions_after_answer) @property def speech_to_text_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}) + return self._get_enabled_config(self.speech_to_text) @property def text_to_speech_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}) + return self._get_enabled_config(self.text_to_speech) @property def retriever_resource_dict(self) -> EnabledConfig: - return cast( - EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} - ) + return self._get_enabled_config(self.retriever_resource, default_enabled=True) @property def annotation_reply_dict(self) -> AnnotationReplyConfig: @@ -722,7 +718,7 @@ class AppModelConfig(TypeBase): @property def more_like_this_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}) + return self._get_enabled_config(self.more_like_this) @property def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig: @@ -902,7 +898,7 @@ class InstalledApp(TypeBase): return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) -class TrialApp(Base): +class TrialApp(TypeBase): __tablename__ = "trial_apps" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="trial_app_pkey"), @@ -911,18 +907,22 @@ class TrialApp(Base): sa.UniqueConstraint("app_id", name="unique_trail_app_id"), ) - id = mapped_column(StringUUID, default=gen_uuidv4_string) - app_id = mapped_column(StringUUID, nullable=False) - tenant_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - trial_limit = mapped_column(sa.Integer, nullable=False, default=3) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3) @property def app(self) -> App | None: return db.session.scalar(select(App).where(App.id == self.app_id)) -class AccountTrialAppRecord(Base): +class AccountTrialAppRecord(TypeBase): __tablename__ = "account_trial_app_records" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"), @@ -930,11 +930,15 @@ class AccountTrialAppRecord(Base): sa.Index("account_trial_app_record_app_id_idx", "app_id"), sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), ) - id = mapped_column(StringUUID, default=gen_uuidv4_string) - account_id = mapped_column(StringUUID, nullable=False) - app_id = mapped_column(StringUUID, nullable=False) - count = mapped_column(sa.Integer, nullable=False, default=0) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) @property def app(self) -> App | None: diff --git a/api/models/source.py b/api/models/source.py index a8addbe342..8078b32f8c 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,5 +1,6 @@ import json from datetime import datetime +from typing import Any, TypedDict from uuid import uuid4 import sqlalchemy as sa @@ -38,6 +39,17 @@ class DataSourceOauthBinding(TypeBase): disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False) +class DataSourceApiKeyAuthBindingDict(TypedDict): + id: str + tenant_id: str + category: str + provider: str + credentials: Any + created_at: float + updated_at: float + disabled: bool + + class DataSourceApiKeyAuthBinding(TypeBase): __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( @@ -65,8 +77,8 @@ class DataSourceApiKeyAuthBinding(TypeBase): ) disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False) - def to_dict(self): - return { + def to_dict(self) -> DataSourceApiKeyAuthBindingDict: + result: DataSourceApiKeyAuthBindingDict = { "id": self.id, "tenant_id": self.tenant_id, "category": self.category, @@ -76,3 +88,4 @@ class DataSourceApiKeyAuthBinding(TypeBase): "updated_at": self.updated_at.timestamp(), "disabled": self.disabled, } + return result diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index f71583c1cd..8b767779ce 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -66,12 +66,15 @@ def build_file_from_stored_mapping( record_id = resolve_file_record_id(mapping) transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) - if transfer_method == FileTransferMethod.TOOL_FILE and record_id: - mapping["tool_file_id"] = record_id - elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: - mapping["upload_file_id"] = record_id - elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: - mapping["datasource_file_id"] = record_id + match transfer_method: + case FileTransferMethod.TOOL_FILE if record_id: + mapping["tool_file_id"] = record_id + case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id: + mapping["upload_file_id"] = record_id + case FileTransferMethod.DATASOURCE_FILE if record_id: + mapping["datasource_file_id"] = record_id + case _: + pass if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: remote_url = mapping.get("remote_url") diff --git a/api/models/workflow.py b/api/models/workflow.py index 8e8d2e6fd9..bb4d6a7ec9 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -4,7 +4,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa @@ -121,7 +121,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": + def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType": """ Get workflow type from app mode. @@ -1051,7 +1051,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None": return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property diff --git a/api/pyproject.toml b/api/pyproject.toml index 086ce5bb72..3f70ec7bc7 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -4,76 +4,76 @@ version = "1.13.3" requires-python = "~=3.12.0" dependencies = [ - "aliyun-log-python-sdk~=0.9.37", + "aliyun-log-python-sdk~=0.9.44", "arize-phoenix-otel~=0.15.0", "azure-identity==1.25.3", "beautifulsoup4==4.14.3", - "boto3==1.42.83", + "boto3==1.42.88", "bs4~=0.0.1", - "cachetools~=5.3.0", - "celery~=5.6.2", - "charset-normalizer>=3.4.4", - "flask~=3.1.2", - "flask-compress>=1.17,<1.25", - "flask-cors~=6.0.0", + "cachetools~=7.0.5", + "celery~=5.6.3", + "charset-normalizer>=3.4.7", + "flask~=3.1.3", + "flask-compress>=1.24,<1.25", + "flask-cors~=6.0.2", "flask-login~=0.6.3", "flask-migrate~=4.1.0", "flask-orjson~=2.0.0", "flask-sqlalchemy~=3.1.1", - "gevent~=25.9.1", + "gevent~=26.4.0", "gmpy2~=2.3.0", - "google-api-core>=2.19.1", - "google-api-python-client==2.193.0", - "google-auth>=2.47.0", + "google-api-core>=2.30.3", + "google-api-python-client==2.194.0", + "google-auth>=2.49.2", "google-auth-httplib2==0.3.1", - "google-cloud-aiplatform>=1.123.0", - "googleapis-common-protos>=1.65.0", + "google-cloud-aiplatform>=1.147.0", + "googleapis-common-protos>=1.74.0", "graphon>=0.1.2", "gunicorn~=25.3.0", - "httpx[socks]~=0.28.0", + "httpx[socks]~=0.28.1", "jieba==0.42.1", - "json-repair>=0.55.1", - "langfuse>=3.0.0,<5.0.0", - "langsmith~=0.7.16", + "json-repair>=0.59.2", + "langfuse>=4.2.0,<5.0.0", + "langsmith~=0.7.30", "markdown~=3.10.2", - "mlflow-skinny>=3.0.0", - "numpy~=1.26.4", + "mlflow-skinny>=3.11.1", + "numpy~=2.4.4", "openpyxl~=3.1.5", - "opik~=1.10.37", + "opik~=1.11.2", "litellm==1.83.0", # Pinned to avoid madoka dependency issue - "opentelemetry-api==1.40.0", - "opentelemetry-distro==0.61b0", - "opentelemetry-exporter-otlp==1.40.0", - "opentelemetry-exporter-otlp-proto-common==1.40.0", - "opentelemetry-exporter-otlp-proto-grpc==1.40.0", - "opentelemetry-exporter-otlp-proto-http==1.40.0", - "opentelemetry-instrumentation==0.61b0", - "opentelemetry-instrumentation-celery==0.61b0", - "opentelemetry-instrumentation-flask==0.61b0", - "opentelemetry-instrumentation-httpx==0.61b0", - "opentelemetry-instrumentation-redis==0.61b0", - "opentelemetry-instrumentation-sqlalchemy==0.61b0", - "opentelemetry-propagator-b3==1.40.0", - "opentelemetry-proto==1.40.0", - "opentelemetry-sdk==1.40.0", - "opentelemetry-semantic-conventions==0.61b0", - "opentelemetry-util-http==0.61b0", - "pandas[excel,output-formatting,performance]~=3.0.1", + "opentelemetry-api==1.41.0", + "opentelemetry-distro==0.62b0", + "opentelemetry-exporter-otlp==1.41.0", + "opentelemetry-exporter-otlp-proto-common==1.41.0", + "opentelemetry-exporter-otlp-proto-grpc==1.41.0", + "opentelemetry-exporter-otlp-proto-http==1.41.0", + "opentelemetry-instrumentation==0.62b0", + "opentelemetry-instrumentation-celery==0.62b0", + "opentelemetry-instrumentation-flask==0.62b0", + "opentelemetry-instrumentation-httpx==0.62b0", + "opentelemetry-instrumentation-redis==0.62b0", + "opentelemetry-instrumentation-sqlalchemy==0.62b0", + "opentelemetry-propagator-b3==1.41.0", + "opentelemetry-proto==1.41.0", + "opentelemetry-sdk==1.41.0", + "opentelemetry-semantic-conventions==0.62b0", + "opentelemetry-util-http==0.62b0", + "pandas[excel,output-formatting,performance]~=3.0.2", "psycogreen~=1.0.2", - "psycopg2-binary~=2.9.6", + "psycopg2-binary~=2.9.11", "pycryptodome==3.23.0", "pydantic~=2.12.5", "pydantic-settings~=2.13.1", - "pyjwt~=2.12.0", + "pyjwt~=2.12.1", "pypdfium2==5.6.0", "python-docx~=1.2.0", "python-dotenv==1.2.2", "pyyaml~=6.0.1", "readabilipy~=0.3.0", "redis[hiredis]~=7.4.0", - "resend~=2.26.0", - "sentry-sdk[flask]~=2.55.0", - "sqlalchemy~=2.0.29", + "resend~=2.27.0", + "sentry-sdk[flask]~=2.57.0", + "sqlalchemy~=2.0.49", "starlette==1.0.0", "tiktoken~=0.12.0", "transformers~=5.3.0", @@ -82,12 +82,12 @@ dependencies = [ "yarl~=1.23.0", "sseclient-py~=1.9.0", "httpx-sse~=0.4.0", - "sendgrid~=6.12.3", + "sendgrid~=6.12.5", "flask-restx~=1.3.2", - "packaging~=23.2", - "croniter>=6.0.0", - "weaviate-client==4.20.4", - "apscheduler>=3.11.0", + "packaging~=26.0", + "croniter>=6.2.2", + "weaviate-client==4.20.5", + "apscheduler>=3.11.2", "weave>=0.52.16", "fastopenapi[flask]>=0.7.0", "bleach~=6.3.0", @@ -111,16 +111,16 @@ package = false dev = [ "coverage~=7.13.4", "dotenv-linter~=0.7.0", - "faker~=40.12.0", + "faker~=40.13.0", "lxml-stubs~=0.5.1", "basedpyright~=1.39.0", - "ruff~=0.15.5", - "pytest~=9.0.2", + "ruff~=0.15.10", + "pytest~=9.0.3", "pytest-benchmark~=5.2.3", "pytest-cov~=7.1.0", "pytest-env~=1.6.0", "pytest-mock~=3.15.1", - "testcontainers~=4.14.1", + "testcontainers~=4.14.2", "types-aiofiles~=25.1.0", "types-beautifulsoup4~=4.12.0", "types-cachetools~=6.2.0", @@ -130,8 +130,8 @@ dev = [ "types-docutils~=0.22.3", "types-flask-cors~=6.0.0", "types-flask-migrate~=4.1.0", - "types-gevent~=25.9.0", - "types-greenlet~=3.3.0", + "types-gevent~=26.4.0", + "types-greenlet~=3.4.0", "types-html5lib~=1.1.11", "types-markdown~=3.10.2", "types-oauthlib~=3.3.0", @@ -149,20 +149,20 @@ dev = [ "types-pyyaml~=6.0.12", "types-regex~=2026.4.4", "types-shapely~=2.1.0", - "types-simplejson>=3.20.0", - "types-six>=1.17.0", - "types-tensorflow>=2.18.0", - "types-tqdm>=4.67.0", + "types-simplejson>=3.20.0.20260408", + "types-six>=1.17.0.20260408", + "types-tensorflow>=2.18.0.20260408", + "types-tqdm>=4.67.3.20260408", "types-ujson>=5.10.0", - "boto3-stubs>=1.38.20", - "types-jmespath>=1.0.2.20240106", - "hypothesis>=6.131.15", + "boto3-stubs>=1.42.88", + "types-jmespath>=1.1.0.20260408", + "hypothesis>=6.151.12", "types_pyOpenSSL>=24.1.0", - "types_cffi>=1.17.0", - "types_setuptools>=80.9.0", + "types_cffi>=2.0.0.20260408", + "types_setuptools>=82.0.0.20260408", "pandas-stubs~=3.0.0", "scipy-stubs>=1.15.3.0", - "types-python-http-client>=3.3.7.20240910", + "types-python-http-client>=3.3.7.20260408", "import-linter>=2.3", "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", @@ -180,10 +180,10 @@ dev = [ ############################################################ storage = [ "azure-storage-blob==12.28.0", - "bce-python-sdk~=0.9.23", + "bce-python-sdk~=0.9.69", "cos-python-sdk-v5==1.9.41", "esdk-obs-python==3.26.2", - "google-cloud-storage>=3.0.0", + "google-cloud-storage>=3.10.1", "opendal~=0.46.0", "oss2==2.19.1", "supabase~=2.18.1", @@ -209,23 +209,23 @@ vdb = [ "elasticsearch==8.14.0", "opensearch-py==3.1.0", "oracledb==3.4.2", - "pgvecto-rs[sqlalchemy]~=0.2.1", + "pgvecto-rs[sqlalchemy]~=0.2.2", "pgvector==0.4.2", - "pymilvus~=2.6.10", + "pymilvus~=2.6.12", "pymochow==2.4.0", "pyobvector~=0.2.17", "qdrant-client==1.9.0", "intersystems-irispython>=5.1.0", - "tablestore==6.4.3", + "tablestore==6.4.4", "tcvectordb~=2.1.0", "tidb-vector==0.0.15", "upstash-vector==0.8.0", "volcengine-compat~=1.0.0", - "weaviate-client==4.20.4", + "weaviate-client==4.20.5", "xinference-client~=2.4.0", "mo-vector~=0.1.13", "mysql-connector-python>=9.3.0", - "holo-search-sdk>=0.4.1", + "holo-search-sdk>=0.4.2", ] [tool.pyrefly] diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index a8b884ea81..424563bc11 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -47,7 +47,6 @@ "reportMissingTypeArgument": "hint", "reportUnnecessaryComparison": "hint", "reportUnnecessaryIsInstance": "hint", - "reportUntypedFunctionDecorator": "hint", "reportUnnecessaryTypeIgnoreComment": "hint", "reportAttributeAccessIssue": "hint", "pythonVersion": "3.12", diff --git a/api/services/account_service.py b/api/services/account_service.py index 4b58b3b697..ccc4a7c1fa 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast from pydantic import BaseModel, TypeAdapter from sqlalchemy import delete, func, select, update -from sqlalchemy.orm import Session, sessionmaker + +from core.db.session_factory import session_factory class InvitationData(TypedDict): @@ -800,19 +801,19 @@ class AccountService: return token @staticmethod - def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + def get_account_by_email_with_case_fallback(email: str) -> Account | None: """ Retrieve an account by email and fall back to the lowercase email if the original lookup fails. This keeps backward compatibility for older records that stored uppercase emails while the rest of the system gradually normalizes new inputs. """ - query_session = session or db.session - account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account + with session_factory.create_session() as session: + account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() + if account or email == email.lower(): + return account - return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: @@ -1516,8 +1517,7 @@ class RegisterService: check_workspace_member_invite_permission(tenant.id) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: TenantService.check_member_permission(tenant, inviter, None, "add") diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index ae5facbec0..ff0882ad5c 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,11 +1,8 @@ import logging import uuid - -import pandas as pd - -logger = logging.getLogger(__name__) from typing import TypedDict +import pandas as pd from sqlalchemy import delete, or_, select, update from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task +logger = logging.getLogger(__name__) + class AnnotationJobStatusDict(TypedDict): job_id: str @@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict): enabled: bool +class EnableAnnotationArgs(TypedDict): + """Expected shape of the args dict passed to enable_app_annotation.""" + + score_threshold: float + embedding_provider_name: str + embedding_model_name: str + + +class UpsertAnnotationArgs(TypedDict, total=False): + """Expected shape of the args dict passed to up_insert_app_annotation_from_message.""" + + answer: str + content: str + message_id: str + question: str + + +class InsertAnnotationArgs(TypedDict): + """Expected shape of the args dict passed to insert_app_annotation_directly.""" + + question: str + answer: str + + +class UpdateAnnotationArgs(TypedDict, total=False): + """Expected shape of the args dict passed to update_app_annotation_directly. + + Both fields are optional at the type level; the service validates at runtime + and raises ValueError if either is missing. + """ + + answer: str + question: str + + +class UpdateAnnotationSettingArgs(TypedDict): + """Expected shape of the args dict passed to update_app_annotation_setting.""" + + score_threshold: float + + class AppAnnotationService: @classmethod - def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: + def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() app = db.session.scalar( @@ -62,8 +102,9 @@ class AppAnnotationService: if answer is None: raise ValueError("Either 'answer' or 'content' must be provided") - if args.get("message_id"): - message_id = str(args["message_id"]) + raw_message_id = args.get("message_id") + if raw_message_id: + message_id = str(raw_message_id) message = db.session.scalar( select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1) ) @@ -87,9 +128,10 @@ class AppAnnotationService: account_id=current_user.id, ) else: - question = args.get("question") - if not question: + maybe_question = args.get("question") + if not maybe_question: raise ValueError("'question' is required when 'message_id' is not provided") + question = maybe_question annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id) db.session.add(annotation) @@ -110,7 +152,7 @@ class AppAnnotationService: return annotation @classmethod - def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict: + def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict: enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: @@ -217,7 +259,7 @@ class AppAnnotationService: return annotations @classmethod - def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: + def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() app = db.session.scalar( @@ -251,7 +293,7 @@ class AppAnnotationService: return annotation @classmethod - def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): + def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str): # get app info _, current_tenant_id = current_account_with_tenant() app = db.session.scalar( @@ -270,7 +312,11 @@ class AppAnnotationService: if question is None: raise ValueError("'question' is required") - annotation.content = args["answer"] + answer = args.get("answer") + if answer is None: + raise ValueError("'answer' is required") + + annotation.content = answer annotation.question = question db.session.commit() @@ -613,7 +659,7 @@ class AppAnnotationService: @classmethod def update_app_annotation_setting( - cls, app_id: str, annotation_setting_id: str, args: dict + cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs ) -> AnnotationSettingDict: current_user, current_tenant_id = current_account_with_tenant() # get app info diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index c6c8a15109..40e1e5f8ab 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -467,61 +467,67 @@ class AppDslService: ) # Initialize app based on mode - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - workflow_data = data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise ValueError("Missing workflow data for workflow/advanced chat app") + match app_mode: + case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW: + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for workflow/advanced chat app") - environment_variables_list = workflow_data.get("environment_variables", []) - environment_variables = [ - variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables", []) - conversation_variables = [ - variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list - ] + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) + for obj in conversation_variables_list + ] - workflow_service = WorkflowService() - current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) - if current_draft_workflow: - unique_hash = current_draft_workflow.unique_hash - else: - unique_hash = None - graph = workflow_data.get("graph", {}) - for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: - dataset_ids = node["data"].get("dataset_ids", []) - node["data"]["dataset_ids"] = [ - decrypted_id - for dataset_id in dataset_ids - if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id)) - ] - workflow_service.sync_draft_workflow( - app_model=app, - graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), - unique_hash=unique_hash, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: - # Initialize model config - model_config = data.get("model_config") - if not model_config or not isinstance(model_config, dict): - raise ValueError("Missing model_config for chat/agent-chat/completion app") - # Initialize or update model config - if not app.app_model_config: - app_model_config = AppModelConfig( - app_id=app.id, created_by=account.id, updated_by=account.id - ).from_model_config_dict(cast(AppModelConfigDict, model_config)) - app_model_config.id = str(uuid4()) - app.app_model_config_id = app_model_config.id + workflow_service = WorkflowService() + current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if ( + decrypted_id := self.decrypt_dataset_id( + encrypted_data=dataset_id, tenant_id=app.tenant_id + ) + ) + ] + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION: + # Initialize model config + model_config = data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise ValueError("Missing model_config for chat/agent-chat/completion app") + # Initialize or update model config + if not app.app_model_config: + app_model_config = AppModelConfig( + app_id=app.id, created_by=account.id, updated_by=account.id + ).from_model_config_dict(cast(AppModelConfigDict, model_config)) + app_model_config.id = str(uuid4()) + app.app_model_config_id = app_model_config.id - self._session.add(app_model_config) - app_model_config_was_updated.send(app, app_model_config=app_model_config) - else: - raise ValueError("Invalid app mode") + self._session.add(app_model_config) + app_model_config_was_updated.send(app, app_model_config=app_model_config) + case _: + raise ValueError("Invalid app mode") return app @classmethod diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index a6639dc780..2c9d815b64 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -4,7 +4,7 @@ import logging import threading import uuid from collections.abc import Callable, Generator, Mapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -89,7 +89,7 @@ class AppGenerateService: def generate( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, @@ -358,11 +358,11 @@ class AppGenerateService: def generate_more_like_this( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[Mapping, Generator]: + ) -> Mapping | Generator: """ Generate more like this :param app_model: app model diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index b4471f51d8..8b39d63385 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac import json from datetime import UTC, datetime -from typing import Any, Union +from typing import Any from celery.result import AsyncResult from sqlalchemy import select @@ -51,7 +51,7 @@ class AsyncWorkflowService: @classmethod def trigger_workflow_async( - cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData + cls, session: Session, user: Account | EndUser, trigger_data: TriggerData ) -> AsyncTriggerResponse: """ Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK @@ -184,7 +184,7 @@ class AsyncWorkflowService: @classmethod def reinvoke_trigger( - cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str + cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str ) -> AsyncTriggerResponse: """ Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index b0f7efaccd..ea12e40420 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -7,7 +7,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app from graphon.model_runtime.utils.encoders import jsonable_encoder -from sqlalchemy import select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs: for model, table_name in related_tables: # Query records related to expired messages - records = ( - session.query(model) - .where( + records = session.scalars( + select(model).where( model.message_id.in_(batch_message_ids), # type: ignore ) - .all() - ) + ).all() if len(records) == 0: continue @@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs: except Exception: logger.exception("Failed to save %s records", table_name) - session.query(model).where( - model.id.in_(record_ids), # type: ignore - ).delete(synchronize_session=False) + session.execute( + delete(model) + .where( + model.id.in_(record_ids), # type: ignore + ) + .execution_options(synchronize_session=False) + ) click.echo( click.style( @@ -121,15 +123,14 @@ class ClearFreePlanTenantExpiredLogs: app_ids = [app.id for app in apps] while True: with sessionmaker(bind=db.engine, autoflush=False).begin() as session: - messages = ( - session.query(Message) + messages = session.scalars( + select(Message) .where( Message.app_id.in_(app_ids), Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(messages) == 0: break @@ -147,9 +148,9 @@ class ClearFreePlanTenantExpiredLogs: message_ids = [message.id for message in messages] # delete messages - session.query(Message).where( - Message.id.in_(message_ids), - ).delete(synchronize_session=False) + session.execute( + delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False) + ) cls._clear_message_related_tables(session, tenant_id, message_ids) @@ -161,15 +162,14 @@ class ClearFreePlanTenantExpiredLogs: while True: with sessionmaker(bind=db.engine, autoflush=False).begin() as session: - conversations = ( - session.query(Conversation) + conversations = session.scalars( + select(Conversation) .where( Conversation.app_id.in_(app_ids), Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(conversations) == 0: break @@ -186,9 +186,11 @@ class ClearFreePlanTenantExpiredLogs: ) conversation_ids = [conversation.id for conversation in conversations] - session.query(Conversation).where( - Conversation.id.in_(conversation_ids), - ).delete(synchronize_session=False) + session.execute( + delete(Conversation) + .where(Conversation.id.in_(conversation_ids)) + .execution_options(synchronize_session=False) + ) click.echo( click.style( @@ -293,15 +295,14 @@ class ClearFreePlanTenantExpiredLogs: while True: with sessionmaker(bind=db.engine, autoflush=False).begin() as session: - workflow_app_logs = ( - session.query(WorkflowAppLog) + workflow_app_logs = session.scalars( + select(WorkflowAppLog) .where( WorkflowAppLog.tenant_id == tenant_id, WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(workflow_app_logs) == 0: break @@ -321,8 +322,10 @@ class ClearFreePlanTenantExpiredLogs: workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] # delete workflow app logs - session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete( - synchronize_session=False + session.execute( + delete(WorkflowAppLog) + .where(WorkflowAppLog.id.in_(workflow_app_log_ids)) + .execution_options(synchronize_session=False) ) click.echo( @@ -344,7 +347,7 @@ class ClearFreePlanTenantExpiredLogs: current_time = started_at with sessionmaker(db.engine).begin() as session: - total_tenant_count = session.query(Tenant.id).count() + total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0 click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) @@ -409,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs: tenant_count = 0 for test_interval in test_intervals: tenant_count = ( - session.query(Tenant.id) - .where(Tenant.created_at.between(current_time, current_time + test_interval)) - .count() + session.scalar( + select(func.count(Tenant.id)).where( + Tenant.created_at.between(current_time, current_time + test_interval) + ) + ) + or 0 ) if tenant_count <= 100: interval = test_interval @@ -433,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs: batch_end = min(current_time + interval, ended_at) - rs = ( - session.query(Tenant.id) + rs = session.execute( + select(Tenant.id) .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3e952059ac..e07e01ad42 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -528,6 +528,8 @@ class DatasetService: raise ValueError("External knowledge id is required.") if not external_knowledge_api_id: raise ValueError("External knowledge api id is required.") + # Ensure the referenced external API template exists and belongs to the dataset tenant. + ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id, dataset.tenant_id) # Update metadata fields dataset.updated_by = user.id if user else None dataset.updated_at = naive_utc_now() @@ -552,8 +554,8 @@ class DatasetService: external_knowledge_api_id: External knowledge API identifier """ with sessionmaker(db.engine).begin() as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + external_knowledge_binding = session.scalar( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1) ) if not external_knowledge_binding: @@ -1454,15 +1456,17 @@ class DocumentService: document_id_list: list[str] = [str(document_id) for document_id in document_ids] with session_factory.create_session() as session: - updated_count = ( - session.query(Document) - .filter( + result = session.execute( + update(Document) + .where( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) - .update({Document.need_summary: need_summary}, synchronize_session=False) + .values(need_summary=need_summary) + .execution_options(synchronize_session=False) ) + updated_count = result.rowcount # type: ignore[union-attr,attr-defined] session.commit() logger.info( "Updated need_summary to %s for %d documents in dataset %s", @@ -2822,6 +2826,10 @@ class DocumentService: knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) + if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL: + if not knowledge_config.process_rule.rules.parent_mode: + knowledge_config.process_rule.rules.parent_mode = "paragraph" + if not knowledge_config.process_rule.rules.segmentation: raise ValueError("Process rule segmentation is required") diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index d5f8cd30bd..9e7de36593 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from typing import Any from graphon.model_runtime.entities.provider_entities import FormType -from sqlalchemy import func, select +from sqlalchemy import delete, func, select, update from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -54,11 +54,13 @@ class DatasourceProviderService: remove oauth custom client params """ with sessionmaker(bind=db.engine).begin() as session: - session.query(DatasourceOauthTenantParamConfig).filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - ).delete() + session.execute( + delete(DatasourceOauthTenantParamConfig).where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, + ) + ) def decrypt_datasource_provider_credentials( self, @@ -110,15 +112,21 @@ class DatasourceProviderService: """ with sessionmaker(bind=db.engine).begin() as session: if credential_id: - datasource_provider = ( - session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() + datasource_provider = session.scalar( + select(DatasourceProvider) + .where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id) + .limit(1) ) else: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + datasource_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .first() + .limit(1) ) if not datasource_provider: return {} @@ -173,12 +181,15 @@ class DatasourceProviderService: get all datasource credentials by provider """ with sessionmaker(bind=db.engine).begin() as session: - datasource_providers = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + datasource_providers = session.scalars( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .all() - ) + ).all() if not datasource_providers: return [] current_user = get_current_user() @@ -232,15 +243,15 @@ class DatasourceProviderService: update datasource provider name """ with sessionmaker(bind=db.engine).begin() as session: - target_provider = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - id=credential_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + target_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == credential_id, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if target_provider is None: raise ValueError("provider not found") @@ -250,16 +261,16 @@ class DatasourceProviderService: # check name is exist if ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=name, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == name, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, + ) ) - .count() - > 0 - ): + or 0 + ) > 0: raise ValueError("Authorization name is already exists") target_provider.name = name @@ -273,26 +284,31 @@ class DatasourceProviderService: """ with sessionmaker(bind=db.engine).begin() as session: # get provider - target_provider = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - id=credential_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + target_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == credential_id, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if target_provider is None: raise ValueError("provider not found") # clear default provider - session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=target_provider.provider, - plugin_id=target_provider.plugin_id, - is_default=True, - ).update({"is_default": False}) + session.execute( + update(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == target_provider.provider, + DatasourceProvider.plugin_id == target_provider.plugin_id, + DatasourceProvider.is_default.is_(True), + ) + .values(is_default=False) + .execution_options(synchronize_session=False) + ) # set new default provider target_provider.is_default = True @@ -311,14 +327,14 @@ class DatasourceProviderService: if client_params is None and enabled is None: return with sessionmaker(bind=db.engine).begin() as session: - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + tenant_oauth_client_params = session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if not tenant_oauth_client_params: @@ -351,9 +367,14 @@ class DatasourceProviderService: """ with Session(db.engine).no_autoflush as session: return ( - session.query(DatasourceOauthParamConfig) - .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id) - .first() + session.scalar( + select(DatasourceOauthParamConfig) + .where( + DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id, + ) + .limit(1) + ) is not None ) @@ -423,15 +444,15 @@ class DatasourceProviderService: plugin_id = datasource_provider_id.plugin_id with Session(db.engine).no_autoflush as session: # get tenant oauth client params - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - enabled=True, + tenant_oauth_client_params = session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == provider, + DatasourceOauthTenantParamConfig.plugin_id == plugin_id, + DatasourceOauthTenantParamConfig.enabled.is_(True), ) - .first() + .limit(1) ) if tenant_oauth_client_params: encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) @@ -443,8 +464,13 @@ class DatasourceProviderService: is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier) if is_verified: # fallback to system oauth client params - oauth_client_params = ( - session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + oauth_client_params = session.scalar( + select(DatasourceOauthParamConfig) + .where( + DatasourceOauthParamConfig.provider == provider, + DatasourceOauthParamConfig.plugin_id == plugin_id, + ) + .limit(1) ) if oauth_client_params: return oauth_client_params.system_credentials @@ -455,15 +481,13 @@ class DatasourceProviderService: def generate_next_datasource_provider_name( session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType ) -> str: - db_providers = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, + db_providers = session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, ) - .all() - ) + ).all() return generate_incremental_name( [provider.name for provider in db_providers], f"{credential_type.get_name()}", @@ -485,8 +509,10 @@ class DatasourceProviderService: with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}" with redis_client.lock(lock, timeout=20): - target_provider = ( - session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() + target_provider = session.scalar( + select(DatasourceProvider) + .where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id) + .limit(1) ) if target_provider is None: raise ValueError("provider not found") @@ -496,25 +522,28 @@ class DatasourceProviderService: db_provider_name = target_provider.name else: name_conflict = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=db_provider_name, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - auth_type=CredentialType.OAUTH2.value, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == db_provider_name, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + DatasourceProvider.auth_type == CredentialType.OAUTH2.value, + ) ) - .count() + or 0 ) if name_conflict > 0: db_provider_name = generate_incremental_name( [ provider.name - for provider in session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ) + for provider in session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + ) + ).all() ], db_provider_name, ) @@ -556,25 +585,27 @@ class DatasourceProviderService: ) else: if ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=db_provider_name, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - auth_type=credential_type.value, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == db_provider_name, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + DatasourceProvider.auth_type == credential_type.value, + ) ) - .count() - > 0 - ): + or 0 + ) > 0: db_provider_name = generate_incremental_name( [ provider.name - for provider in session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ) + for provider in session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + ) + ).all() ], db_provider_name, ) @@ -627,11 +658,16 @@ class DatasourceProviderService: # check name is exist if ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name) - .count() - > 0 - ): + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.plugin_id == plugin_id, + DatasourceProvider.provider == provider_name, + DatasourceProvider.name == db_provider_name, + ) + ) + or 0 + ) > 0: raise ValueError("Authorization name is already exists") try: @@ -918,21 +954,31 @@ class DatasourceProviderService: """ with sessionmaker(bind=db.engine).begin() as session: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) - .first() + datasource_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == auth_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .limit(1) ) if not datasource_provider: raise ValueError("Datasource provider not found") # update name if name and name != datasource_provider.name: if ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id) - .count() - > 0 - ): + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == name, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + ) + or 0 + ) > 0: raise ValueError("Authorization name is already exists") datasource_provider.name = name diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index d30ec940f5..96db644d44 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from typing import Any, Union, cast +from typing import Any, cast from urllib.parse import urlparse import httpx @@ -148,18 +148,23 @@ class ExternalDatasetService: db.session.commit() @staticmethod - def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: + def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]: + """ + Return usage for an external knowledge API within a single tenant. + + The caller already scopes access by tenant, so this query must do the + same; otherwise the endpoint becomes a cross-tenant UUID oracle. + """ count = ( db.session.scalar( select(func.count(ExternalKnowledgeBindings.id)).where( - ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id + ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id, + ExternalKnowledgeBindings.tenant_id == tenant_id, ) ) or 0 ) - if count > 0: - return True, count - return False, 0 + return count > 0, count @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: @@ -190,9 +195,7 @@ class ExternalDatasetService: raise ValueError(f"{parameter.get('name')} is required") @staticmethod - def process_external_api( - settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]] - ) -> httpx.Response: + def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response: """ do http request depending on api bundle """ @@ -314,7 +317,10 @@ class ExternalDatasetService: external_knowledge_api = db.session.scalar( select(ExternalKnowledgeApis) - .where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) + .where( + ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id, + ExternalKnowledgeApis.tenant_id == tenant_id, + ) .limit(1) ) if external_knowledge_api is None or external_knowledge_api.settings is None: diff --git a/api/services/file_service.py b/api/services/file_service.py index 50a326d813..79a935de4b 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Iterator, Sequence from contextlib import contextmanager, suppress from tempfile import NamedTemporaryFile -from typing import Literal, Union +from typing import Literal from zipfile import ZIP_DEFLATED, ZipFile from graphon.file import helpers as file_helpers @@ -52,7 +52,7 @@ class FileService: filename: str, content: bytes, mimetype: str, - user: Union[Account, EndUser], + user: Account | EndUser, source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: @@ -132,8 +132,8 @@ class FileService: return file_size <= file_size_limit def get_file_base64(self, file_id: str) -> str: - upload_file = ( - self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = self._session_maker(expire_on_commit=False).scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) ) if not upload_file: raise NotFound("File not found") @@ -178,7 +178,7 @@ class FileService: Return a short text preview extracted from a document file. """ with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") @@ -200,7 +200,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -220,7 +220,7 @@ class FileService: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -231,7 +231,7 @@ class FileService: def get_public_image_preview(self, file_id: str): with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -247,7 +247,7 @@ class FileService: def get_file_content(self, file_id: str) -> str: with self._session_maker(expire_on_commit=False) as session: - upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 3cce83a975..41b6b885b2 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, TypedDict, Union +from typing import Any, TypedDict from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -626,7 +626,7 @@ class ModelLoadBalancingService: def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + ) -> ModelCredentialSchema | ProviderCredentialSchema: """Get form schemas.""" if provider_configuration.provider.model_credential_schema: return provider_configuration.provider.model_credential_schema diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index a58bede8db..9bb0ab6ae2 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -1,14 +1,13 @@ from sqlalchemy import select -from sqlalchemy.orm import sessionmaker -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import TenantPluginAutoUpgradeStrategy class PluginAutoUpgradeService: @staticmethod def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: return session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -24,7 +23,7 @@ class PluginAutoUpgradeService: exclude_plugins: list[str], include_plugins: list[str], ) -> bool: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -51,7 +50,7 @@ class PluginAutoUpgradeService: @staticmethod def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index d6f6ee8086..43a726b100 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -13,6 +13,7 @@ import sqlalchemy as sa import tqdm from flask import Flask, current_app from pydantic import TypeAdapter +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity @@ -66,7 +67,7 @@ class PluginMigration: current_time = started_at with Session(db.engine) as session: - total_tenant_count = session.query(Tenant.id).count() + total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0 click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) @@ -123,9 +124,12 @@ class PluginMigration: tenant_count = 0 for test_interval in test_intervals: tenant_count = ( - session.query(Tenant.id) - .where(Tenant.created_at.between(current_time, current_time + test_interval)) - .count() + session.scalar( + select(func.count(Tenant.id)).where( + Tenant.created_at.between(current_time, current_time + test_interval) + ) + ) + or 0 ) if tenant_count <= 100: interval = test_interval @@ -147,8 +151,8 @@ class PluginMigration: batch_end = min(current_time + interval, ended_at) - rs = ( - session.query(Tenant.id) + rs = session.execute( + select(Tenant.id) .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) @@ -235,7 +239,7 @@ class PluginMigration: Extract tool tables. """ with Session(db.engine) as session: - rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all() + rs = session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id)).all() result = [] for row in rs: result.append(ToolProviderID(row.provider).plugin_id) @@ -249,7 +253,7 @@ class PluginMigration: """ with Session(db.engine) as session: - rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all() + rs = session.scalars(select(Workflow).where(Workflow.tenant_id == tenant_id)).all() result = [] for row in rs: graph = row.graph_dict @@ -272,7 +276,7 @@ class PluginMigration: Extract app tables. """ with Session(db.engine) as session: - apps = session.query(App).where(App.tenant_id == tenant_id).all() + apps = session.scalars(select(App).where(App.tenant_id == tenant_id)).all() if not apps: return [] @@ -280,7 +284,7 @@ class PluginMigration: app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT ] - rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() + rs = session.scalars(select(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids))).all() result = [] for row in rs: agent_config = row.agent_mode_dict diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 0d2a70acbd..3cca4268d0 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -1,14 +1,13 @@ from sqlalchemy import select -from sqlalchemy.orm import sessionmaker -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import TenantPluginPermission class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: return session.scalar( select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) @@ -19,7 +18,7 @@ class PluginPermissionService: install_permission: TenantPluginPermission.InstallPermission, debug_permission: TenantPluginPermission.DebugPermission, ): - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session, session.begin(): permission = session.scalar( select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 10e89b1dba..56bc785958 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Union +from typing import Any from configs import dify_config from core.app.apps.pipeline.pipeline_generator import PipelineGenerator @@ -17,7 +17,7 @@ class PipelineGenerateService: def generate( cls, pipeline: Pipeline, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f6d80f9a6e..5fc5b412b3 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,7 +5,7 @@ import threading import time from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Union, cast +from typing import Any, cast from uuid import uuid4 from flask_login import current_user @@ -1387,7 +1387,7 @@ class RagPipelineService: "uninstalled_recommended_plugins": uninstalled_plugin_list, } - def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]): + def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser): """ Retry error document """ diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c24bf3d649..65bdf43af5 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -283,7 +283,9 @@ class RagPipelineDslService: ): raise ValueError("Chunk structure is not compatible with the published pipeline") if not dataset: - datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all() + datasets = self._session.scalars( + select(Dataset).where(Dataset.tenant_id == account.current_tenant_id) + ).all() names = [dataset.name for dataset in datasets] generate_name = generate_incremental_name(names, name) dataset = Dataset( @@ -303,8 +305,8 @@ class RagPipelineDslService: chunk_structure=knowledge_configuration.chunk_structure, ) if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - dataset_collection_binding = ( - self._session.query(DatasetCollectionBinding) + dataset_collection_binding = self._session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -312,7 +314,7 @@ class RagPipelineDslService: DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -440,8 +442,8 @@ class RagPipelineDslService: dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - dataset_collection_binding = ( - self._session.query(DatasetCollectionBinding) + dataset_collection_binding = self._session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -449,7 +451,7 @@ class RagPipelineDslService: DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -591,14 +593,14 @@ class RagPipelineDslService: IMPORT_INFO_REDIS_EXPIRY, CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), ) - workflow = ( - self._session.query(Workflow) + workflow = self._session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() + .limit(1) ) # create draft workflow if not found @@ -665,14 +667,12 @@ class RagPipelineDslService: :param pipeline: Pipeline instance """ - workflow = ( - self._session.query(Workflow) - .where( + workflow = self._session.scalar( + select(Workflow).where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() ) if not workflow: raise ValueError("Missing draft workflow configuration, please check.") @@ -904,15 +904,16 @@ class RagPipelineDslService: ): if rag_pipeline_dataset_create_entity.name: # check if dataset name already exists - if ( - self._session.query(Dataset) - .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) - .first() + if self._session.scalar( + select(Dataset).where( + Dataset.name == rag_pipeline_dataset_create_entity.name, + Dataset.tenant_id == tenant_id, + ) ): raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") else: # generate a random name as Untitled 1 2 3 ... - datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all() + datasets = self._session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all() names = [dataset.name for dataset in datasets] rag_pipeline_dataset_create_entity.name = generate_incremental_name( names, diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 6fb90d356d..1df5fd13b6 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,3 +1,5 @@ +from typing import Any, TypedDict + from sqlalchemy import select from constants.languages import languages @@ -8,16 +10,43 @@ from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase from services.recommend_app.recommend_app_type import RecommendAppType +class RecommendedAppItemDict(TypedDict): + id: str + app: App | None + app_id: str + description: Any + copyright: Any + privacy_policy: Any + custom_disclaimer: str + category: str + position: int + is_listed: bool + + +class RecommendedAppsResultDict(TypedDict): + recommended_apps: list[RecommendedAppItemDict] + categories: list[str] + + +class RecommendedAppDetailDict(TypedDict): + id: str + name: str + icon: Any + icon_background: str | None + mode: str + export_data: str + + class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): """ Retrieval recommended app from database """ - def get_recommended_apps_and_categories(self, language: str): + def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict: result = self.fetch_recommended_apps_from_db(language) return result - def get_recommend_app_detail(self, app_id: str): + def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None: result = self.fetch_recommended_app_detail_from_db(app_id) return result @@ -25,7 +54,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.DATABASE @classmethod - def fetch_recommended_apps_from_db(cls, language: str): + def fetch_recommended_apps_from_db(cls, language: str) -> RecommendedAppsResultDict: """ Fetch recommended apps from db. :param language: language @@ -41,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): ).all() categories = set() - recommended_apps_result = [] + recommended_apps_result: list[RecommendedAppItemDict] = [] for recommended_app in recommended_apps: app = recommended_app.app if not app or not app.is_public: @@ -51,7 +80,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): if not site: continue - recommended_app_result = { + recommended_app_result: RecommendedAppItemDict = { "id": recommended_app.id, "app": recommended_app.app, "app_id": recommended_app.app_id, @@ -67,10 +96,10 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): categories.add(recommended_app.category) - return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories)) @classmethod - def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None: + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None: """ Fetch recommended app detail from db. :param app_id: App ID @@ -89,11 +118,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): if not app_model or not app_model.is_public: return None - return { - "id": app_model.id, - "name": app_model.name, - "icon": app_model.icon, - "icon_background": app_model.icon_background, - "mode": app_model.mode, - "export_data": AppDslService.export_dsl(app_model=app_model), - } + return RecommendedAppDetailDict( + id=app_model.id, + name=app_model.name, + icon=app_model.icon, + icon_background=app_model.icon_background, + mode=app_model.mode, + export_data=AppDslService.export_dsl(app_model=app_model), + ) diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 8760d60de0..c906e3bca3 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -8,6 +8,7 @@ from typing import TypedDict, cast from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.model_entities import ModelType +from sqlalchemy import select from sqlalchemy.orm import Session from core.db.session_factory import session_factory @@ -109,8 +110,13 @@ class SummaryIndexService: """ with session_factory.create_session() as session: # Check if summary record already exists - existing_summary = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + existing_summary = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if existing_summary: @@ -309,8 +315,10 @@ class SummaryIndexService: summary_record_id, segment.id, ) - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where(DocumentSegmentSummary.id == summary_record_id) + .limit(1) ) if not summary_record_in_session: @@ -323,10 +331,13 @@ class SummaryIndexService: dataset.id, segment.id, ) - summary_record_in_session = ( - session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if not summary_record_in_session: @@ -487,8 +498,10 @@ class SummaryIndexService: with session_factory.create_session() as error_session: # Try to find the record by id first # Note: Using assignment only (no type annotation) to avoid redeclaration error - summary_record_in_session = ( - error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + summary_record_in_session = error_session.scalar( + select(DocumentSegmentSummary) + .where(DocumentSegmentSummary.id == summary_record_id) + .limit(1) ) if not summary_record_in_session: # Try to find by chunk_id and dataset_id @@ -500,10 +513,13 @@ class SummaryIndexService: dataset.id, segment.id, ) - summary_record_in_session = ( - error_session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record_in_session = error_session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record_in_session: @@ -551,14 +567,12 @@ class SummaryIndexService: with session_factory.create_session() as session: # Query existing summary records - existing_summaries = ( - session.query(DocumentSegmentSummary) - .filter( + existing_summaries = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset.id, ) - .all() - ) + ).all() existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} # Create or update records @@ -603,8 +617,13 @@ class SummaryIndexService: error: Error message """ with session_factory.create_session() as session: - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -639,8 +658,13 @@ class SummaryIndexService: with session_factory.create_session() as session: try: # Get or refresh summary record in this session - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if not summary_record_in_session: @@ -710,8 +734,13 @@ class SummaryIndexService: except Exception as e: logger.exception("Failed to generate summary for segment %s", segment.id) # Update summary record with error status - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record_in_session: summary_record_in_session.status = SummaryStatus.ERROR @@ -769,17 +798,17 @@ class SummaryIndexService: with session_factory.create_session() as session: # Query segments (only enabled segments) - query = session.query(DocumentSegment).filter_by( - dataset_id=dataset.id, - document_id=document.id, - status="completed", - enabled=True, # Only generate summaries for enabled segments + stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments ) if segment_ids: - query = query.filter(DocumentSegment.id.in_(segment_ids)) + stmt = stmt.where(DocumentSegment.id.in_(segment_ids)) - segments = query.all() + segments = list(session.scalars(stmt).all()) if not segments: logger.info("No segments found for document %s", document.id) @@ -848,15 +877,15 @@ class SummaryIndexService: from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=True, # Only disable enabled summaries + stmt = select(DocumentSegmentSummary).where( + DocumentSegmentSummary.dataset_id == dataset.id, + DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -911,15 +940,15 @@ class SummaryIndexService: return with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=False, # Only enable disabled summaries + stmt = select(DocumentSegmentSummary).where( + DocumentSegmentSummary.dataset_id == dataset.id, + DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -935,13 +964,13 @@ class SummaryIndexService: enabled_count = 0 for summary in summaries: # Get the original segment - segment = ( - session.query(DocumentSegment) - .filter_by( - id=summary.chunk_id, - dataset_id=dataset.id, + segment = session.scalar( + select(DocumentSegment) + .where( + DocumentSegment.id == summary.chunk_id, + DocumentSegment.dataset_id == dataset.id, ) - .first() + .limit(1) ) # Summary.enabled stays in sync with chunk.enabled, @@ -988,12 +1017,12 @@ class SummaryIndexService: segment_ids: List of segment IDs to delete summaries for. If None, delete all. """ with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) + stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -1046,10 +1075,13 @@ class SummaryIndexService: # Check if summary_content is empty (whitespace-only strings are considered empty) if not summary_content or not summary_content.strip(): # If summary is empty, only delete existing summary vector and record - summary_record = ( - session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -1077,8 +1109,13 @@ class SummaryIndexService: return None # Find existing summary record - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -1162,8 +1199,13 @@ class SummaryIndexService: except Exception as e: logger.exception("Failed to update summary for segment %s", segment.id) # Update summary record with error status if it exists - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: summary_record.status = SummaryStatus.ERROR @@ -1185,14 +1227,14 @@ class SummaryIndexService: DocumentSegmentSummary instance if found, None otherwise """ with session_factory.create_session() as session: - return ( - session.query(DocumentSegmentSummary) + return session.scalar( + select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment_id, DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) - .first() + .limit(1) ) @staticmethod @@ -1211,15 +1253,13 @@ class SummaryIndexService: return {} with session_factory.create_session() as session: - summary_records = ( - session.query(DocumentSegmentSummary) - .where( + summary_records = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) - .all() - ) + ).all() return {summary.chunk_id: summary for summary in summary_records} @@ -1239,16 +1279,16 @@ class SummaryIndexService: List of DocumentSegmentSummary instances (only enabled summaries) """ with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter( + stmt = select(DocumentSegmentSummary).where( DocumentSegmentSummary.document_id == document_id, DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - return query.all() + return list(session.scalars(stmt).all()) @staticmethod def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None: @@ -1265,16 +1305,15 @@ class SummaryIndexService: """ # Get all segments for this document (excluding qa_model and re_segment) with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment.id) - .where( - DocumentSegment.document_id == document_id, - DocumentSegment.status != "re_segment", - DocumentSegment.tenant_id == tenant_id, - ) - .all() + segment_ids = list( + session.scalars( + select(DocumentSegment.id).where( + DocumentSegment.document_id == document_id, + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + ).all() ) - segment_ids = [seg.id for seg in segments] if not segment_ids: return None @@ -1312,15 +1351,13 @@ class SummaryIndexService: # Get all segments for these documents (excluding qa_model and re_segment) with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment.id, DocumentSegment.document_id) - .where( + segments = session.execute( + select(DocumentSegment.id, DocumentSegment.document_id).where( DocumentSegment.document_id.in_(document_ids), DocumentSegment.status != "re_segment", DocumentSegment.tenant_id == tenant_id, ) - .all() - ) + ).all() # Group segments by document_id document_segments_map: dict[str, list[str]] = {} diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 3daaf9a263..202432007a 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from pathlib import Path from typing import Any -from sqlalchemy import exists, select +from sqlalchemy import delete, exists, func, select, update from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -47,11 +47,15 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider) with sessionmaker(bind=db.engine).begin() as session: - session.query(ToolOAuthTenantClient).filter_by( - tenant_id=tenant_id, - provider=tool_provider.provider_name, - plugin_id=tool_provider.plugin_id, - ).delete() + session.execute( + delete(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ) + .execution_options(synchronize_session=False) + ) return {"result": "success"} @staticmethod @@ -151,13 +155,13 @@ class BuiltinToolManageService: """ with sessionmaker(bind=db.engine).begin() as session: # get if the provider exists - db_provider = ( - session.query(BuiltinToolProvider) + db_provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) - .first() + .limit(1) ) if db_provider is None: raise ValueError(f"you have not added provider {provider}") @@ -228,7 +232,13 @@ class BuiltinToolManageService: raise ValueError(f"provider {provider} does not need credentials") provider_count = ( - session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() + session.scalar( + select(func.count(BuiltinToolProvider.id)).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + ) + or 0 ) # check if the provider count is reached the limit @@ -304,16 +314,15 @@ class BuiltinToolManageService: def generate_builtin_tool_provider_name( session: Session, tenant_id: str, provider: str, credential_type: CredentialType ) -> str: - db_providers = ( - session.query(BuiltinToolProvider) - .filter_by( - tenant_id=tenant_id, - provider=provider, - credential_type=credential_type, + db_providers = session.scalars( + select(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.credential_type == credential_type, ) .order_by(BuiltinToolProvider.created_at.desc()) - .all() - ) + ).all() return generate_incremental_name( [provider.name for provider in db_providers], f"{credential_type.get_name()}", @@ -375,13 +384,13 @@ class BuiltinToolManageService: delete tool provider """ with sessionmaker(bind=db.engine).begin() as session: - db_provider = ( - session.query(BuiltinToolProvider) + db_provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) - .first() + .limit(1) ) if db_provider is None: @@ -405,14 +414,26 @@ class BuiltinToolManageService: """ with sessionmaker(bind=db.engine).begin() as session: # get provider - target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first() + target_provider = session.scalar( + select(BuiltinToolProvider) + .where(BuiltinToolProvider.id == id, BuiltinToolProvider.tenant_id == tenant_id) + .limit(1) + ) if target_provider is None: raise ValueError("provider not found") # clear default provider - session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True - ).update({"is_default": False}) + session.execute( + update(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.user_id == user_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.is_default.is_(True), + ) + .values(is_default=False) + .execution_options(synchronize_session=False) + ) # set new default provider target_provider.is_default = True @@ -426,10 +447,13 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider_name) with Session(db.engine, autoflush=False) as session: - system_client: ToolOAuthSystemClient | None = ( - session.query(ToolOAuthSystemClient) - .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) - .first() + system_client = session.scalar( + select(ToolOAuthSystemClient) + .where( + ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id, + ToolOAuthSystemClient.provider == tool_provider.provider_name, + ) + .limit(1) ) return system_client is not None @@ -440,15 +464,15 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider) with Session(db.engine, autoflush=False) as session: - user_client: ToolOAuthTenantClient | None = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=tool_provider.provider_name, - plugin_id=tool_provider.plugin_id, - enabled=True, + user_client = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) return user_client is not None and user_client.enabled @@ -465,15 +489,15 @@ class BuiltinToolManageService: cache=NoOpProviderCredentialCache(), ) with Session(db.engine, autoflush=False) as session: - user_client: ToolOAuthTenantClient | None = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=tool_provider.provider_name, - plugin_id=tool_provider.plugin_id, - enabled=True, + user_client = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) oauth_params: Mapping[str, Any] | None = None if user_client: @@ -487,10 +511,13 @@ class BuiltinToolManageService: if not is_verified: return oauth_params - system_client: ToolOAuthSystemClient | None = ( - session.query(ToolOAuthSystemClient) - .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) - .first() + system_client = session.scalar( + select(ToolOAuthSystemClient) + .where( + ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id, + ToolOAuthSystemClient.provider == tool_provider.provider_name, + ) + .limit(1) ) if system_client: try: @@ -582,8 +609,8 @@ class BuiltinToolManageService: provider_name = provider_id_entity.provider_name if provider_id_entity.organization != "langgenius": - provider = ( - session.query(BuiltinToolProvider) + provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == full_provider_name, @@ -592,11 +619,11 @@ class BuiltinToolManageService: BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) - .first() + .limit(1) ) else: - provider = ( - session.query(BuiltinToolProvider) + provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_name) @@ -606,7 +633,7 @@ class BuiltinToolManageService: BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) - .first() + .limit(1) ) if provider is None: @@ -616,14 +643,14 @@ class BuiltinToolManageService: return provider except Exception: # it's an old provider without organization - return ( - session.query(BuiltinToolProvider) + return session.scalar( + select(BuiltinToolProvider) .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) .order_by( BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) - .first() + .limit(1) ) @staticmethod @@ -648,14 +675,14 @@ class BuiltinToolManageService: raise ValueError(f"Provider {provider} is not a builtin or plugin provider") with sessionmaker(bind=db.engine).begin() as session: - custom_client_params = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=tool_provider.plugin_id, - provider=tool_provider.provider_name, + custom_client_params = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, ) - .first() + .limit(1) ) # if the record does not exist, create a basic record @@ -692,14 +719,14 @@ class BuiltinToolManageService: """ with Session(db.engine) as session: tool_provider = ToolProviderID(provider) - custom_oauth_client_params: ToolOAuthTenantClient | None = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=tool_provider.plugin_id, - provider=tool_provider.provider_name, + custom_oauth_client_params = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, ) - .first() + .limit(1) ) if custom_oauth_client_params is None: return {} diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 4fd2ea1628..72954a3102 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, Union +from typing import Any from pydantic import TypeAdapter, ValidationError from yarl import URL @@ -69,7 +69,7 @@ class ToolTransformService: return "" @staticmethod - def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): + def repack_provider(tenant_id: str, provider: dict | ToolProviderApiEntity | PluginDatasourceProviderEntity): """ repack provider diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 8f5144c866..f7c35fa64e 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -4,7 +4,7 @@ from datetime import datetime from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import delete, or_, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -42,32 +42,43 @@ class WorkflowToolManageService: labels: list[str] | None = None, ): # check if the name is unique - existing_workflow_tool_provider = db.session.scalar( - select(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - # name or app_id - or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + existing_workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + # query if the name or app_id exists + existing_workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .limit(1) ) - .limit(1) - ) + # if the name or app_id exists raise error if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") - app: App | None = db.session.scalar( - select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1) - ) + # query the app + app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + app = _session.scalar(select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)) + # if not found raise error if app is None: raise ValueError(f"App {workflow_app_id} not found") + # query the workflow workflow: Workflow | None = app.workflow + + # if not found raise error if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") + # check if workflow configuration is synced WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) + # create workflow tool provider workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -84,15 +95,18 @@ class WorkflowToolManageService: try: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: + logger.warning(e, exc_info=True) raise ValueError(str(e)) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): - session.add(workflow_tool_provider) + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + _session.add(workflow_tool_provider) + # keep the session open to make orm instances in the same session if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) + return {"result": "success"} @classmethod @@ -111,6 +125,7 @@ class WorkflowToolManageService: ): """ Update a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: workflow tool id @@ -186,28 +201,32 @@ class WorkflowToolManageService: def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: """ List workflow tools. + :param user_id: the user id :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.scalars( - select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) - ).all() + + providers: list[WorkflowToolProvider] = [] + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + providers = list( + _session.scalars(select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)).all() + ) # Create a mapping from provider_id to app_id - provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools} + provider_id_to_app_id = {provider.id: provider.app_id for provider in providers} tools: list[WorkflowToolProviderController] = [] - for provider in db_tools: + for provider in providers: try: tools.append(ToolTransformService.workflow_provider_to_controller(provider)) except Exception: # skip deleted tools logger.exception("Failed to load workflow tool provider %s", provider.id) - labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) + labels = ToolLabelManager.get_tools_labels([tool for tool in tools if isinstance(tool, ToolProviderController)]) - result = [] + result: list[ToolProviderApiEntity] = [] for tool in tools: workflow_app_id = provider_id_to_app_id.get(tool.provider_id) @@ -232,17 +251,18 @@ class WorkflowToolManageService: def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Delete a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id """ - db.session.execute( - delete(WorkflowToolProvider).where( - WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id - ) - ) - db.session.commit() + with sessionmaker(db.engine).begin() as _session: + _ = _session.execute( + delete(WorkflowToolProvider).where( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id + ) + ) return {"result": "success"} @@ -250,47 +270,59 @@ class WorkflowToolManageService: def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Get a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id :return: the tool """ - db_tool: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .limit(1) - ) - return cls._get_workflow_tool(tenant_id, db_tool) + + tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + + return cls._get_workflow_tool(tenant_id, tool_provider) @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str): """ Get a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) - .limit(1) - ) - return cls._get_workflow_tool(tenant_id, db_tool) + + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + tool_provider: WorkflowToolProvider | None = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .limit(1) + ) + + return cls._get_workflow_tool(tenant_id, tool_provider) @classmethod def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None): """ Get a workflow tool. + :db_tool: the database tool :return: the tool """ if db_tool is None: raise ValueError("Tool not found") - workflow_app: App | None = db.session.scalar( - select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1) - ) + workflow_app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + workflow_app = _session.scalar( + select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1) + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") @@ -330,28 +362,32 @@ class WorkflowToolManageService: def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]: """ List workflow tool provider tools. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id :return: the list of tools """ - db_tool: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .limit(1) - ) - if db_tool is None: + provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + + if provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - tool = ToolTransformService.workflow_provider_to_controller(db_tool) + tool = ToolTransformService.workflow_provider_to_controller(provider) workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) if len(workflow_tools) == 0: raise ValueError(f"Tool {workflow_tool_id} not found") return [ ToolTransformService.convert_tool_entity_to_api_entity( - tool=tool.get_tools(db_tool.tenant_id)[0], + tool=tool.get_tools(provider.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool), tenant_id=tenant_id, ) diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index ae74f7a8cd..6e14d996ea 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -3,9 +3,9 @@ import logging import time as _time import uuid from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict -from sqlalchemy import desc, func +from sqlalchemy import delete, desc, func, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -42,6 +42,10 @@ from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) +class VerifyCredentialsResult(TypedDict): + verified: bool + + class TriggerProviderService: """Service for managing trigger providers and credentials""" @@ -69,27 +73,28 @@ class TriggerProviderService: workflows_in_use_map: dict[str, int] = {} with Session(db.engine, expire_on_commit=False) as session: # Get all subscriptions - subscriptions_db = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) + subscriptions_db = session.scalars( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + ) .order_by(desc(TriggerSubscription.created_at)) - .all() - ) + ).all() subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db] if not subscriptions: return [] - usage_counts = ( - session.query( + usage_counts = session.execute( + select( WorkflowPluginTrigger.subscription_id, func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"), ) - .filter( + .where( WorkflowPluginTrigger.tenant_id == tenant_id, WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]), ) .group_by(WorkflowPluginTrigger.subscription_id) - .all() - ) + ).all() workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts} provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) @@ -152,9 +157,13 @@ class TriggerProviderService: with redis_client.lock(lock_key, timeout=20): # Check provider count limit provider_count = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) - .count() + session.scalar( + select(func.count(TriggerSubscription.id)).where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + ) + ) + or 0 ) if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__: @@ -164,10 +173,14 @@ class TriggerProviderService: ) # Check if name already exists - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() + existing = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + TriggerSubscription.name == name, + ) + .limit(1) ) if existing: raise ValueError(f"Credential name '{name}' already exists for this provider") @@ -244,8 +257,13 @@ class TriggerProviderService: # Use distributed lock to prevent race conditions on the same subscription lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}" with redis_client.lock(lock_key, timeout=20): - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if not subscription: raise ValueError(f"Trigger subscription {subscription_id} not found") @@ -255,10 +273,14 @@ class TriggerProviderService: # Check for name uniqueness if name is being updated if name is not None and name != subscription.name: - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() + existing = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + TriggerSubscription.name == name, + ) + .limit(1) ) if existing: raise ValueError(f"Subscription name '{name}' already exists for this provider") @@ -316,11 +338,18 @@ class TriggerProviderService: with Session(db.engine, expire_on_commit=False) as session: subscription: TriggerSubscription | None = None if subscription_id: - subscription = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) else: - subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first() + subscription = session.scalar( + select(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant_id).limit(1) + ) if subscription: provider_controller = TriggerManager.get_trigger_provider( tenant_id, TriggerProviderID(subscription.provider_id) @@ -349,8 +378,13 @@ class TriggerProviderService: :param subscription_id: Subscription instance ID :return: Success response """ - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -402,7 +436,14 @@ class TriggerProviderService: :return: New token info """ with sessionmaker(bind=db.engine).begin() as session: - subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) + ) if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -475,8 +516,13 @@ class TriggerProviderService: now_ts: int = int(now if now is not None else _time.time()) with sessionmaker(bind=db.engine).begin() as session: - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if subscription is None: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -552,15 +598,15 @@ class TriggerProviderService: tenant_id=tenant_id, provider_id=provider_id ) with Session(db.engine, expire_on_commit=False) as session: - tenant_client: TriggerOAuthTenantClient | None = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - enabled=True, + tenant_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) oauth_params: Mapping[str, Any] | None = None @@ -578,10 +624,13 @@ class TriggerProviderService: return None # Check for system-level OAuth client - system_client: TriggerOAuthSystemClient | None = ( - session.query(TriggerOAuthSystemClient) - .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) - .first() + system_client = session.scalar( + select(TriggerOAuthSystemClient) + .where( + TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id, + TriggerOAuthSystemClient.provider == provider_id.provider_name, + ) + .limit(1) ) if system_client: @@ -602,10 +651,13 @@ class TriggerProviderService: if not is_verified: return False with Session(db.engine, expire_on_commit=False) as session: - system_client: TriggerOAuthSystemClient | None = ( - session.query(TriggerOAuthSystemClient) - .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) - .first() + system_client = session.scalar( + select(TriggerOAuthSystemClient) + .where( + TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id, + TriggerOAuthSystemClient.provider == provider_id.provider_name, + ) + .limit(1) ) return system_client is not None @@ -636,14 +688,14 @@ class TriggerProviderService: with sessionmaker(bind=db.engine).begin() as session: # Find existing custom client params - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, ) - .first() + .limit(1) ) # Create new record if doesn't exist @@ -690,14 +742,14 @@ class TriggerProviderService: :return: Masked OAuth client parameters """ with Session(db.engine) as session: - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, ) - .first() + .limit(1) ) if custom_client is None: @@ -727,11 +779,15 @@ class TriggerProviderService: :return: Success response """ with sessionmaker(bind=db.engine).begin() as session: - session.query(TriggerOAuthTenantClient).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ).delete() + session.execute( + delete(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + ) + .execution_options(synchronize_session=False) + ) return {"result": "success"} @@ -745,15 +801,15 @@ class TriggerProviderService: :return: True if enabled, False otherwise """ with Session(db.engine, expire_on_commit=False) as session: - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, - enabled=True, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) return custom_client is not None @@ -763,7 +819,9 @@ class TriggerProviderService: Get a trigger subscription by the endpoint ID. """ with Session(db.engine, expire_on_commit=False) as session: - subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first() + subscription = session.scalar( + select(TriggerSubscription).where(TriggerSubscription.endpoint_id == endpoint_id).limit(1) + ) if not subscription: return None provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( @@ -792,7 +850,7 @@ class TriggerProviderService: provider_id: TriggerProviderID, subscription_id: str, credentials: dict[str, Any], - ) -> dict[str, Any]: + ) -> VerifyCredentialsResult: """ Verify credentials for an existing subscription without updating it. diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 23858de9f6..c782bffad4 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -105,32 +105,32 @@ class WebhookService: """ with Session(db.engine) as session: # Get webhook trigger - webhook_trigger = ( - session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first() + webhook_trigger = session.scalar( + select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).limit(1) ) if not webhook_trigger: raise ValueError(f"Webhook not found: {webhook_id}") if is_debug: - workflow = ( - session.query(Workflow) - .filter( + workflow = session.scalar( + select(Workflow) + .where( Workflow.app_id == webhook_trigger.app_id, Workflow.version == Workflow.VERSION_DRAFT, ) .order_by(Workflow.created_at.desc()) - .first() + .limit(1) ) else: # Check if the corresponding AppTrigger exists - app_trigger = ( - session.query(AppTrigger) - .filter( + app_trigger = session.scalar( + select(AppTrigger) + .where( AppTrigger.app_id == webhook_trigger.app_id, AppTrigger.node_id == webhook_trigger.node_id, AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK, ) - .first() + .limit(1) ) if not app_trigger: @@ -147,14 +147,14 @@ class WebhookService: raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}") # Get workflow - workflow = ( - session.query(Workflow) - .filter( + workflow = session.scalar( + select(Workflow) + .where( Workflow.app_id == webhook_trigger.app_id, Workflow.version != Workflow.VERSION_DRAFT, ) .order_by(Workflow.created_at.desc()) - .first() + .limit(1) ) if not workflow: raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}") diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 1c1b94ae9d..2cc6e21574 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -3,8 +3,9 @@ import json import logging from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from enum import StrEnum -from typing import Any, ClassVar +from typing import Any, ClassVar, NotRequired, TypedDict from graphon.enums import NodeType from graphon.file import File @@ -19,7 +20,7 @@ from graphon.variables.segments import ( ) from graphon.variables.types import SegmentType from graphon.variables.utils import dumps_with_segments -from sqlalchemy import Engine, orm, select +from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session, sessionmaker @@ -222,11 +223,10 @@ class WorkflowDraftVariableService: ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: - return ( - self._session.query(WorkflowDraftVariable) + return self._session.scalar( + select(WorkflowDraftVariable) .options(orm.selectinload(WorkflowDraftVariable.variable_file)) .where(WorkflowDraftVariable.id == variable_id) - .first() ) def get_draft_variables_by_selectors( @@ -254,20 +254,21 @@ class WorkflowDraftVariableService: # Alternatively, a `SELECT` statement could be constructed for each selector and # combined using `UNION` to fetch all rows. # Benchmarking indicates that both approaches yield comparable performance. - query = ( - self._session.query(WorkflowDraftVariable) - .options( - orm.selectinload(WorkflowDraftVariable.variable_file).selectinload( - WorkflowDraftVariableFile.upload_file + return list( + self._session.scalars( + select(WorkflowDraftVariable) + .options( + orm.selectinload(WorkflowDraftVariable.variable_file).selectinload( + WorkflowDraftVariableFile.upload_file + ) + ) + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.user_id == user_id, + or_(*ors), ) ) - .where( - WorkflowDraftVariable.app_id == app_id, - WorkflowDraftVariable.user_id == user_id, - or_(*ors), - ) ) - return query.all() def list_variables_without_values( self, app_id: str, page: int, limit: int, user_id: str @@ -277,18 +278,21 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.user_id == user_id, ] total = None - query = self._session.query(WorkflowDraftVariable).where(*criteria) + base_stmt = select(WorkflowDraftVariable).where(*criteria) if page == 1: - total = query.count() - variables = ( - # Do not load the `value` field - query.options( - orm.defer(WorkflowDraftVariable.value, raiseload=True), + from sqlalchemy import func as sa_func + + total = self._session.scalar(select(sa_func.count()).select_from(base_stmt.subquery())) + variables = list( + self._session.scalars( + # Do not load the `value` field + base_stmt.options( + orm.defer(WorkflowDraftVariable.value, raiseload=True), + ) + .order_by(WorkflowDraftVariable.created_at.desc()) + .limit(limit) + .offset((page - 1) * limit) ) - .order_by(WorkflowDraftVariable.created_at.desc()) - .limit(limit) - .offset((page - 1) * limit) - .all() ) return WorkflowDraftVariableList(variables=variables, total=total) @@ -299,11 +303,13 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.user_id == user_id, ] - query = self._session.query(WorkflowDraftVariable).where(*criteria) - variables = ( - query.options(orm.selectinload(WorkflowDraftVariable.variable_file)) - .order_by(WorkflowDraftVariable.created_at.desc()) - .all() + variables = list( + self._session.scalars( + select(WorkflowDraftVariable) + .options(orm.selectinload(WorkflowDraftVariable.variable_file)) + .where(*criteria) + .order_by(WorkflowDraftVariable.created_at.desc()) + ) ) return WorkflowDraftVariableList(variables=variables) @@ -326,8 +332,8 @@ class WorkflowDraftVariableService: return self._get_variable(app_id, node_id, name, user_id=user_id) def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None: - return ( - self._session.query(WorkflowDraftVariable) + return self._session.scalar( + select(WorkflowDraftVariable) .options(orm.selectinload(WorkflowDraftVariable.variable_file)) .where( WorkflowDraftVariable.app_id == app_id, @@ -335,7 +341,6 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.name == name, WorkflowDraftVariable.user_id == user_id, ) - .first() ) def update_variable( @@ -488,20 +493,20 @@ class WorkflowDraftVariableService: self._session.delete(variable) def delete_user_workflow_variables(self, app_id: str, user_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.user_id == user_id, ) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def delete_app_workflow_variables(self, app_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where(WorkflowDraftVariable.app_id == app_id) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def delete_workflow_draft_variable_file(self, deletions: list[DraftVarFileDeletion]): @@ -540,14 +545,14 @@ class WorkflowDraftVariableService: return self._delete_node_variables(app_id, node_id, user_id=user_id) def _delete_node_variables(self, app_id: str, node_id: str, user_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.user_id == user_id, ) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None: @@ -588,13 +593,11 @@ class WorkflowDraftVariableService: conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id) if conv_id is not None: - conversation = ( - self._session.query(Conversation) - .where( + conversation = self._session.scalar( + select(Conversation).where( Conversation.id == conv_id, Conversation.app_id == workflow.app_id, ) - .first() ) # Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB). if conversation is not None: @@ -723,8 +726,27 @@ def _batch_upsert_draft_variable( session.execute(stmt) -def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: - d: dict[str, Any] = { +class _InsertionDict(TypedDict): + id: str + app_id: str + user_id: str | None + last_edited_at: datetime | None + node_id: str + name: str + selector: str + value_type: SegmentType + value: str + node_execution_id: str | None + file_id: str | None + visible: NotRequired[bool] + editable: NotRequired[bool] + created_at: NotRequired[datetime] + updated_at: NotRequired[datetime] + description: NotRequired[str] + + +def _model_to_insertion_dict(model: WorkflowDraftVariable) -> _InsertionDict: + d: _InsertionDict = { "id": model.id, "app_id": model.app_id, "user_id": model.user_id, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index c28704e83b..839b9e3319 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1512,14 +1512,12 @@ class WorkflowService: # Don't use workflow.tool_published as it's not accurate for specific workflow versions # Check if there's a tool provider using this specific workflow version - tool_provider = ( - session.query(WorkflowToolProvider) - .where( + tool_provider = session.scalar( + select(WorkflowToolProvider).where( WorkflowToolProvider.tenant_id == workflow.tenant_id, WorkflowToolProvider.app_id == workflow.app_id, WorkflowToolProvider.version == workflow.version, ) - .first() ) if tool_provider: diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index 0a73c91279..45e1f80e35 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -7,15 +7,16 @@ with appropriate retry policies and error handling. import logging from datetime import UTC, datetime -from typing import Any +from typing import Any, NotRequired from celery import shared_task from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from typing_extensions import TypedDict from configs import dify_config -from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.app.layers.timeslice_layer import TimeSliceLayer @@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf logger = logging.getLogger(__name__) +class WorkflowGeneratorArgsDict(TypedDict): + inputs: dict[str, Any] + files: list[Any] + _skip_prepare_user_inputs: bool + workflow_id: NotRequired[str] + + @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) def execute_workflow_professional(task_data_dict: dict[str, Any]): """Execute workflow for professional tier with highest priority""" @@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]): ) -def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]: +def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict: """Build args passed into WorkflowAppGenerator.generate for Celery executions.""" - - args: dict[str, Any] = { + return { "inputs": dict(trigger_data.inputs), "files": list(trigger_data.files), - SKIP_PREPARE_USER_INPUTS_KEY: True, + "_skip_prepare_user_inputs": True, } - return args def _execute_workflow_common( diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index fa844a8647..c9b5121a08 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task # type: ignore +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): with session_factory.create_session() as session: try: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] for segment in segments: @@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): # clean keywords index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) index_processor.load(dataset, documents, with_keywords=False) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() elif action == "update": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() # add new index if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() @@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] multimodal_documents = [] @@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): index_processor.load( dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False ) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() else: diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 23a80fa106..31dad7937c 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -5,6 +5,7 @@ from typing import Any, Protocol import click from celery import current_app, shared_task +from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory @@ -53,11 +54,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): Usage: _document_indexing(dataset_id, document_ids) """ - documents = [] start_at = time.perf_counter() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow")) return @@ -79,8 +79,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) except Exception as e: for document_id in document_ids: - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) if document: document.indexing_status = IndexingStatus.ERROR @@ -92,8 +92,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # Phase 1: Update status to parsing (short transaction) with session_factory.create_session() as session, session.begin(): - documents = ( - session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all() + documents: list[Document] = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) for document in documents: @@ -122,7 +124,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # Trigger summary index generation for completed documents if enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled # Re-query dataset to get latest summary_index_setting (in case it was updated) - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.warning("Dataset %s not found after indexing", dataset_id) return @@ -134,10 +136,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.expire_all() # Check each document's indexing status and trigger summary generation if completed - documents = ( - session.query(Document) - .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) - .all() + documents = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) for document in documents: diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index b1840662ff..72d824b8c1 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -6,7 +6,7 @@ from typing import Any, cast import click import sqlalchemy as sa from celery import shared_task -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker @@ -99,7 +99,11 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): def _delete_app_model_configs(tenant_id: str, app_id: str): def del_model_config(session, model_config_id: str): - session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) + session.execute( + delete(AppModelConfig) + .where(AppModelConfig.id == model_config_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from app_model_configs where app_id=:app_id limit 1000""", @@ -111,7 +115,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): def _delete_app_site(tenant_id: str, app_id: str): def del_site(session, site_id: str): - session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) + session.execute(delete(Site).where(Site.id == site_id).execution_options(synchronize_session=False)) _delete_records( """select id from sites where app_id=:app_id limit 1000""", @@ -123,7 +127,9 @@ def _delete_app_site(tenant_id: str, app_id: str): def _delete_app_mcp_servers(tenant_id: str, app_id: str): def del_mcp_server(session, mcp_server_id: str): - session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) + session.execute( + delete(AppMCPServer).where(AppMCPServer.id == mcp_server_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from app_mcp_servers where app_id=:app_id limit 1000""", @@ -136,12 +142,14 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str): def del_api_token(session, api_token_id: str): # Fetch token details for cache invalidation - token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first() + token_obj = session.scalar(select(ApiToken).where(ApiToken.id == api_token_id).limit(1)) if token_obj: # Invalidate cache before deletion ApiTokenCache.delete(token_obj.token, token_obj.type) - session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) + session.execute( + delete(ApiToken).where(ApiToken.id == api_token_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from api_tokens where app_id=:app_id limit 1000""", @@ -153,7 +161,9 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): def _delete_installed_apps(tenant_id: str, app_id: str): def del_installed_app(session, installed_app_id: str): - session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) + session.execute( + delete(InstalledApp).where(InstalledApp.id == installed_app_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -165,7 +175,11 @@ def _delete_installed_apps(tenant_id: str, app_id: str): def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(session, recommended_app_id: str): - session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False) + session.execute( + delete(RecommendedApp) + .where(RecommendedApp.id == recommended_app_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", @@ -177,8 +191,10 @@ def _delete_recommended_apps(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(session, annotation_hit_history_id: str): - session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( - synchronize_session=False + session.execute( + delete(AppAnnotationHitHistory) + .where(AppAnnotationHitHistory.id == annotation_hit_history_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -189,8 +205,10 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): ) def del_annotation_setting(session, annotation_setting_id: str): - session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( - synchronize_session=False + session.execute( + delete(AppAnnotationSetting) + .where(AppAnnotationSetting.id == annotation_setting_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -203,7 +221,11 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_dataset_joins(tenant_id: str, app_id: str): def del_dataset_join(session, dataset_join_id: str): - session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) + session.execute( + delete(AppDatasetJoin) + .where(AppDatasetJoin.id == dataset_join_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from app_dataset_joins where app_id=:app_id limit 1000""", @@ -215,7 +237,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): def _delete_app_workflows(tenant_id: str, app_id: str): def del_workflow(session, workflow_id: str): - session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) + session.execute(delete(Workflow).where(Workflow.id == workflow_id).execution_options(synchronize_session=False)) _delete_records( """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -255,7 +277,11 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(session, workflow_app_log_id: str): - session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowAppLog) + .where(WorkflowAppLog.id == workflow_app_log_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -267,8 +293,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): def del_workflow_archive_log(session, workflow_archive_log_id: str): - session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowArchiveLog) + .where(WorkflowArchiveLog.id == workflow_archive_log_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -306,10 +334,14 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(session, conversation_id: str): - session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False + session.execute( + delete(PinnedConversation) + .where(PinnedConversation.conversation_id == conversation_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(Conversation).where(Conversation.id == conversation_id).execution_options(synchronize_session=False) ) - session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", @@ -329,17 +361,35 @@ def _delete_conversation_variables(*, app_id: str): def _delete_app_messages(tenant_id: str, app_id: str): def del_message(session, message_id: str): - session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False) - session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( - synchronize_session=False + session.execute( + delete(MessageFeedback) + .where(MessageFeedback.message_id == message_id) + .execution_options(synchronize_session=False) ) - session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) - session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False + session.execute( + delete(MessageAnnotation) + .where(MessageAnnotation.message_id == message_id) + .execution_options(synchronize_session=False) ) - session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) - session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) - session.query(Message).where(Message.id == message_id).delete() + session.execute( + delete(MessageChain) + .where(MessageChain.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(MessageAgentThought) + .where(MessageAgentThought.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(MessageFile).where(MessageFile.message_id == message_id).execution_options(synchronize_session=False) + ) + session.execute( + delete(SavedMessage) + .where(SavedMessage.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute(delete(Message).where(Message.id == message_id).execution_options(synchronize_session=False)) _delete_records( """select id from messages where app_id=:app_id limit 1000""", @@ -351,8 +401,10 @@ def _delete_app_messages(tenant_id: str, app_id: str): def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(session, tool_provider_id: str): - session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowToolProvider) + .where(WorkflowToolProvider.id == tool_provider_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -365,7 +417,9 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def _delete_app_tag_bindings(tenant_id: str, app_id: str): def del_tag_binding(session, tag_binding_id: str): - session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) + session.execute( + delete(TagBinding).where(TagBinding.id == tag_binding_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", @@ -377,7 +431,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): def _delete_end_users(tenant_id: str, app_id: str): def del_end_user(session, end_user_id: str): - session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) + session.execute(delete(EndUser).where(EndUser.id == end_user_id).execution_options(synchronize_session=False)) _delete_records( """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -389,7 +443,11 @@ def _delete_end_users(tenant_id: str, app_id: str): def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(session, trace_app_config_id: str): - session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False) + session.execute( + delete(TraceAppConfig) + .where(TraceAppConfig.id == trace_app_config_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", @@ -545,7 +603,9 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int: def _delete_app_triggers(tenant_id: str, app_id: str): def del_app_trigger(session, trigger_id: str): - session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) + session.execute( + delete(AppTrigger).where(AppTrigger.id == trigger_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -557,8 +617,10 @@ def _delete_app_triggers(tenant_id: str, app_id: str): def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def del_plugin_trigger(session, trigger_id: str): - session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowPluginTrigger) + .where(WorkflowPluginTrigger.id == trigger_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -571,8 +633,10 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def del_webhook_trigger(session, trigger_id: str): - session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowWebhookTrigger) + .where(WorkflowWebhookTrigger.id == trigger_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -585,7 +649,11 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def del_schedule_plan(session, plan_id: str): - session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowSchedulePlan) + .where(WorkflowSchedulePlan.id == plan_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -597,7 +665,11 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): def del_trigger_log(session, log_id: str): - session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowTriggerLog) + .where(WorkflowTriggerLog.id == log_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 44adadeaa5..b2e8dda443 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -8,6 +8,7 @@ from collections.abc import Generator import pytest from flask import Flask from flask.testing import FlaskClient +from sqlalchemy import delete, select from sqlalchemy.orm import Session from app_factory import create_app @@ -83,15 +84,15 @@ def setup_account(request) -> Generator[Account, None, None]: with _CACHED_APP.test_request_context(): with Session(bind=db.engine, expire_on_commit=False) as session: - account = session.query(Account).filter_by(email=email).one() + account = session.scalars(select(Account).filter_by(email=email)).one() yield account with _CACHED_APP.test_request_context(): - db.session.query(DifySetup).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Account).delete() - db.session.query(Tenant).delete() + db.session.execute(delete(DifySetup)) + db.session.execute(delete(TenantAccountJoin)) + db.session.execute(delete(Account)) + db.session.execute(delete(Tenant)) db.session.commit() diff --git a/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py b/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py index 951a5ab4b4..0a19debc39 100644 --- a/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py +++ b/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py @@ -1,5 +1,5 @@ import pytest -from sqlalchemy import delete +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from models import Tenant @@ -61,7 +61,11 @@ class TestPluginPermissionLifecycle: assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS with session_factory.create_session() as session: - count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count() + count = session.scalar( + select(func.count()) + .select_from(TenantPluginPermission) + .where(TenantPluginPermission.tenant_id == tenant) + ) assert count == 1 diff --git a/api/tests/integration_tests/services/retention/test_messages_clean_service.py b/api/tests/integration_tests/services/retention/test_messages_clean_service.py index 348bb0af4a..352960bcc2 100644 --- a/api/tests/integration_tests/services/retention/test_messages_clean_service.py +++ b/api/tests/integration_tests/services/retention/test_messages_clean_service.py @@ -3,7 +3,7 @@ import math import uuid import pytest -from sqlalchemy import delete +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from models import Tenant @@ -210,7 +210,7 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 with session_factory.create_session() as session: - remaining = session.query(Message).where(Message.id.in_(all_ids)).count() + remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids))) assert remaining == len(all_ids) def test_billing_disabled_deletes_all_in_range(self, seed_messages): @@ -231,7 +231,7 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == len(all_ids) with session_factory.create_session() as session: - remaining = session.query(Message).where(Message.id.in_(all_ids)).count() + remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids))) assert remaining == 0 def test_start_from_filters_correctly(self, seed_messages): @@ -254,7 +254,7 @@ class TestMessagesCleanServiceIntegration: with session_factory.create_session() as session: all_ids = list(msg_ids.values()) - remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()} + remaining_ids = set(session.scalars(select(Message.id).where(Message.id.in_(all_ids))).all()) assert msg_ids["old"] not in remaining_ids assert msg_ids["very_old"] in remaining_ids @@ -282,7 +282,7 @@ class TestMessagesCleanServiceIntegration: assert stats["batches"] >= expected_batches with session_factory.create_session() as session: - remaining = session.query(Message).where(Message.id.in_(msg_ids)).count() + remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(msg_ids))) assert remaining == 0 def test_no_messages_in_range_returns_empty_stats(self, seed_messages): @@ -319,9 +319,17 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 with session_factory.create_session() as session: - assert session.query(Message).where(Message.id == msg_id).count() == 0 - assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0 - assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0 + assert session.scalar(select(func.count()).select_from(Message).where(Message.id == msg_id)) == 0 + assert ( + session.scalar(select(func.count()).select_from(MessageFeedback).where(MessageFeedback.id == fb_id)) + == 0 + ) + assert ( + session.scalar( + select(func.count()).select_from(MessageAnnotation).where(MessageAnnotation.id == ann_id) + ) + == 0 + ) def test_factory_from_time_range_validation(self): with pytest.raises(ValueError, match="start_from"): diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 5c6636f31e..c7bb90f019 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -7,7 +7,7 @@ from graphon.nodes import BuiltinNodeTypes from graphon.variables.segments import StringSegment from graphon.variables.types import SegmentType from graphon.variables.variables import StringVariable -from sqlalchemy import delete +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID @@ -38,21 +38,25 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def setUp(self): self._test_app_id = str(uuid.uuid4()) + self._test_user_id = str(uuid.uuid4()) self._session: Session = db.session() sys_var = WorkflowDraftVariable.new_sys_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="sys_var", value=build_segment("sys_value"), node_execution_id=self._node_exec_id, ) conv_var = WorkflowDraftVariable.new_conversation_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="conv_var", value=build_segment("conv_value"), ) node2_vars = [ WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node2_id, name="int_var", value=build_segment(1), @@ -61,6 +65,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): ), WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node2_id, name="str_var", value=build_segment("str_value"), @@ -70,6 +75,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): ] node1_var = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node1_id, name="str_var", value=build_segment("str_value"), @@ -141,24 +147,27 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_delete_node_variables(self): srv = self._get_test_srv() srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id) - node2_var_count = ( - self._session.query(WorkflowDraftVariable) + node2_var_count = self._session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == self._test_app_id, WorkflowDraftVariable.node_id == self._node2_id, + WorkflowDraftVariable.user_id == self._test_user_id, ) - .count() ) assert node2_var_count == 0 def test_delete_variable(self): srv = self._get_test_srv() - node_1_var = ( - self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one() - ) + node_1_var = self._session.scalars( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id) + ).one() srv.delete_variable(node_1_var) exists = bool( - self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first() + self._session.scalars( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id) + ).first() ) assert exists is False @@ -248,9 +257,7 @@ class TestDraftVariableLoader(unittest.TestCase): def tearDown(self): with Session(bind=db.engine, expire_on_commit=False) as session: - session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete( - synchronize_session=False - ) + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id)) session.commit() def test_variable_loader_with_empty_selector(self): @@ -431,9 +438,11 @@ class TestDraftVariableLoader(unittest.TestCase): # Clean up with Session(bind=db.engine) as session: # Query and delete by ID to ensure they're tracked in this session - session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() - session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() - session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id)) + session.execute( + delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id) + ) + session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id)) session.commit() # Clean up storage try: @@ -534,9 +543,11 @@ class TestDraftVariableLoader(unittest.TestCase): # Clean up with Session(bind=db.engine) as session: # Query and delete by ID to ensure they're tracked in this session - session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() - session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() - session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id)) + session.execute( + delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id) + ) + session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id)) session.commit() # Clean up storage try: diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 38dc8bbb28..3dfedd811d 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -3,7 +3,7 @@ from unittest.mock import patch import pytest from graphon.variables.segments import StringSegment -from sqlalchemy import delete +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType @@ -108,8 +108,12 @@ class TestDeleteDraftVariablesIntegration: app2_id = data["app2"].id with session_factory.create_session() as session: - app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + app1_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) + app2_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id) + ) assert app1_vars_before == 5 assert app2_vars_before == 5 @@ -117,8 +121,12 @@ class TestDeleteDraftVariablesIntegration: assert deleted_count == 5 with session_factory.create_session() as session: - app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + app1_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) + app2_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id) + ) assert app1_vars_after == 0 assert app2_vars_after == 5 @@ -130,7 +138,9 @@ class TestDeleteDraftVariablesIntegration: assert deleted_count == 5 with session_factory.create_session() as session: - remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + remaining_vars = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) assert remaining_vars == 0 def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data): @@ -143,14 +153,18 @@ class TestDeleteDraftVariablesIntegration: app1_id = data["app1"].id with session_factory.create_session() as session: - vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) assert vars_before == 5 deleted_count = _delete_draft_variables(app1_id) assert deleted_count == 5 with session_factory.create_session() as session: - vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) assert vars_after == 0 def test_batch_deletion_handles_large_dataset(self, app_and_tenant): @@ -175,7 +189,9 @@ class TestDeleteDraftVariablesIntegration: deleted_count = delete_draft_variables_batch(app.id, batch_size=8) assert deleted_count == 25 with session_factory.create_session() as session: - remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count() + remaining = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app.id) + ) assert remaining == 0 finally: with session_factory.create_session() as session: @@ -307,13 +323,17 @@ class TestDeleteDraftVariablesWithOffloadIntegration: mock_storage.delete.return_value = None with session_factory.create_session() as session: - draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = ( - session.query(WorkflowDraftVariableFile) - .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) - .count() + draft_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) + var_files_before = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + ) + upload_files_before = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) ) - upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -322,16 +342,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert deleted_count == 3 with session_factory.create_session() as session: - draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + draft_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = ( - session.query(WorkflowDraftVariableFile) + var_files_after = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) - .count() ) - upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + upload_files_after = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) assert var_files_after == 0 assert upload_files_after == 0 @@ -352,16 +376,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert deleted_count == 3 with session_factory.create_session() as session: - draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + draft_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = ( - session.query(WorkflowDraftVariableFile) + var_files_after = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) - .count() ) - upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + upload_files_after = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) assert var_files_after == 0 assert upload_files_after == 0 @@ -579,7 +607,9 @@ class TestDeleteDraftVariablesSessionCommit: # Verify all data was deleted (proves transaction was committed) with session_factory.create_session() as session: - remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + remaining_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert deleted_count == 10 assert remaining_count == 0 @@ -592,7 +622,9 @@ class TestDeleteDraftVariablesSessionCommit: # Verify initial state with session_factory.create_session() as session: - initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + initial_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert initial_count == 10 # Perform deletion with small batch size to force multiple commits @@ -602,13 +634,17 @@ class TestDeleteDraftVariablesSessionCommit: # Verify all data is deleted in a new session (proves commits worked) with session_factory.create_session() as session: - final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + final_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert final_count == 0 # Verify specific IDs are deleted with session_factory.create_session() as session: - remaining_vars = ( - session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count() + remaining_vars = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) ) assert remaining_vars == 0 @@ -626,7 +662,9 @@ class TestDeleteDraftVariablesSessionCommit: app_id = data["app"].id with session_factory.create_session() as session: - initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + initial_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert initial_count == 10 # Delete all in a single batch @@ -635,7 +673,9 @@ class TestDeleteDraftVariablesSessionCommit: # Verify data is persisted with session_factory.create_session() as session: - final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + final_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert final_count == 0 def test_invalid_batch_size_raises_error(self, setup_commit_test_data): @@ -659,13 +699,17 @@ class TestDeleteDraftVariablesSessionCommit: # Verify initial state with session_factory.create_session() as session: - draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = ( - session.query(WorkflowDraftVariableFile) - .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) - .count() + draft_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) + var_files_before = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + ) + upload_files_before = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) ) - upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -676,13 +720,17 @@ class TestDeleteDraftVariablesSessionCommit: # Verify all data is persisted (deleted) in new session with session_factory.create_session() as session: - draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_after = ( - session.query(WorkflowDraftVariableFile) - .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) - .count() + draft_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) + var_files_after = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + ) + upload_files_after = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) ) - upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_after == 0 assert var_files_after == 0 assert upload_files_after == 0 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 879c337319..320da85b60 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 7b7393dade..d2703ed5cc 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") def test_reset_fetches_account_with_original_email( self, mock_get_reset_data, mock_revoke_token, + mock_db, mock_get_account, mock_update_account, app, @@ -126,6 +128,7 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account wraps_features = SimpleNamespace(enable_email_password_login=True) with ( @@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index a2f1328579..1eabb45422 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -437,7 +437,10 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 8f9db287e3..50249bcd74 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -335,10 +335,12 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, + mock_db, mock_get_account, mock_revoke_token, mock_get_data, @@ -356,6 +358,7 @@ class TestForgotPasswordResetApi: # Arrange mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 04ad143103..f14b2c0ae5 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi: @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") - @patch("controllers.web.forgot_password.sessionmaker") def test_should_normalize_email_before_sending( self, - mock_session_cls, mock_extract_ip, mock_rate_limit, mock_get_account, @@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi: mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_mail.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password", - method="POST", - json={"email": "User@Example.com", "language": "zh-Hans"}, - ): - response = ForgotPasswordSendEmailApi().post() + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -153,14 +148,14 @@ class TestForgotPasswordResetApi: @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") def test_should_fetch_account_with_fallback( self, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_get_account, mock_update_account, app, @@ -168,29 +163,27 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = mock_account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "token-123", - "new_password": "ValidPass123!", - "password_confirm": "ValidPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -199,7 +192,7 @@ class TestForgotPasswordResetApi: mock_get_account, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_token_bytes, mock_hash_password, app, @@ -207,20 +200,18 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() mock_get_account.return_value = account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("reset-token") diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 2b4c1b59ab..c9ee67863d 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -557,11 +557,9 @@ class TestPauseStatePersistenceLayerTestContainers: self.session.refresh(self.test_workflow_run) assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING - pause_states = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) - .all() - ) + pause_states = self.session.scalars( + select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) + ).all() assert len(pause_states) == 0 def test_layer_requires_initialization(self, db_session_with_containers): diff --git a/api/tests/test_containers_integration_tests/models/test_account.py b/api/tests/test_containers_integration_tests/models/test_account.py index 078dc0e8de..6fd6716cbb 100644 --- a/api/tests/test_containers_integration_tests/models/test_account.py +++ b/api/tests/test_containers_integration_tests/models/test_account.py @@ -1,79 +1,202 @@ -# import secrets +""" +Integration tests for Account and Tenant model methods that interact with the database. -# import pytest -# from sqlalchemy import select -# from sqlalchemy.orm import Session -# from sqlalchemy.orm.exc import DetachedInstanceError +Migrated from unit_tests/models/test_account_models.py, replacing +@patch("models.account.db") mock patches with real PostgreSQL operations. -# from libs.datetime_utils import naive_utc_now -# from models.account import Account, Tenant, TenantAccountJoin +Covers: +- Account.current_tenant setter (sets _current_tenant and role from TenantAccountJoin) +- Account.set_tenant_id (resolves tenant + role from real join row) +- Account.get_by_openid (AccountIntegrate lookup then Account fetch) +- Tenant.get_accounts (returns accounts linked via TenantAccountJoin) +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy import delete +from sqlalchemy.orm import Session + +from models.account import Account, AccountIntegrate, Tenant, TenantAccountJoin, TenantAccountRole -# @pytest.fixture -# def session(db_session_with_containers): -# with Session(db_session_with_containers.get_bind()) as session: -# yield session +def _cleanup_tracked_rows(db_session: Session, tracked: list) -> None: + """Delete rows tracked during the test so committed state does not leak into the DB. + + Rolls back any pending (uncommitted) session state first, then issues DELETE + statements by primary key for each tracked entity (in reverse creation order) + and commits. This cleans up rows created via either flush() or commit(). + """ + db_session.rollback() + for entity in reversed(tracked): + db_session.execute(delete(type(entity)).where(type(entity).id == entity.id)) + db_session.commit() -# @pytest.fixture -# def account(session): -# account = Account( -# name="test account", -# email=f"test_{secrets.token_hex(8)}@example.com", -# ) -# session.add(account) -# session.commit() -# return account +def _build_tenant() -> Tenant: + return Tenant(name=f"Tenant {uuid4()}") -# @pytest.fixture -# def tenant(session): -# tenant = Tenant(name="test tenant") -# session.add(tenant) -# session.commit() -# return tenant +def _build_account(email_prefix: str = "account") -> Account: + return Account( + name=f"Account {uuid4()}", + email=f"{email_prefix}_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) -# @pytest.fixture -# def tenant_account_join(session, account, tenant): -# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id) -# session.add(tenant_join) -# session.commit() -# yield tenant_join -# session.delete(tenant_join) -# session.commit() +class _DBTrackingTestBase: + """Base class providing a tracker list and shared row factories for account/tenant tests.""" + + _tracked: list + + @pytest.fixture(autouse=True) + def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]: + self._tracked = [] + yield + _cleanup_tracked_rows(db_session_with_containers, self._tracked) + + def _create_tenant(self, db_session: Session) -> Tenant: + tenant = _build_tenant() + db_session.add(tenant) + db_session.flush() + self._tracked.append(tenant) + return tenant + + def _create_account(self, db_session: Session, email_prefix: str = "account") -> Account: + account = _build_account(email_prefix) + db_session.add(account) + db_session.flush() + self._tracked.append(account) + return account + + def _create_join( + self, db_session: Session, tenant_id: str, account_id: str, role: TenantAccountRole, current: bool = True + ) -> TenantAccountJoin: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id, role=role, current=current) + db_session.add(join) + db_session.flush() + self._tracked.append(join) + return join -# class TestAccountTenant: -# def test_set_current_tenant_should_reload_tenant( -# self, -# db_session_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session: -# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one() -# account.current_tenant = scoped_tenant -# scoped_tenant.created_at = naive_utc_now() -# # session.commit() +class TestAccountCurrentTenantSetter(_DBTrackingTestBase): + """Integration tests for Account.current_tenant property setter.""" -# # Ensure the tenant used in assignment is detached. -# with pytest.raises(DetachedInstanceError): -# _ = scoped_tenant.name + def test_current_tenant_property_returns_cached_tenant(self, db_session_with_containers: Session) -> None: + """current_tenant getter returns the in-memory _current_tenant without DB access.""" + account = self._create_account(db_session_with_containers) + tenant = self._create_tenant(db_session_with_containers) + account._current_tenant = tenant -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id + assert account.current_tenant is tenant -# def test_set_tenant_id_should_load_tenant_as_not_expire( -# self, -# flask_app_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with flask_app_with_containers.test_request_context(): -# account.set_tenant_id(tenant.id) + def test_current_tenant_setter_sets_tenant_and_role_when_join_exists( + self, db_session_with_containers: Session + ) -> None: + """Setting current_tenant loads the join row and assigns role when relationship exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.OWNER) + db_session_with_containers.commit() -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id + account.current_tenant = tenant + + assert account._current_tenant is not None + assert account._current_tenant.id == tenant.id + assert account.role == TenantAccountRole.OWNER + + def test_current_tenant_setter_sets_none_when_no_join_exists(self, db_session_with_containers: Session) -> None: + """Setting current_tenant results in _current_tenant=None when no join row exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + db_session_with_containers.commit() + + account.current_tenant = tenant + + assert account._current_tenant is None + + +class TestAccountSetTenantId(_DBTrackingTestBase): + """Integration tests for Account.set_tenant_id method.""" + + def test_set_tenant_id_sets_tenant_and_role_when_relationship_exists( + self, db_session_with_containers: Session + ) -> None: + """set_tenant_id loads the tenant and assigns role when a join row exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.ADMIN) + db_session_with_containers.commit() + + account.set_tenant_id(tenant.id) + + assert account._current_tenant is not None + assert account._current_tenant.id == tenant.id + assert account.role == TenantAccountRole.ADMIN + + def test_set_tenant_id_does_not_set_tenant_when_no_relationship_exists( + self, db_session_with_containers: Session + ) -> None: + """set_tenant_id does nothing when no join row matches the tenant.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + db_session_with_containers.commit() + + account.set_tenant_id(tenant.id) + + assert account._current_tenant is None + + +class TestAccountGetByOpenId(_DBTrackingTestBase): + """Integration tests for Account.get_by_openid class method.""" + + def test_get_by_openid_returns_account_when_integrate_exists(self, db_session_with_containers: Session) -> None: + """get_by_openid returns the Account when a matching AccountIntegrate row exists.""" + account = self._create_account(db_session_with_containers, email_prefix="openid") + provider = "google" + open_id = f"google_{uuid4()}" + + integrate = AccountIntegrate( + account_id=account.id, + provider=provider, + open_id=open_id, + encrypted_token="token", + ) + db_session_with_containers.add(integrate) + db_session_with_containers.flush() + self._tracked.append(integrate) + + result = Account.get_by_openid(provider, open_id) + + assert result is not None + assert result.id == account.id + + def test_get_by_openid_returns_none_when_no_integrate_exists(self, db_session_with_containers: Session) -> None: + """get_by_openid returns None when no AccountIntegrate row matches.""" + result = Account.get_by_openid("github", f"github_{uuid4()}") + + assert result is None + + +class TestTenantGetAccounts(_DBTrackingTestBase): + """Integration tests for Tenant.get_accounts method.""" + + def test_get_accounts_returns_linked_accounts(self, db_session_with_containers: Session) -> None: + """get_accounts returns all accounts linked to the tenant via TenantAccountJoin.""" + tenant = self._create_tenant(db_session_with_containers) + account1 = self._create_account(db_session_with_containers, email_prefix="tenant_member") + account2 = self._create_account(db_session_with_containers, email_prefix="tenant_member") + self._create_join(db_session_with_containers, tenant.id, account1.id, TenantAccountRole.OWNER, current=False) + self._create_join(db_session_with_containers, tenant.id, account2.id, TenantAccountRole.NORMAL, current=False) + + accounts = tenant.get_accounts() + + assert len(accounts) == 2 + account_ids = {a.id for a in accounts} + assert account1.id in account_ids + assert account2.id in account_ids diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py new file mode 100644 index 0000000000..e922c19a5a --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py @@ -0,0 +1,149 @@ +""" +Integration tests for Conversation.inputs and Message.inputs tenant resolution. + +Migrated from unit_tests/models/test_model.py, replacing db.session.scalar monkeypatching +with a real App in PostgreSQL so the _resolve_app_tenant_id lookup executes against the DB. +""" + +from collections.abc import Generator +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod +from sqlalchemy.orm import Session + +from core.workflow.file_reference import build_file_reference +from models.model import App, AppMode, Conversation, Message + + +def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) -> dict: + mapping: dict = { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "reference": build_file_reference(record_id=record_id), + "type": "document", + "filename": "example.txt", + "extension": ".txt", + "mime_type": "text/plain", + "size": 1, + } + if tenant_id is not None: + mapping["tenant_id"] = tenant_id + return mapping + + +class TestConversationMessageInputsTenantResolution: + """Integration tests for Conversation/Message.inputs tenant resolution via real DB lookup.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_app(self, db_session: Session) -> App: + tenant_id = str(uuid4()) + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.CHAT, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=str(uuid4()), + updated_by=str(uuid4()), + ) + db_session.add(app) + db_session.flush() + return app + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_resolves_tenant_via_db_for_local_file( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs resolves tenant_id from real App row when file mapping has no tenant_id.""" + app = self._create_app(db_session_with_containers) + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = {"file": _build_local_file_mapping("upload-1")} + + restored_inputs = owner.inputs + + # The tenant_id should come from the real App row in the DB + assert restored_inputs["file"] == {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"} + assert len(build_calls) == 1 + assert build_calls[0][1] == app.tenant_id + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_uses_serialized_tenant_id_skipping_db_lookup( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs uses tenant_id from the file mapping payload without hitting the DB.""" + app = self._create_app(db_session_with_containers) + payload_tenant_id = "tenant-from-payload" + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id=payload_tenant_id)} + + restored_inputs = owner.inputs + + assert restored_inputs["file"] == {"tenant_id": payload_tenant_id, "upload_file_id": "upload-1"} + assert len(build_calls) == 1 + assert build_calls[0][1] == payload_tenant_id + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_resolves_tenant_for_file_list( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs resolves tenant_id for a list of file mappings.""" + app = self._create_app(db_session_with_containers) + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = { + "files": [ + _build_local_file_mapping("upload-1"), + _build_local_file_mapping("upload-2"), + ] + } + + restored_inputs = owner.inputs + + assert len(build_calls) == 2 + assert all(call[1] == app.tenant_id for call in build_calls) + assert restored_inputs["files"] == [ + {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"}, + {"tenant_id": app.tenant_id, "upload_file_id": "upload-2"}, + ] diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py new file mode 100644 index 0000000000..4ca87de52d --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py @@ -0,0 +1,314 @@ +""" +Integration tests for Conversation.status_count and Site.generate_code model properties. + +Migrated from unit_tests/models/test_app_models.py TestConversationStatusCount and +test_site_generate_code, replacing db.session.scalars mocks with real PostgreSQL queries. +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from graphon.enums import WorkflowExecutionStatus +from sqlalchemy.orm import Session + +from models.enums import ConversationFromSource, InvokeFrom +from models.model import App, AppMode, Conversation, Message, Site +from models.workflow import Workflow, WorkflowRun, WorkflowRunTriggeredFrom, WorkflowType + + +class TestConversationStatusCount: + """Integration tests for Conversation.status_count property.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App: + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.ADVANCED_CHAT, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=created_by, + updated_by=created_by, + ) + db_session.add(app) + db_session.flush() + return app + + def _create_conversation(self, db_session: Session, app: App) -> Conversation: + conversation = Conversation( + app_id=app.id, + mode=app.mode, + name=f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.WEB_APP, + from_source=ConversationFromSource.API, + dialogue_count=0, + is_deleted=False, + ) + conversation.inputs = {} + db_session.add(conversation) + db_session.flush() + return conversation + + def _create_workflow(self, db_session: Session, app: App, created_by: str) -> Workflow: + workflow = Workflow( + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.CHAT, + version="draft", + graph="{}", + created_by=created_by, + ) + workflow._features = "{}" + db_session.add(workflow) + db_session.flush() + return workflow + + def _create_workflow_run( + self, db_session: Session, app: App, workflow: Workflow, status: WorkflowExecutionStatus, created_by: str + ) -> WorkflowRun: + run = WorkflowRun( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type=WorkflowType.CHAT, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + version="draft", + status=status, + created_by_role="account", + created_by=created_by, + ) + db_session.add(run) + db_session.flush() + return run + + def _create_message( + self, db_session: Session, app: App, conversation: Conversation, workflow_run_id: str | None = None + ) -> Message: + message = Message( + app_id=app.id, + conversation_id=conversation.id, + _inputs={}, + query="Test query", + message={"role": "user", "content": "Test query"}, + answer="Test answer", + model_provider="openai", + model_id="gpt-4", + message_tokens=10, + message_unit_price=0, + answer_tokens=10, + answer_unit_price=0, + total_price=0, + currency="USD", + from_source=ConversationFromSource.API, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id=workflow_run_id, + ) + db_session.add(message) + db_session.flush() + return message + + def test_status_count_returns_none_when_no_messages(self, db_session_with_containers: Session) -> None: + """status_count returns None when conversation has no messages with workflow_run_id.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + + result = conversation.status_count + + assert result is None + + def test_status_count_returns_none_when_messages_have_no_workflow_run_id( + self, db_session_with_containers: Session + ) -> None: + """status_count returns None when messages exist but none have workflow_run_id.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=None) + + result = conversation.status_count + + assert result is None + + def test_status_count_counts_succeeded_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts succeeded workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 1 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_counts_failed_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts failed workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.FAILED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 0 + assert result["failed"] == 1 + assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_counts_paused_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts paused workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.PAUSED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + assert result["paused"] == 1 + + def test_status_count_multiple_statuses(self, db_session_with_containers: Session) -> None: + """status_count counts multiple workflow runs with different statuses.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + + for status in [ + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + WorkflowExecutionStatus.PAUSED, + ]: + run = self._create_workflow_run(db_session_with_containers, app, workflow, status, created_by) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 1 + assert result["failed"] == 1 + assert result["partial_success"] == 1 + assert result["paused"] == 1 + + def test_status_count_filters_workflow_runs_by_app_id(self, db_session_with_containers: Session) -> None: + """status_count excludes workflow runs belonging to a different app.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + other_app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, other_app, created_by) + + # Workflow run belongs to other_app, not app + other_run = self._create_workflow_run( + db_session_with_containers, other_app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by + ) + # Message references that run but is in a conversation under app + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=other_run.id) + + result = conversation.status_count + + # The run should be excluded because app_id filter doesn't match + assert result is not None + assert result["success"] == 0 + + +class TestSiteGenerateCode: + """Integration tests for Site.generate_code static method.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def test_generate_code_returns_string_of_correct_length(self, db_session_with_containers: Session) -> None: + """Site.generate_code returns a code string of the requested length.""" + code = Site.generate_code(8) + + assert isinstance(code, str) + assert len(code) == 8 + + def test_generate_code_avoids_duplicates(self, db_session_with_containers: Session) -> None: + """Site.generate_code returns a code not already in use.""" + tenant_id = str(uuid4()) + app = App( + tenant_id=tenant_id, + name="Test App", + mode=AppMode.CHAT, + enable_site=True, + enable_api=False, + is_demo=False, + is_public=False, + is_universal=False, + created_by=str(uuid4()), + updated_by=str(uuid4()), + ) + db_session_with_containers.add(app) + db_session_with_containers.flush() + + site = Site( + app_id=app.id, + title="Test Site", + default_language="en-US", + customize_token_strategy="not_allow", + ) + # Set an explicit code so generate_code must avoid it + site.code = "AAAAAAAA" + db_session_with_containers.add(site) + db_session_with_containers.flush() + + code = Site.generate_code(8) + + assert isinstance(code, str) + assert len(code) == 8 + assert code != site.code diff --git a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py index 8aec6b6acc..957b7145d3 100644 --- a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py +++ b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py @@ -6,7 +6,7 @@ import pytest import sqlalchemy as sa from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import exc as sa_exc -from sqlalchemy import insert +from sqlalchemy import insert, select from sqlalchemy.engine import Connection, Engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.sql.sqltypes import VARCHAR @@ -137,12 +137,12 @@ class TestEnumText: session.commit() with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == admin_user_id).first() + user = session.scalar(select(_User).where(_User.id == admin_user_id).limit(1)) assert user.user_type == _UserType.admin assert user.user_type_nullable is None with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == normal_user_id).first() + user = session.scalar(select(_User).where(_User.id == normal_user_id).limit(1)) assert user.user_type == _UserType.normal assert user.user_type_nullable == _UserType.normal @@ -206,7 +206,7 @@ class TestEnumText: with pytest.raises(ValueError) as exc: with Session(engine_with_containers) as session: - _user = session.query(_User).where(_User.id == 1).first() + _user = session.scalar(select(_User).where(_User.id == 1).limit(1)) assert str(exc.value) == "'invalid' is not a valid _UserType" @@ -222,7 +222,7 @@ class TestEnumText: session.commit() with Session(engine_with_containers) as session: - records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all() + records = session.scalars(select(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id)).all() assert [record.model_type for record in records] == [ ModelType.LLM, diff --git a/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py b/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py new file mode 100644 index 0000000000..14c2263110 --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py @@ -0,0 +1,170 @@ +""" +Integration tests for WorkflowNodeExecutionModel.created_by_account and .created_by_end_user. + +Migrated from unit_tests/models/test_workflow_trigger_log.py, replacing +monkeypatch.setattr(db.session, "scalar", ...) with real Account/EndUser rows +persisted in PostgreSQL so the db.session.get() call executes against the DB. +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account +from models.enums import CreatorUserRole +from models.model import App, AppMode, EndUser +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom + + +class TestWorkflowNodeExecutionModelCreatedBy: + """Integration tests for WorkflowNodeExecutionModel creator lookup properties.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_account(self, db_session: Session) -> Account: + account = Account( + name="Test Account", + email=f"test_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + return account + + def _create_end_user(self, db_session: Session, tenant_id: str, app_id: str) -> EndUser: + end_user = EndUser( + tenant_id=tenant_id, + app_id=app_id, + type="service_api", + external_user_id=f"ext-{uuid4()}", + name="End User", + session_id=f"session-{uuid4()}", + ) + end_user.is_anonymous = False + db_session.add(end_user) + db_session.flush() + return end_user + + def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App: + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.WORKFLOW, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=created_by, + updated_by=created_by, + ) + db_session.add(app) + db_session.flush() + return app + + def _make_execution( + self, tenant_id: str, app_id: str, created_by_role: str, created_by: str + ) -> WorkflowNodeExecutionModel: + return WorkflowNodeExecutionModel( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=str(uuid4()), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=None, + index=1, + predecessor_node_id=None, + node_execution_id=None, + node_id="n1", + node_type="start", + title="Start", + inputs=None, + process_data=None, + outputs=None, + status="succeeded", + error=None, + elapsed_time=0.0, + execution_metadata=None, + created_by_role=created_by_role, + created_by=created_by, + ) + + def test_created_by_account_returns_account_when_role_is_account(self, db_session_with_containers: Session) -> None: + """created_by_account returns the Account row when role is ACCOUNT.""" + account = self._create_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, str(uuid4()), account.id) + + execution = self._make_execution( + tenant_id=app.tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + ) + + result = execution.created_by_account + + assert result is not None + assert result.id == account.id + + def test_created_by_account_returns_none_when_role_is_end_user(self, db_session_with_containers: Session) -> None: + """created_by_account returns None when role is END_USER, even if an Account exists.""" + account = self._create_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, str(uuid4()), account.id) + + execution = self._make_execution( + tenant_id=app.tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.END_USER.value, + created_by=account.id, + ) + + result = execution.created_by_account + + assert result is None + + def test_created_by_end_user_returns_end_user_when_role_is_end_user( + self, db_session_with_containers: Session + ) -> None: + """created_by_end_user returns the EndUser row when role is END_USER.""" + account = self._create_account(db_session_with_containers) + tenant_id = str(uuid4()) + app = self._create_app(db_session_with_containers, tenant_id, account.id) + end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id) + + execution = self._make_execution( + tenant_id=tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.END_USER.value, + created_by=end_user.id, + ) + + result = execution.created_by_end_user + + assert result is not None + assert result.id == end_user.id + + def test_created_by_end_user_returns_none_when_role_is_account(self, db_session_with_containers: Session) -> None: + """created_by_end_user returns None when role is ACCOUNT, even if an EndUser exists.""" + account = self._create_account(db_session_with_containers) + tenant_id = str(uuid4()) + app = self._create_app(db_session_with_containers, tenant_id, account.id) + end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id) + + execution = self._make_execution( + tenant_id=tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=end_user.id, + ) + + result = execution.created_by_end_user + + assert result is None diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py new file mode 100644 index 0000000000..22e0aa34ff --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -0,0 +1,395 @@ +"""Testcontainers integration tests for SQLAlchemyWorkflowNodeExecutionRepository.""" + +from __future__ import annotations + +import json +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.model_runtime.utils.encoders import jsonable_encoder +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker + +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig +from models.account import Account, Tenant +from models.enums import CreatorUserRole +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom + + +def _create_account_with_tenant(session: Session) -> Account: + tenant = Tenant(name="Test Workspace") + session.add(tenant) + session.flush() + + account = Account(name="test", email=f"test-{uuid4()}@example.com") + session.add(account) + session.flush() + + account._current_tenant = tenant + return account + + +def _make_repo(session: Session, account: Account, app_id: str) -> SQLAlchemyWorkflowNodeExecutionRepository: + engine = session.get_bind() + assert isinstance(engine, Engine) + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=sessionmaker(bind=engine, expire_on_commit=False), + user=account, + app_id=app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def _create_node_execution_model( + session: Session, + *, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_run_id: str, + index: int = 1, + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING, +) -> WorkflowNodeExecutionModel: + model = WorkflowNodeExecutionModel( + id=str(uuid4()), + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=workflow_run_id, + index=index, + predecessor_node_id=None, + node_execution_id=str(uuid4()), + node_id=f"node-{index}", + node_type=BuiltinNodeTypes.START, + title=f"Test Node {index}", + inputs='{"input_key": "input_value"}', + process_data='{"process_key": "process_value"}', + outputs='{"output_key": "output_value"}', + status=status, + error=None, + elapsed_time=1.5, + execution_metadata="{}", + created_at=datetime.now(), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + finished_at=None, + ) + session.add(model) + session.flush() + return model + + +class TestSave: + def test_save_new_record(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + predecessor_node_id=None, + node_id="node-1", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs={"input_key": "input_value"}, + process_data={"process_key": "process_value"}, + outputs={"result": "success"}, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=1.5, + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}, + created_at=datetime.now(), + finished_at=None, + ) + + repo.save(execution) + + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session: + saved = verify_session.get(WorkflowNodeExecutionModel, execution.id) + assert saved is not None + assert saved.tenant_id == account.current_tenant_id + assert saved.app_id == app_id + assert saved.node_id == "node-1" + assert saved.status == WorkflowNodeExecutionStatus.RUNNING + + def test_save_updates_existing_record(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + repo = _make_repo(db_session_with_containers, account, str(uuid4())) + + execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + predecessor_node_id=None, + node_id="node-1", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs=None, + process_data=None, + outputs=None, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=0.0, + metadata=None, + created_at=datetime.now(), + finished_at=None, + ) + + repo.save(execution) + + execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + execution.elapsed_time = 2.5 + repo.save(execution) + + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session: + saved = verify_session.get(WorkflowNodeExecutionModel, execution.id) + assert saved is not None + assert saved.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert saved.elapsed_time == 2.5 + + +class TestGetByWorkflowExecution: + def test_returns_executions_ordered(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + tenant_id = account.current_tenant_id + app_id = str(uuid4()) + workflow_id = str(uuid4()) + workflow_run_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=2, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + db_session_with_containers.commit() + + order_config = OrderConfig(order_by=["index"], order_direction="desc") + result = repo.get_by_workflow_execution( + workflow_execution_id=workflow_run_id, + order_config=order_config, + ) + + assert len(result) == 2 + assert result[0].index == 2 + assert result[1].index == 1 + assert all(isinstance(r, WorkflowNodeExecution) for r in result) + + def test_excludes_paused_executions(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + tenant_id = account.current_tenant_id + app_id = str(uuid4()) + workflow_id = str(uuid4()) + workflow_run_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=1, + status=WorkflowNodeExecutionStatus.RUNNING, + ) + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=2, + status=WorkflowNodeExecutionStatus.PAUSED, + ) + db_session_with_containers.commit() + + result = repo.get_by_workflow_execution(workflow_execution_id=workflow_run_id) + + assert len(result) == 1 + assert result[0].index == 1 + + +class TestToDbModel: + def test_converts_domain_to_db_model(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + domain_model = WorkflowNodeExecution( + id="test-id", + workflow_id="test-workflow-id", + node_execution_id="test-node-execution-id", + workflow_execution_id="test-workflow-run-id", + index=1, + predecessor_node_id="test-predecessor-id", + node_id="test-node-id", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs={"input_key": "input_value"}, + process_data={"process_key": "process_value"}, + outputs={"output_key": "output_value"}, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=1.5, + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"), + }, + created_at=datetime.now(), + finished_at=None, + ) + + db_model = repo._to_db_model(domain_model) + + assert isinstance(db_model, WorkflowNodeExecutionModel) + assert db_model.id == domain_model.id + assert db_model.tenant_id == account.current_tenant_id + assert db_model.app_id == app_id + assert db_model.workflow_id == domain_model.workflow_id + assert db_model.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + assert db_model.workflow_run_id == domain_model.workflow_execution_id + assert db_model.index == domain_model.index + assert db_model.predecessor_node_id == domain_model.predecessor_node_id + assert db_model.node_execution_id == domain_model.node_execution_id + assert db_model.node_id == domain_model.node_id + assert db_model.node_type == domain_model.node_type + assert db_model.title == domain_model.title + assert db_model.inputs_dict == domain_model.inputs + assert db_model.process_data_dict == domain_model.process_data + assert db_model.outputs_dict == domain_model.outputs + assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata) + assert db_model.status == domain_model.status + assert db_model.error == domain_model.error + assert db_model.elapsed_time == domain_model.elapsed_time + assert db_model.created_at == domain_model.created_at + assert db_model.created_by_role == CreatorUserRole.ACCOUNT + assert db_model.created_by == account.id + assert db_model.finished_at == domain_model.finished_at + + +class TestToDomainModel: + def test_converts_db_to_domain_model(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + inputs_dict = {"input_key": "input_value"} + process_data_dict = {"process_key": "process_value"} + outputs_dict = {"output_key": "output_value"} + metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100} + now = datetime.now() + + db_model = WorkflowNodeExecutionModel() + db_model.id = "test-id" + db_model.tenant_id = account.current_tenant_id + db_model.app_id = app_id + db_model.workflow_id = "test-workflow-id" + db_model.triggered_from = "workflow-run" + db_model.workflow_run_id = "test-workflow-run-id" + db_model.index = 1 + db_model.predecessor_node_id = "test-predecessor-id" + db_model.node_execution_id = "test-node-execution-id" + db_model.node_id = "test-node-id" + db_model.node_type = BuiltinNodeTypes.START + db_model.title = "Test Node" + db_model.inputs = json.dumps(inputs_dict) + db_model.process_data = json.dumps(process_data_dict) + db_model.outputs = json.dumps(outputs_dict) + db_model.status = WorkflowNodeExecutionStatus.RUNNING + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.execution_metadata = json.dumps(metadata_dict) + db_model.created_at = now + db_model.created_by_role = "account" + db_model.created_by = account.id + db_model.finished_at = None + + domain_model = repo._to_domain_model(db_model) + + assert isinstance(domain_model, WorkflowNodeExecution) + assert domain_model.id == "test-id" + assert domain_model.workflow_id == "test-workflow-id" + assert domain_model.workflow_execution_id == "test-workflow-run-id" + assert domain_model.index == 1 + assert domain_model.predecessor_node_id == "test-predecessor-id" + assert domain_model.node_execution_id == "test-node-execution-id" + assert domain_model.node_id == "test-node-id" + assert domain_model.node_type == BuiltinNodeTypes.START + assert domain_model.title == "Test Node" + assert domain_model.inputs == inputs_dict + assert domain_model.process_data == process_data_dict + assert domain_model.outputs == outputs_dict + assert domain_model.status == WorkflowNodeExecutionStatus.RUNNING + assert domain_model.error is None + assert domain_model.elapsed_time == 1.5 + assert domain_model.metadata == {WorkflowNodeExecutionMetadataKey(k): v for k, v in metadata_dict.items()} + assert domain_model.created_at == now + assert domain_model.finished_at is None + + def test_domain_model_without_offload_data(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + repo = _make_repo(db_session_with_containers, account, str(uuid4())) + + process_data = {"normal": "data"} + db_model = WorkflowNodeExecutionModel() + db_model.id = str(uuid4()) + db_model.tenant_id = account.current_tenant_id + db_model.app_id = str(uuid4()) + db_model.workflow_id = str(uuid4()) + db_model.triggered_from = "workflow-run" + db_model.workflow_run_id = None + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_execution_id = str(uuid4()) + db_model.node_id = "test-node-id" + db_model.node_type = "llm" + db_model.title = "Test Node" + db_model.inputs = None + db_model.process_data = json.dumps(process_data) + db_model.outputs = None + db_model.status = "succeeded" + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.execution_metadata = "{}" + db_model.created_at = datetime.now() + db_model.created_by_role = "account" + db_model.created_by = account.id + db_model.finished_at = None + + domain_model = repo._to_domain_model(db_model) + + assert domain_model.process_data == process_data + assert domain_model.process_data_truncated is False + assert domain_model.get_truncated_process_data() is None diff --git a/api/tests/test_containers_integration_tests/services/rag_pipeline/__init__.py b/api/tests/test_containers_integration_tests/services/rag_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py new file mode 100644 index 0000000000..8fc1809a46 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py @@ -0,0 +1,255 @@ +""" +Integration tests for RagPipelineService methods that interact with the database. + +Migrated from unit_tests/services/rag_pipeline/test_rag_pipeline_service.py, replacing +db.session.scalar/commit/delete mocker patches with real PostgreSQL operations. + +Covers: +- get_pipeline: Dataset and Pipeline lookups +- update_customized_pipeline_template: find + unique-name check + commit +- delete_customized_pipeline_template: find + delete + commit +""" + +from collections.abc import Generator +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate +from models.enums import DataSourceType +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class TestRagPipelineServiceGetPipeline: + """Integration tests for RagPipelineService.get_pipeline.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _make_service(self, flask_app_with_containers) -> RagPipelineService: + with ( + patch( + "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository", + return_value=None, + ), + patch( + "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=None, + ), + ): + session_factory = sessionmaker(bind=flask_app_with_containers.extensions["sqlalchemy"].engine) + return RagPipelineService(session_maker=session_factory) + + def _create_pipeline(self, db_session: Session, tenant_id: str, created_by: str) -> Pipeline: + pipeline = Pipeline( + tenant_id=tenant_id, + name=f"Pipeline {uuid4()}", + description="", + created_by=created_by, + ) + db_session.add(pipeline) + db_session.flush() + return pipeline + + def _create_dataset( + self, db_session: Session, tenant_id: str, created_by: str, pipeline_id: str | None = None + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=f"Dataset {uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by, + pipeline_id=pipeline_id, + ) + db_session.add(dataset) + db_session.flush() + return dataset + + def test_get_pipeline_raises_when_dataset_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline raises ValueError when dataset does not exist.""" + service = self._make_service(flask_app_with_containers) + + with pytest.raises(ValueError, match="Dataset not found"): + service.get_pipeline(tenant_id=str(uuid4()), dataset_id=str(uuid4())) + + def test_get_pipeline_raises_when_pipeline_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline raises ValueError when dataset exists but has no linked pipeline.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=None) + db_session_with_containers.flush() + + service = self._make_service(flask_app_with_containers) + + with pytest.raises(ValueError, match="(Dataset not found|Pipeline not found)"): + service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id) + + def test_get_pipeline_returns_pipeline_when_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline returns the Pipeline when both Dataset and Pipeline exist.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + pipeline = self._create_pipeline(db_session_with_containers, tenant_id, created_by) + dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=pipeline.id) + db_session_with_containers.flush() + + service = self._make_service(flask_app_with_containers) + + result = service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id) + + assert result.id == pipeline.id + + +class TestUpdateCustomizedPipelineTemplate: + """Integration tests for RagPipelineService.update_customized_pipeline_template.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_template( + self, db_session: Session, tenant_id: str, created_by: str, name: str = "Template" + ) -> PipelineCustomizedTemplate: + template = PipelineCustomizedTemplate( + tenant_id=tenant_id, + name=name, + description="Original description", + chunk_structure="fixed_size", + icon={"type": "emoji", "value": "📄"}, + position=1, + yaml_content="{}", + install_count=0, + language="en-US", + created_by=created_by, + ) + db_session.add(template) + db_session.flush() + return template + + def test_update_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + """update_customized_pipeline_template updates name and description.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template = self._create_template(db_session_with_containers, tenant_id, created_by) + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="Updated Name", + description="Updated description", + icon_info=IconInfo(icon="🔥"), + ) + result = RagPipelineService.update_customized_pipeline_template(template.id, info) + + assert result.name == "Updated Name" + assert result.description == "Updated description" + + def test_update_template_raises_when_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """update_customized_pipeline_template raises ValueError when template doesn't exist.""" + fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="New Name", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.update_customized_pipeline_template(str(uuid4()), info) + + def test_update_template_raises_on_duplicate_name( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """update_customized_pipeline_template raises ValueError when new name already exists.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template1 = self._create_template(db_session_with_containers, tenant_id, created_by, name="Original") + self._create_template(db_session_with_containers, tenant_id, created_by, name="Duplicate") + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="Duplicate", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Template name is already exists"): + RagPipelineService.update_customized_pipeline_template(template1.id, info) + + +class TestDeleteCustomizedPipelineTemplate: + """Integration tests for RagPipelineService.delete_customized_pipeline_template.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_template(self, db_session: Session, tenant_id: str, created_by: str) -> PipelineCustomizedTemplate: + template = PipelineCustomizedTemplate( + tenant_id=tenant_id, + name=f"Template {uuid4()}", + description="Description", + chunk_structure="fixed_size", + icon={"type": "emoji", "value": "📄"}, + position=1, + yaml_content="{}", + install_count=0, + language="en-US", + created_by=created_by, + ) + db_session.add(template) + db_session.flush() + return template + + def test_delete_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + """delete_customized_pipeline_template removes the template from the DB.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template = self._create_template(db_session_with_containers, tenant_id, created_by) + template_id = template.id + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + RagPipelineService.delete_customized_pipeline_template(template_id) + + # Verify the record is deleted within the same context + from sqlalchemy import select + + from extensions.ext_database import db as ext_db + + remaining = ext_db.session.scalar( + select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id) + ) + assert remaining is None + + def test_delete_template_raises_when_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """delete_customized_pipeline_template raises ValueError when template doesn't exist.""" + fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.delete_customized_pipeline_template(str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/test_billing_service.py b/api/tests/test_containers_integration_tests/services/test_billing_service.py index 76708b36b1..8092c7ad75 100644 --- a/api/tests/test_containers_integration_tests/services/test_billing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_billing_service.py @@ -1,9 +1,13 @@ import json +from collections.abc import Generator from unittest.mock import patch +from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from services.billing_service import BillingService @@ -363,3 +367,62 @@ class TestBillingServiceGetPlanBulkWithCache: assert ttl_1_new <= 600 assert ttl_2 > 0 assert ttl_2 <= 600 + + +class TestBillingServiceIsTenantOwnerOrAdmin: + """ + Integration tests for BillingService.is_tenant_owner_or_admin. + + Verifies that non-privileged roles (EDITOR, DATASET_OPERATOR) raise ValueError + when checked against real TenantAccountJoin rows in PostgreSQL. + """ + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_account_with_tenant_role(self, db_session: Session, role: TenantAccountRole) -> tuple[Account, Tenant]: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session.add(tenant) + db_session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"billing_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session.add(join) + db_session.flush() + + # Wire up in-memory reference so current_tenant_id resolves + account._current_tenant = tenant + return account, tenant + + def test_is_tenant_owner_or_admin_editor_role_raises_error(self, db_session_with_containers: Session) -> None: + """is_tenant_owner_or_admin raises ValueError for EDITOR role.""" + account, _ = self._create_account_with_tenant_role(db_session_with_containers, TenantAccountRole.EDITOR) + + with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"): + BillingService.is_tenant_owner_or_admin(account) + + def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self, db_session_with_containers: Session) -> None: + """is_tenant_owner_or_admin raises ValueError for DATASET_OPERATOR role.""" + account, _ = self._create_account_with_tenant_role( + db_session_with_containers, TenantAccountRole.DATASET_OPERATOR + ) + + with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"): + BillingService.is_tenant_owner_or_admin(account) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 6180d98b1e..98c38f2b5f 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -637,6 +637,40 @@ class TestConversationServiceSummarization: assert conversation.name == new_name assert conversation.updated_at == mock_time + @patch("services.conversation_service.LLMGenerator.generate_conversation_name") + def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers): + """ + Test rename delegates to auto_generate_name when auto_generate is True. + + When auto_generate is True, the service should call auto_generate_name + which uses an LLM to create a descriptive conversation title. + """ + # Arrange + app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + conversation = ConversationServiceIntegrationTestDataFactory.create_conversation( + db_session_with_containers, app_model, user + ) + ConversationServiceIntegrationTestDataFactory.create_message( + db_session_with_containers, app_model, conversation, user + ) + generated_name = "Auto Generated Name" + mock_llm_generator.return_value = generated_name + + # Act + result = ConversationService.rename( + app_model=app_model, + conversation_id=conversation.id, + user=user, + name=None, + auto_generate=True, + ) + + # Assert + assert result == conversation + assert conversation.name == generated_name + class TestConversationServiceMessageAnnotation: """ @@ -1066,3 +1100,32 @@ class TestConversationServiceExport: not_deleted = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation.id)) assert not_deleted is not None mock_delete_task.delay.assert_not_called() + + @patch("services.conversation_service.delete_conversation_related_data") + def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers): + """ + Test that delete propagates exceptions and does not trigger the cleanup task. + + When a DB error occurs during deletion, the service must rollback the + transaction and re-raise the exception without scheduling async cleanup. + """ + # Arrange + app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + conversation = ConversationServiceIntegrationTestDataFactory.create_conversation( + db_session_with_containers, app_model, user + ) + conversation_id = conversation.id + + # Act — force an error during the delete to exercise the rollback path + with patch("services.conversation_service.db.session.delete", side_effect=Exception("DB error")): + with pytest.raises(Exception, match="DB error"): + ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user) + + # Assert — async cleanup must NOT have been scheduled + mock_delete_task.delay.assert_not_called() + + # Conversation is still present because the deletion was never committed + still_there = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation_id)) + assert still_there is not None diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index a814466e14..2a2d86a8a6 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -1,3 +1,4 @@ +import json from unittest.mock import Mock, patch from uuid import uuid4 @@ -7,7 +8,7 @@ from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.dataset import Dataset, ExternalKnowledgeBindings +from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings from models.enums import DataSourceType from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -103,6 +104,34 @@ class DatasetUpdateTestDataFactory: db_session_with_containers.commit() return binding + @staticmethod + def create_external_knowledge_api( + db_session_with_containers: Session, + tenant_id: str, + created_by: str, + api_id: str | None = None, + name: str = "test-api", + ) -> ExternalKnowledgeApis: + """Create a real external knowledge API template for tenant-scoped update validation.""" + external_api = ExternalKnowledgeApis( + tenant_id=tenant_id, + created_by=created_by, + updated_by=created_by, + name=name, + description="test description", + settings=json.dumps( + { + "endpoint": "https://example.com", + "api_key": "test-api-key", + } + ), + ) + if api_id is not None: + external_api.id = api_id + db_session_with_containers.add(external_api) + db_session_with_containers.commit() + return external_api + class TestDatasetServiceUpdateDataset: """ @@ -138,6 +167,11 @@ class TestDatasetServiceUpdateDataset: ) binding_id = binding.id db_session_with_containers.expunge(binding) + external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api( + db_session_with_containers, + tenant_id=tenant.id, + created_by=user.id, + ) update_data = { "name": "new_name", @@ -145,7 +179,7 @@ class TestDatasetServiceUpdateDataset: "external_retrieval_model": "new_model", "permission": "only_me", "external_knowledge_id": "new_knowledge_id", - "external_knowledge_api_id": str(uuid4()), + "external_knowledge_api_id": external_api.id, } result = DatasetService.update_dataset(dataset.id, update_data, user) @@ -218,11 +252,16 @@ class TestDatasetServiceUpdateDataset: created_by=user.id, provider="external", ) + external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api( + db_session_with_containers, + tenant_id=tenant.id, + created_by=user.id, + ) update_data = { "name": "new_name", "external_knowledge_id": "knowledge_id", - "external_knowledge_api_id": str(uuid4()), + "external_knowledge_api_id": external_api.id, } with pytest.raises(ValueError) as context: diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py similarity index 51% rename from api/tests/unit_tests/services/test_hit_testing_service.py rename to api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index 80e9729f5b..f332ba05ec 100644 --- a/api/tests/unit_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -1,239 +1,193 @@ +from __future__ import annotations + import json from typing import Any, cast from unittest.mock import ANY, MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.models.document import Document -from models.dataset import Dataset +from models.dataset import Dataset, DatasetQuery from services.hit_testing_service import HitTestingService -class TestHitTestingService: - """Test suite for HitTestingService""" +def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset: + tenant_id = str(uuid4()) + created_by = str(uuid4()) + ds = Dataset( + tenant_id=kwargs.get("tenant_id", tenant_id), + name=kwargs.get("name", "test-dataset"), + created_by=kwargs.get("created_by", created_by), + provider=provider, + ) + db_session.add(ds) + db_session.commit() + db_session.refresh(ds) + return ds - # ===== Utility Method Tests ===== + +class TestHitTestingService: + # ── Utility methods (pure logic, no DB) ──────────────────────────── def test_escape_query_for_search_should_escape_double_quotes(self): - """Test that escape_query_for_search escapes double quotes correctly""" - # Arrange query = 'test "query" with quotes' - expected = 'test \\"query\\" with quotes' - - # Act result = HitTestingService.escape_query_for_search(query) - - # Assert - assert result == expected + assert result == 'test \\"query\\" with quotes' def test_hit_testing_args_check_should_pass_with_valid_query(self): - """Test that hit_testing_args_check passes with a valid query""" - # Arrange - args = {"query": "valid query"} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"query": "valid query"}) def test_hit_testing_args_check_should_pass_with_valid_attachments(self): - """Test that hit_testing_args_check passes with valid attachment_ids""" - # Arrange - args = {"attachment_ids": ["id1", "id2"]} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]}) def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): - """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" - # Arrange - args = {} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query or attachment_ids is required" in str(exc_info.value) + with pytest.raises(ValueError, match="Query or attachment_ids is required"): + HitTestingService.hit_testing_args_check({}) def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): - """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" - # Arrange - args = {"query": "a" * 251} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query cannot exceed 250 characters" in str(exc_info.value) + with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): + HitTestingService.hit_testing_args_check({"query": "a" * 251}) def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): - """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" - # Arrange - args = {"attachment_ids": "not a list"} + with pytest.raises(ValueError, match="Attachment_ids must be a list"): + HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"}) - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Attachment_ids must be a list" in str(exc_info.value) - - # ===== Response Formatting Tests ===== + # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") def test_compact_retrieve_response_should_format_correctly(self, mock_format): - """Test that compact_retrieve_response formats the response correctly""" - # Arrange query = "test query" mock_doc = MagicMock(spec=Document) - documents = [mock_doc] mock_record = MagicMock() mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - # Act - result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc])) - # Assert assert cast(dict[str, Any], result["query"])["content"] == query assert len(result["records"]) == 1 assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" - mock_format.assert_called_once_with(documents) + mock_format.assert_called_once_with([mock_doc]) - def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): - """Test that compact_external_retrieve_response returns records when dataset provider is external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "external" - query = "test query" + def test_compact_external_retrieve_response_should_return_records_for_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") documents = [ {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, ] - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert len(result["records"]) == 2 assert cast(dict[str, Any], result["records"][0])["content"] == "c1" assert cast(dict[str, Any], result["records"][1])["title"] == "t2" - def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): - """Test that compact_external_retrieve_response returns empty records for non-external provider""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" - documents = [{"content": "c1"}] + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="vendor") - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], + HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]), + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== External Retrieve Tests ===== + # ── External retrieve (real DB) ──────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): - """Test that external_retrieve successfully retrieves from external provider and commits query""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - dataset.provider = "external" - query = 'test "query"' + def test_external_retrieve_should_succeed_for_external_provider( + self, mock_ext_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") + account_id = str(uuid4()) account = MagicMock() - account.id = "account_id" - + account.id = account_id mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.external_retrieve( dataset=dataset, - query=query, + query='test "query"', account=account, external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == 'test "query"' assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" - - # Verify call to RetrievalService.external_retrieve with escaped query mock_ext_retrieve.assert_called_once_with( - dataset_id="dataset_id", + dataset_id=dataset.id, query='test \\"query\\"', external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ) - # Verify DatasetQuery record was added and committed - mock_add.assert_called_once() - mock_commit.assert_called_once() + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 - def test_external_retrieve_should_return_empty_for_non_external_provider(self): - """Test that external_retrieve returns empty results immediately if provider is not external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" + def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - # Act - result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account)) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== Retrieve Tests ===== + # ── Retrieve (real DB) ───────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve uses default model when retrieval_model is not provided""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" + def test_retrieve_should_use_default_model_when_none_provided( + self, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) dataset.retrieval_model = None - query = "test query" account = MagicMock() - account.id = "account_id" - + account.id = str(uuid4()) mock_retrieve.return_value = [] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={} ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" mock_retrieve.assert_called_once() - # Verify top_k from default_retrieval_model (4) assert mock_retrieve.call_args.kwargs["top_k"] == 4 - mock_commit.assert_called_once() + + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): - """Test that retrieve correctly calls metadata filtering when conditions are present""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_metadata_filtering( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "semantic_search", @@ -242,29 +196,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response - mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string") mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_get_meta.assert_called_once() mock_retrieve.assert_called_once() assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): - """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_return_empty_if_metadata_filtering_fails( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() retrieval_model = { @@ -274,37 +226,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response: condition returned but no IDs mock_get_meta.return_value = ({}, "condition_string") - # Act result = cast( dict[str, Any], HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, ), ) - # Assert assert result["records"] == [] mock_retrieve.assert_not_called() @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) attachment_ids = ["att1", "att2"] retrieval_model = { @@ -315,21 +257,19 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, attachment_ids=attachment_ids, ) - # Assert mock_retrieve.assert_called_once_with( retrieval_method=ANY, - dataset_id="dataset_id", - query=query, + dataset_id=dataset.id, + query="test query", attachment_ids=attachment_ids, top_k=4, score_threshold=0.0, @@ -338,26 +278,27 @@ class TestHitTestingService: weights=None, document_ids_filter=None, ) - # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) - # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) - called_query = mock_add.call_args[0][0] - query_content = json.loads(called_query.content) + + # Verify DatasetQuery was persisted with correct content structure + db_session_with_containers.expire_all() + latest = db_session_with_containers.scalar( + select(DatasetQuery) + .where(DatasetQuery.dataset_id == dataset.id) + .order_by(DatasetQuery.created_at.desc()) + .limit(1) + ) + assert latest is not None + query_content = json.loads(latest.content) assert len(query_content) == 3 # 1 text + 2 images assert query_content[0]["content_type"] == "text_query" assert query_content[1]["content_type"] == "image_query" assert query_content[1]["content"] == "att1" @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve passes reranking and threshold parameters correctly""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "hybrid_search", @@ -371,12 +312,14 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_retrieve.assert_called_once() kwargs = mock_retrieve.call_args.kwargs assert kwargs["score_threshold"] == 0.5 diff --git a/api/tests/test_containers_integration_tests/services/test_ops_service.py b/api/tests/test_containers_integration_tests/services/test_ops_service.py new file mode 100644 index 0000000000..e2e1a228b2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_ops_service.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import uuid +from unittest.mock import patch + +import pytest +from faker import Faker +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import TraceAppConfig +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.ops_service import OpsService +from tests.test_containers_integration_tests.helpers import generate_valid_password + + +class TestOpsService: + @pytest.fixture + def mock_external_service_dependencies(self): + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + @pytest.fixture + def mock_ops_trace_manager(self): + with patch("services.ops_service.OpsTraceManager") as mock: + yield mock + + def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies): + fake = Faker() + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + }, + account, + ) + return app, account + + _SENTINEL = object() + + def _insert_trace_config( + self, + db_session: Session, + app_id: str, + provider: str, + tracing_config: dict | None | object = _SENTINEL, + ) -> TraceAppConfig: + trace_config = TraceAppConfig( + app_id=app_id, + tracing_provider=provider, + tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"}, + ) + db_session.add(trace_config) + db_session.commit() + return trace_config + + # ── get_tracing_app_config ───────────────────────────────────────── + + def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, "arize") + result = OpsService.get_tracing_app_config(fake_app_id, "arize") + assert result is None + + def test_get_tracing_app_config_none_config( + self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None) + + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config(app.id, "arize") + + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {} + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == default_url + + @pytest.mark.parametrize( + "provider", + ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"], + ) + def test_get_tracing_app_config_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"} + mock_otm.get_trace_config_project_url.return_value = "success_url" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == "success_url" + + def test_get_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.return_value = "key" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + def test_get_tracing_app_config_langfuse_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + # ── create_tracing_app_config ────────────────────────────────────── + + def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + def test_create_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_ops_trace_manager + ): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + result = OpsService.create_tracing_app_config( + str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"} + ) + assert result == {"error": "Invalid Credentials"} + + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config + ): + # Existing config causes the service to return None before reaching the DB insert + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(provider)) + + result = OpsService.create_tracing_app_config(app.id, provider, config) + + assert result is None + + def test_create_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_key.return_value = "key" + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config( + app.id, + TracingProviderEnum.LANGFUSE, + {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}, + ) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is None + + def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_create_tracing_app_config_with_empty_other_keys( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + # "project" is in other_keys for Arize; providing "" triggers default substitution + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("no url") + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""}) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.return_value = "http://project_url" + mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result == {"result": "success"} + + # ── update_tracing_app_config ────────────────────────────────────── + + def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + + def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE)) + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = False + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + def test_update_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {"updated": "config"} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is not None + assert result["app_id"] == app.id + + # ── delete_tracing_app_config ────────────────────────────────────── + + def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session): + result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_delete_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize") + + result = OpsService.delete_tracing_app_config(app.id, "arize") + + assert result is True + remaining = db_session_with_containers.scalar( + select(TraceAppConfig) + .where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize") + .limit(1) + ) + assert remaining is None diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 4fe65d5803..7825f502f7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -233,11 +233,10 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None - assert result.password is not None - assert result.password_salt is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None + assert refreshed.password is not None + assert refreshed.password_salt is not None def test_authenticate_account_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -414,9 +413,8 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None def test_get_user_through_email_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 159ab51304..4bc022c415 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -26,7 +26,7 @@ from datetime import timedelta import pytest from graphon.entities import WorkflowExecution from graphon.enums import WorkflowExecutionStatus -from sqlalchemy import delete, select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage @@ -679,9 +679,12 @@ class TestWorkflowPauseIntegration: # Verify only 3 were deleted remaining_count = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities])) - .count() + self.session.scalar( + select(func.count(WorkflowPauseModel.id)).where( + WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]) + ) + ) + or 0 ) assert remaining_count == 2 diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 161d0c41e8..514bbbe040 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -1,3 +1,4 @@ +from importlib import import_module from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -11,6 +12,7 @@ from controllers.console.datasets.external import ( BedrockRetrievalApi, ExternalApiTemplateApi, ExternalApiTemplateListApi, + ExternalApiUseCheckApi, ExternalDatasetCreateApi, ExternalKnowledgeHitTestingApi, ) @@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService +external_controller = import_module("controllers.console.datasets.external") + def unwrap(func): while hasattr(func, "__wrapped__"): @@ -44,10 +48,11 @@ def current_user(): @pytest.fixture(autouse=True) -def mock_auth(mocker, current_user): - mocker.patch( - "controllers.console.datasets.external.current_account_with_tenant", - return_value=(current_user, "tenant-1"), +def mock_auth(monkeypatch, current_user): + monkeypatch.setattr( + external_controller, + "current_account_with_tenant", + lambda: (current_user, "tenant-1"), ) @@ -136,6 +141,26 @@ class TestExternalApiTemplateApi: method(api, "api-id") +class TestExternalApiUseCheckApi: + def test_get_scopes_usage_check_to_current_tenant(self, app): + api = ExternalApiUseCheckApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + ExternalDatasetService, + "external_knowledge_api_use_check", + return_value=(True, 2), + ) as mock_use_check, + ): + response, status = method(api, "api-id") + + assert status == 200 + assert response == {"is_using": True, "count": 2} + mock_use_check.assert_called_once_with("api-id", "tenant-1") + + class TestExternalDatasetCreateApi: def test_create_success(self, app): api = ExternalDatasetCreateApi() diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 7f9fe9cbf9..dd643faac9 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -233,15 +233,20 @@ class TestCheckEmailUnique: def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): - session = MagicMock() + mock_session = MagicMock() first = MagicMock() first.scalar_one_or_none.return_value = None second = MagicMock() expected_account = MagicMock() second.scalar_one_or_none.return_value = expected_account - session.execute.side_effect = [first, second] + mock_session.execute.side_effect = [first, second] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + mock_factory = MagicMock() + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + + with patch("services.account_service.session_factory", mock_factory): + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account - assert session.execute.call_count == 2 + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index fe533e62af..1f5fdd2657 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -862,6 +862,15 @@ class TestAuthOrchestration: result = discover_protected_resource_metadata(None, "https://api.example.com") assert result is None + # JSONDecodeError (non-JSON 200 response) + mock_get.side_effect = None + bad_json_response = Mock() + bad_json_response.status_code = 200 + bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_response + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is None + @patch("core.helper.ssrf_proxy.get") def test_discover_oauth_authorization_server_metadata(self, mock_get): # Success @@ -892,6 +901,14 @@ class TestAuthOrchestration: result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") assert result is None + # JSONDecodeError (non-JSON 200 response) + bad_json_response = Mock() + bad_json_response.status_code = 200 + bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is None + def test_get_effective_scope(self): prm = ProtectedResourceMetadata( resource="https://api.example.com", @@ -997,6 +1014,24 @@ class TestAuthOrchestration: supported, url = check_support_resource_discovery("https://api") assert supported is False + # Case 6: JSONDecodeError (non-JSON 200 response) + mock_get.side_effect = None + bad_json_res = Mock() + bad_json_res.status_code = 200 + bad_json_res.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_res + supported, url = check_support_resource_discovery("https://api") + assert supported is False + assert url == "" + + # Case 7: Empty authorization_servers array (IndexError) + empty_res = Mock() + empty_res.status_code = 200 + empty_res.json.return_value = {"authorization_servers": []} + mock_get.return_value = empty_res + supported, url = check_support_resource_discovery("https://api") + assert supported is False + def test_discover_oauth_metadata(self): with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: diff --git a/api/tests/unit_tests/core/mcp/test_entities.py b/api/tests/unit_tests/core/mcp/test_entities.py index 3fede55916..e99c38285c 100644 --- a/api/tests/unit_tests/core/mcp/test_entities.py +++ b/api/tests/unit_tests/core/mcp/test_entities.py @@ -4,9 +4,7 @@ from unittest.mock import Mock from core.mcp.entities import ( SUPPORTED_PROTOCOL_VERSIONS, - LifespanContextT, RequestContext, - SessionT, ) from core.mcp.session.base_session import BaseSession from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams @@ -198,42 +196,3 @@ class TestRequestContext: assert "RequestContext" in repr_str assert "test-123" in repr_str assert "MockSession" in repr_str - - -class TestTypeVariables: - """Test type variables defined in the module.""" - - def test_session_type_var(self): - """Test SessionT type variable.""" - - # Create a custom session class - class CustomSession(BaseSession): - pass - - # Use in generic context - def process_session(session: SessionT) -> SessionT: - return session - - mock_session = Mock(spec=CustomSession) - result = process_session(mock_session) - assert result == mock_session - - def test_lifespan_context_type_var(self): - """Test LifespanContextT type variable.""" - - # Use in generic context - def process_lifespan(context: LifespanContextT) -> LifespanContextT: - return context - - # Test with different types - str_context = "string-context" - assert process_lifespan(str_context) == str_context - - dict_context = {"key": "value"} - assert process_lifespan(dict_context) == dict_context - - class CustomContext: - pass - - custom_context = CustomContext() - assert process_lifespan(custom_context) == custom_context diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py index ca8cd5e514..43cdb4948d 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py @@ -39,6 +39,25 @@ class _FakeSession: return None +class _FakeBeginContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return None + + +def _patch_both(monkeypatch, module, session): + """Patch both Session and sessionmaker on the module.""" + monkeypatch.setattr(module, "Session", lambda _client: session) + monkeypatch.setattr( + module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session))) + ) + + @pytest.fixture def relyt_module(monkeypatch): for name, module in _build_fake_relyt_modules().items(): @@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch): monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1)) session = _FakeSession() - monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + _patch_both(monkeypatch, relyt_module, session) vector.create_collection(3) session.execute.assert_not_called() monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None)) session = _FakeSession() - monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + _patch_both(monkeypatch, relyt_module, session) vector.create_collection(3) executed_sql = [str(call.args[0]) for call in session.execute.call_args_list] assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql) @@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module): # 8. delete commits session -def test_delete_commits_session(relyt_module, monkeypatch): +def test_delete_drops_table(relyt_module, monkeypatch): vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) vector._collection_name = "collection_1" vector.client = MagicMock() vector.embedding_dimension = 3 session = _FakeSession() - monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + _patch_both(monkeypatch, relyt_module, session) vector.delete() - session.commit.assert_called_once() + session.execute.assert_called_once() def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index 4e9ceddda9..5a0e4dcd75 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -121,7 +121,18 @@ def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): default_vector = vector_factory_module.Vector(dataset) custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) - assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + # `is_summary` and `original_chunk_id` must be in the default return-properties + # projection so summary index retrieval works on backends that honor the list + # as an explicit projection (e.g. Weaviate). See #34884. + assert default_vector._attributes == [ + "doc_id", + "dataset_id", + "document_id", + "doc_hash", + "doc_type", + "is_summary", + "original_chunk_id", + ] assert custom_vector._attributes == ["doc_id"] assert default_vector._embeddings == "embeddings" assert default_vector._vector_processor == "processor" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py index 951a920f3b..8e19a59af8 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py @@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke session = MagicMock() - class _SessionCtx: + class _BeginCtx: def __enter__(self): return session def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx())) + monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm) vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) vector._collection_name = "collection_1" @@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke vector._create_collection(3) - session.begin.assert_called_once() sql = str(session.execute.call_args.args[0]) assert "VECTOR(3)" in sql assert "VEC_L2_DISTANCE" in sql - session.commit.assert_called_once() tidb_module.redis_client.set.assert_called_once() @@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch): def test_delete_drops_table(tidb_module, monkeypatch): session = MagicMock() session.execute.return_value = None - session.commit = MagicMock() - class _SessionCtx: + class _BeginCtx: def __enter__(self): return session def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx())) + monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm) vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) vector._collection_name = "collection_1" vector._engine = MagicMock() vector.delete() drop_sql = str(session.execute.call_args.args[0]) assert "DROP TABLE IF EXISTS collection_1" in drop_sql - session.commit.assert_called_once() def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch): diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index c241b44d52..8ef0e046ef 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -258,10 +258,10 @@ class TestParentChildIndexProcessor: session.commit.assert_called_once() def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: - segment_query = Mock() - segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="seg-1")] session = Mock() - session.query.return_value = segment_query + session.scalars.return_value = scalars_result session_ctx = MagicMock() session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 98c47bec8f..b1b1835a52 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -220,10 +220,10 @@ class TestQAIndexProcessor: self, processor: QAIndexProcessor, dataset: Mock ) -> None: mock_segment = SimpleNamespace(id="seg-1") - mock_query = Mock() - mock_query.filter.return_value.all.return_value = [mock_segment] + scalars_result = Mock() + scalars_result.all.return_value = [mock_segment] mock_session = Mock() - mock_session.query.return_value = mock_query + mock_session.scalars.return_value = scalars_result session_context = MagicMock() session_context.__enter__.return_value = mock_session session_context.__exit__.return_value = False diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index b98fec3854..1b17cbc368 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -8,7 +8,6 @@ import pytest from flask import Flask, current_app from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.model_entities import ModelFeature -from sqlalchemy import column from core.app.app_config.entities import ( DatasetEntity, @@ -4039,21 +4038,9 @@ class TestDatasetRetrievalAdditionalHelpers: def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None: session = Mock() - subquery_query = Mock() - subquery_query.where.return_value = subquery_query - subquery_query.group_by.return_value = subquery_query - subquery_query.having.return_value = subquery_query - subquery_query.subquery.return_value = SimpleNamespace( - c=SimpleNamespace( - dataset_id=column("dataset_id"), available_document_count=column("available_document_count") - ) - ) - - dataset_query = Mock() - dataset_query.outerjoin.return_value = dataset_query - dataset_query.where.return_value = dataset_query - dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] - session.query.side_effect = [subquery_query, dataset_query] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] + session.scalars.return_value = scalars_result session_ctx = MagicMock() session_ctx.__enter__.return_value = session @@ -4902,9 +4889,6 @@ class TestInternalHooksCoverage: _scalars(segments), _scalars(bindings), ] - query = Mock() - query.where.return_value = query - session.query.return_value = query session_ctx = MagicMock() session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False @@ -4919,7 +4903,7 @@ class TestInternalHooksCoverage: ): retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) - query.update.assert_called_once() + session.execute.assert_called_once() mock_trace.assert_called_once() def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: diff --git a/api/tests/unit_tests/core/tools/test_tool_engine.py b/api/tests/unit_tests/core/tools/test_tool_engine.py index 40c107667c..cd16557ef6 100644 --- a/api/tests/unit_tests/core/tools/test_tool_engine.py +++ b/api/tests/unit_tests/core/tools/test_tool_engine.py @@ -260,6 +260,28 @@ def test_agent_invoke_engine_meta_error(): assert error_meta.error == "meta failure" +def test_convert_tool_response_excludes_variable_messages(): + """Regression test for issue #34723. + + WorkflowTool._invoke yields VARIABLE, TEXT, and suppressed-JSON messages. + _convert_tool_response_to_str must skip VARIABLE messages so that the + returned string contains only the TEXT representation and not a + duplicated, garbled Pydantic repr of the same data. + """ + tool = _build_tool() + outputs = {"reports": "hello"} + messages = [ + tool.create_variable_message(variable_name="reports", variable_value="hello"), + tool.create_text_message('{"reports": "hello"}'), + tool.create_json_message(outputs, suppress_output=True), + ] + + result = ToolEngine._convert_tool_response_to_str(messages) + + assert result == '{"reports": "hello"}' + assert "variable_name" not in result + + def test_agent_invoke_tool_invoke_error(): tool = _build_tool(with_llm_parameter=True) callback = Mock() diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 31b68f0b3f..9ebaa0417b 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -637,7 +637,7 @@ def test_list_default_builtin_providers_for_postgres_and_mysql(): for scheme in ("postgresql", "mysql"): session = Mock() session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] - session.query.return_value.where.return_value.all.return_value = provider_records + session.scalars.return_value = iter(provider_records) with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)): with patch("core.tools.tool_manager.db") as mock_db: diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py index f7475f2239..12e91f190f 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_generate_handler.py @@ -39,7 +39,7 @@ class TestAppGenerateHandler: "root_node_id": None, } - arguments = handler._extract_arguments(AppGenerateService.generate, (), kwargs) + arguments = handler._extract_arguments(AppGenerateService.generate, **kwargs) assert arguments is not None, "Failed to extract arguments from AppGenerateService.generate" assert "app_model" in arguments, "Handler uses app_model but parameter is missing" @@ -70,14 +70,11 @@ class TestAppGenerateHandler: handler.wrapper( tracer, dummy_func, - (), - { - "app_model": mock_app_model, - "user": mock_account_user, - "args": {"workflow_id": test_workflow_id}, - "invoke_from": InvokeFrom.DEBUGGER, - "streaming": False, - }, + app_model=mock_app_model, + user=mock_account_user, + args={"workflow_id": test_workflow_id}, + invoke_from=InvokeFrom.DEBUGGER, + streaming=False, ) spans = memory_span_exporter.get_finished_spans() diff --git a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py index 500f80fc3c..842e7f55e2 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/handlers/test_workflow_app_runner_handler.py @@ -63,7 +63,7 @@ class TestWorkflowAppRunnerHandler: def runner_run(self): return "result" - handler.wrapper(tracer, runner_run, (mock_workflow_runner,), {}) + handler.wrapper(tracer, runner_run, mock_workflow_runner) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 diff --git a/api/tests/unit_tests/extensions/otel/decorators/test_handler.py b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py index 44788bab9a..bf861e3ef7 100644 --- a/api/tests/unit_tests/extensions/otel/decorators/test_handler.py +++ b/api/tests/unit_tests/extensions/otel/decorators/test_handler.py @@ -28,7 +28,7 @@ class TestSpanHandlerExtractArguments: args = (1, 2, 3) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -44,7 +44,7 @@ class TestSpanHandlerExtractArguments: args = () kwargs = {"a": 1, "b": 2, "c": 3} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -60,7 +60,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {"b": 2, "c": 3} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -76,7 +76,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -94,7 +94,7 @@ class TestSpanHandlerExtractArguments: instance = MyClass() args = (1, 2) kwargs = {} - result = handler._extract_arguments(instance.method, args, kwargs) + result = handler._extract_arguments(instance.method, *args, **kwargs) assert result is not None assert result["a"] == 1 @@ -109,7 +109,7 @@ class TestSpanHandlerExtractArguments: args = (1,) kwargs = {} - result = handler._extract_arguments(func, args, kwargs) + result = handler._extract_arguments(func, *args, **kwargs) assert result is None @@ -122,11 +122,11 @@ class TestSpanHandlerExtractArguments: assert func not in handler._signature_cache - handler._extract_arguments(func, (1, 2), {}) + handler._extract_arguments(func, 1, 2) assert func in handler._signature_cache cached_sig = handler._signature_cache[func] - handler._extract_arguments(func, (3, 4), {}) + handler._extract_arguments(func, 3, 4) assert handler._signature_cache[func] is cached_sig @@ -142,7 +142,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - result = handler.wrapper(tracer, test_func, (), {}) + result = handler.wrapper(tracer, test_func) assert result == "result" spans = memory_span_exporter.get_finished_spans() @@ -159,7 +159,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -174,7 +174,7 @@ class TestSpanHandlerWrapper: def test_func(): return "result" - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -190,7 +190,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError, match="test error"): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -208,7 +208,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) spans = memory_span_exporter.get_finished_spans() assert len(spans) == 1 @@ -225,7 +225,7 @@ class TestSpanHandlerWrapper: raise ValueError("test error") with pytest.raises(ValueError, match="test error"): - handler.wrapper(tracer, test_func, (), {}) + handler.wrapper(tracer, test_func) @patch("extensions.otel.decorators.base.dify_config.ENABLE_OTEL", True) def test_wrapper_passes_arguments_correctly(self, tracer_provider_with_memory_exporter, memory_span_exporter): @@ -236,7 +236,7 @@ class TestSpanHandlerWrapper: def test_func(a, b, c=10): return a + b + c - result = handler.wrapper(tracer, test_func, (1, 2), {"c": 3}) + result = handler.wrapper(tracer, test_func, 1, 2, c=3) assert result == 6 @@ -249,7 +249,7 @@ class TestSpanHandlerWrapper: def my_function(x): return x * 2 - result = handler.wrapper(tracer, my_function, (5,), {}) + result = handler.wrapper(tracer, my_function, 5) assert result == 10 spans = memory_span_exporter.get_finished_spans() diff --git a/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py b/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py new file mode 100644 index 0000000000..7087490845 --- /dev/null +++ b/api/tests/unit_tests/libs/test_pyrefly_type_coverage.py @@ -0,0 +1,138 @@ +import json + +from libs.pyrefly_type_coverage import ( + CoverageSummary, + format_comparison_markdown, + format_summary_markdown, + parse_summary, +) + + +def _make_report(summary: dict) -> str: + return json.dumps({"module_reports": [], "summary": summary}) + + +_SAMPLE_SUMMARY: dict = { + "n_modules": 100, + "n_typable": 1000, + "n_typed": 400, + "n_any": 50, + "n_untyped": 550, + "coverage": 45.0, + "strict_coverage": 40.0, + "n_functions": 200, + "n_methods": 300, + "n_function_params": 150, + "n_method_params": 250, + "n_classes": 80, + "n_attrs": 40, + "n_properties": 20, + "n_type_ignores": 10, +} + + +def _make_summary( + *, + n_modules: int = 100, + n_typable: int = 1000, + n_typed: int = 400, + n_any: int = 50, + n_untyped: int = 550, + coverage: float = 45.0, + strict_coverage: float = 40.0, +) -> CoverageSummary: + return { + "n_modules": n_modules, + "n_typable": n_typable, + "n_typed": n_typed, + "n_any": n_any, + "n_untyped": n_untyped, + "coverage": coverage, + "strict_coverage": strict_coverage, + } + + +def test_parse_summary_extracts_fields() -> None: + report_json = _make_report(_SAMPLE_SUMMARY) + + result = parse_summary(report_json) + + assert result["n_modules"] == 100 + assert result["n_typable"] == 1000 + assert result["n_typed"] == 400 + assert result["n_any"] == 50 + assert result["n_untyped"] == 550 + assert result["coverage"] == 45.0 + assert result["strict_coverage"] == 40.0 + + +def test_parse_summary_handles_empty_input() -> None: + assert parse_summary("")["n_modules"] == 0 + assert parse_summary(" ")["n_modules"] == 0 + + +def test_parse_summary_handles_invalid_json() -> None: + assert parse_summary("not json")["n_modules"] == 0 + + +def test_parse_summary_handles_missing_summary_key() -> None: + assert parse_summary(json.dumps({"other": 1}))["n_modules"] == 0 + + +def test_parse_summary_handles_incomplete_summary() -> None: + partial = json.dumps({"summary": {"n_modules": 5}}) + assert parse_summary(partial)["n_modules"] == 0 + + +def test_format_summary_markdown_contains_key_metrics() -> None: + summary = _make_summary() + + result = format_summary_markdown(summary) + + assert "**Type coverage**" in result + assert "45.00%" in result + assert "40.00%" in result + assert "| Modules | 100 |" in result + + +def test_format_comparison_markdown_shows_positive_delta() -> None: + base = _make_summary() + pr = _make_summary( + n_modules=101, + n_typable=1010, + n_typed=420, + n_untyped=540, + coverage=46.53, + strict_coverage=41.58, + ) + + result = format_comparison_markdown(base, pr) + + assert "| Base | PR | Delta |" in result + assert "+1.53%" in result + assert "+1.58%" in result + assert "+20" in result + + +def test_format_comparison_markdown_shows_negative_delta() -> None: + base = _make_summary() + pr = _make_summary( + n_typed=390, + n_any=60, + coverage=44.0, + strict_coverage=39.0, + ) + + result = format_comparison_markdown(base, pr) + + assert "-1.00%" in result + assert "-10" in result + + +def test_format_comparison_markdown_shows_zero_delta() -> None: + summary = _make_summary() + + result = format_comparison_markdown(summary, summary) + + assert "0.00%" in result + assert "| 0 |" in result diff --git a/api/tests/unit_tests/models/test_account_models.py b/api/tests/unit_tests/models/test_account_models.py index f48db77bb5..25933dd15b 100644 --- a/api/tests/unit_tests/models/test_account_models.py +++ b/api/tests/unit_tests/models/test_account_models.py @@ -12,7 +12,6 @@ This test suite covers: import base64 import secrets from datetime import UTC, datetime -from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest @@ -310,90 +309,6 @@ class TestAccountStatusTransitions: class TestTenantRelationshipIntegrity: """Test suite for tenant relationship integrity.""" - @patch("models.account.db") - def test_account_current_tenant_property(self, mock_db): - """Test the current_tenant property getter.""" - # Arrange - account = Account( - name="Test User", - email="test@example.com", - ) - account.id = str(uuid4()) - - tenant = Tenant(name="Test Tenant") - tenant.id = str(uuid4()) - - account._current_tenant = tenant - - # Act - result = account.current_tenant - - # Assert - assert result == tenant - - @patch("models.account.Session") - @patch("models.account.db") - def test_account_current_tenant_setter_with_valid_tenant(self, mock_db, mock_session_class): - """Test setting current_tenant with a valid tenant relationship.""" - # Arrange - account = Account( - name="Test User", - email="test@example.com", - ) - account.id = str(uuid4()) - - tenant = Tenant(name="Test Tenant") - tenant.id = str(uuid4()) - - # Mock the session and queries - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - # Mock TenantAccountJoin query result - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.OWNER, - ) - mock_session.scalar.return_value = tenant_join - - # Mock Tenant query result - mock_session.scalars.return_value.one.return_value = tenant - - # Act - account.current_tenant = tenant - - # Assert - assert account._current_tenant == tenant - assert account.role == TenantAccountRole.OWNER - - @patch("models.account.Session") - @patch("models.account.db") - def test_account_current_tenant_setter_without_relationship(self, mock_db, mock_session_class): - """Test setting current_tenant when no relationship exists.""" - # Arrange - account = Account( - name="Test User", - email="test@example.com", - ) - account.id = str(uuid4()) - - tenant = Tenant(name="Test Tenant") - tenant.id = str(uuid4()) - - # Mock the session and queries - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - - # Mock no TenantAccountJoin found - mock_session.scalar.return_value = None - - # Act - account.current_tenant = tenant - - # Assert - assert account._current_tenant is None - def test_account_current_tenant_id_property(self): """Test the current_tenant_id property.""" # Arrange @@ -418,61 +333,6 @@ class TestTenantRelationshipIntegrity: # Assert assert tenant_id_none is None - @patch("models.account.Session") - @patch("models.account.db") - def test_account_set_tenant_id_method(self, mock_db, mock_session_class): - """Test the set_tenant_id method.""" - # Arrange - account = Account( - name="Test User", - email="test@example.com", - ) - account.id = str(uuid4()) - - tenant = Tenant(name="Test Tenant") - tenant.id = str(uuid4()) - - tenant_join = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=TenantAccountRole.ADMIN, - ) - - # Mock the session and queries - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.execute.return_value.first.return_value = (tenant, tenant_join) - - # Act - account.set_tenant_id(tenant.id) - - # Assert - assert account._current_tenant == tenant - assert account.role == TenantAccountRole.ADMIN - - @patch("models.account.Session") - @patch("models.account.db") - def test_account_set_tenant_id_with_no_relationship(self, mock_db, mock_session_class): - """Test set_tenant_id when no relationship exists.""" - # Arrange - account = Account( - name="Test User", - email="test@example.com", - ) - account.id = str(uuid4()) - tenant_id = str(uuid4()) - - # Mock the session and queries - mock_session = MagicMock() - mock_session_class.return_value.__enter__.return_value = mock_session - mock_session.execute.return_value.first.return_value = None - - # Act - account.set_tenant_id(tenant_id) - - # Assert - should not set tenant when no relationship exists - # The method returns early without setting _current_tenant - class TestAccountRolePermissions: """Test suite for account role permissions.""" @@ -605,51 +465,6 @@ class TestAccountRolePermissions: assert current_role == TenantAccountRole.EDITOR -class TestAccountGetByOpenId: - """Test suite for get_by_openid class method.""" - - @patch("models.account.db") - def test_get_by_openid_success(self, mock_db): - """Test successful retrieval of account by OpenID.""" - # Arrange - provider = "google" - open_id = "google_user_123" - account_id = str(uuid4()) - - mock_account_integrate = MagicMock() - mock_account_integrate.account_id = account_id - - mock_account = Account(name="Test User", email="test@example.com") - mock_account.id = account_id - - # Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup - mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate - # Mock db.session.scalar() for Account lookup - mock_db.session.scalar.return_value = mock_account - - # Act - result = Account.get_by_openid(provider, open_id) - - # Assert - assert result == mock_account - - @patch("models.account.db") - def test_get_by_openid_not_found(self, mock_db): - """Test get_by_openid when account integrate doesn't exist.""" - # Arrange - provider = "github" - open_id = "github_user_456" - - # Mock db.session.execute().scalar_one_or_none() to return None - mock_db.session.execute.return_value.scalar_one_or_none.return_value = None - - # Act - result = Account.get_by_openid(provider, open_id) - - # Assert - assert result is None - - class TestTenantAccountJoinModel: """Test suite for TenantAccountJoin model.""" @@ -760,31 +575,6 @@ class TestTenantModel: # Assert assert tenant.custom_config == '{"feature1": true, "feature2": "value"}' - @patch("models.account.db") - def test_tenant_get_accounts(self, mock_db): - """Test getting accounts associated with a tenant.""" - # Arrange - tenant = Tenant(name="Test Workspace") - tenant.id = str(uuid4()) - - account1 = Account(name="User 1", email="user1@example.com") - account1.id = str(uuid4()) - account2 = Account(name="User 2", email="user2@example.com") - account2.id = str(uuid4()) - - # Mock the query chain - mock_scalars = MagicMock() - mock_scalars.all.return_value = [account1, account2] - mock_db.session.scalars.return_value = mock_scalars - - # Act - accounts = tenant.get_accounts() - - # Assert - assert len(accounts) == 2 - assert account1 in accounts - assert account2 in accounts - class TestTenantStatusEnum: """Test suite for TenantStatus enum.""" diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 59597fb8cd..4e46cf9654 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -291,24 +291,6 @@ class TestAppModelConfig: # Assert assert result == questions - def test_app_model_config_annotation_reply_dict_disabled(self): - """Test annotation_reply_dict when annotation is disabled.""" - # Arrange - config = AppModelConfig( - app_id=str(uuid4()), - provider="openai", - model_id="gpt-4", - created_by=str(uuid4()), - ) - - # Mock database scalar to return None (no annotation setting found) - with patch("models.model.db.session.scalar", return_value=None): - # Act - result = config.annotation_reply_dict - - # Assert - assert result == {"enabled": False} - class TestConversationModel: """Test suite for Conversation model integrity.""" @@ -948,17 +930,6 @@ class TestSiteModel: with pytest.raises(ValueError, match="Custom disclaimer cannot exceed 512 characters"): site.custom_disclaimer = long_disclaimer - def test_site_generate_code(self): - """Test Site.generate_code static method.""" - # Mock database scalar to return 0 (no existing codes) - with patch("models.model.db.session.scalar", return_value=0): - # Act - code = Site.generate_code(8) - - # Assert - assert isinstance(code, str) - assert len(code) == 8 - class TestModelIntegration: """Test suite for model integration scenarios.""" @@ -1146,314 +1117,3 @@ class TestModelIntegration: # Assert assert site.app_id == app.id assert app.enable_site is True - - -class TestConversationStatusCount: - """Test suite for Conversation.status_count property N+1 query fix.""" - - def test_status_count_no_messages(self): - """Test status_count returns None when conversation has no messages.""" - # Arrange - conversation = Conversation( - app_id=str(uuid4()), - mode=AppMode.CHAT, - name="Test Conversation", - status="normal", - from_source=ConversationFromSource.API, - ) - conversation.id = str(uuid4()) - - # Mock the database query to return no messages - with patch("models.model.db.session.scalars") as mock_scalars: - mock_scalars.return_value.all.return_value = [] - - # Act - result = conversation.status_count - - # Assert - assert result is None - - def test_status_count_messages_without_workflow_runs(self): - """Test status_count when messages have no workflow_run_id.""" - # Arrange - app_id = str(uuid4()) - conversation_id = str(uuid4()) - - conversation = Conversation( - app_id=app_id, - mode=AppMode.CHAT, - name="Test Conversation", - status="normal", - from_source=ConversationFromSource.API, - ) - conversation.id = conversation_id - - # Mock the database query to return no messages with workflow_run_id - with patch("models.model.db.session.scalars") as mock_scalars: - mock_scalars.return_value.all.return_value = [] - - # Act - result = conversation.status_count - - # Assert - assert result is None - - def test_status_count_batch_loading_implementation(self): - """Test that status_count uses batch loading instead of N+1 queries.""" - # Arrange - from graphon.enums import WorkflowExecutionStatus - - app_id = str(uuid4()) - conversation_id = str(uuid4()) - - # Create workflow run IDs - workflow_run_id_1 = str(uuid4()) - workflow_run_id_2 = str(uuid4()) - workflow_run_id_3 = str(uuid4()) - - conversation = Conversation( - app_id=app_id, - mode=AppMode.CHAT, - name="Test Conversation", - status="normal", - from_source=ConversationFromSource.API, - ) - conversation.id = conversation_id - - # Mock messages with workflow_run_id - mock_messages = [ - MagicMock( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id_1, - ), - MagicMock( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id_2, - ), - MagicMock( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id_3, - ), - ] - - # Mock workflow runs with different statuses - mock_workflow_runs = [ - MagicMock( - id=workflow_run_id_1, - status=WorkflowExecutionStatus.SUCCEEDED.value, - app_id=app_id, - ), - MagicMock( - id=workflow_run_id_2, - status=WorkflowExecutionStatus.FAILED.value, - app_id=app_id, - ), - MagicMock( - id=workflow_run_id_3, - status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, - app_id=app_id, - ), - ] - - # Track database calls - calls_made = [] - - def mock_scalars(query): - calls_made.append(str(query)) - mock_result = MagicMock() - - # Return messages for the first query (messages with workflow_run_id) - if "messages" in str(query) and "conversation_id" in str(query): - mock_result.all.return_value = mock_messages - # Return workflow runs for the batch query - elif "workflow_runs" in str(query): - mock_result.all.return_value = mock_workflow_runs - else: - mock_result.all.return_value = [] - - return mock_result - - # Act & Assert - with patch("models.model.db.session.scalars", side_effect=mock_scalars): - result = conversation.status_count - - # Verify only 2 database queries were made (not N+1) - assert len(calls_made) == 2, f"Expected 2 queries, got {len(calls_made)}: {calls_made}" - - # Verify the first query gets messages - assert "messages" in calls_made[0] - assert "conversation_id" in calls_made[0] - - # Verify the second query batch loads workflow runs with proper filtering - assert "workflow_runs" in calls_made[1] - assert "app_id" in calls_made[1] # Security filter applied - assert "IN" in calls_made[1] # Batch loading with IN clause - - # Verify correct status counts - assert result["success"] == 1 # One SUCCEEDED - assert result["failed"] == 1 # One FAILED - assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED - assert result["paused"] == 0 - - def test_status_count_app_id_filtering(self): - """Test that status_count filters workflow runs by app_id for security.""" - # Arrange - app_id = str(uuid4()) - other_app_id = str(uuid4()) - conversation_id = str(uuid4()) - workflow_run_id = str(uuid4()) - - conversation = Conversation( - app_id=app_id, - mode=AppMode.CHAT, - name="Test Conversation", - status="normal", - from_source=ConversationFromSource.API, - ) - conversation.id = conversation_id - - # Mock message with workflow_run_id - mock_messages = [ - MagicMock( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id, - ), - ] - - calls_made = [] - - def mock_scalars(query): - calls_made.append(str(query)) - mock_result = MagicMock() - - if "messages" in str(query): - mock_result.all.return_value = mock_messages - elif "workflow_runs" in str(query): - # Return empty list because no workflow run matches the correct app_id - mock_result.all.return_value = [] # Workflow run filtered out by app_id - else: - mock_result.all.return_value = [] - - return mock_result - - # Act - with patch("models.model.db.session.scalars", side_effect=mock_scalars): - result = conversation.status_count - - # Assert - query should include app_id filter - workflow_query = calls_made[1] - assert "app_id" in workflow_query - - # Since workflow run has wrong app_id, it shouldn't be included in counts - assert result["success"] == 0 - assert result["failed"] == 0 - assert result["partial_success"] == 0 - assert result["paused"] == 0 - - def test_status_count_handles_invalid_workflow_status(self): - """Test that status_count gracefully handles invalid workflow status values.""" - # Arrange - app_id = str(uuid4()) - conversation_id = str(uuid4()) - workflow_run_id = str(uuid4()) - - conversation = Conversation( - app_id=app_id, - mode=AppMode.CHAT, - name="Test Conversation", - status="normal", - from_source=ConversationFromSource.API, - ) - conversation.id = conversation_id - - mock_messages = [ - MagicMock( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id, - ), - ] - - # Mock workflow run with invalid status - mock_workflow_runs = [ - MagicMock( - id=workflow_run_id, - status="invalid_status", # Invalid status that should raise ValueError - app_id=app_id, - ), - ] - - with patch("models.model.db.session.scalars") as mock_scalars: - # Mock the messages query - def mock_scalars_side_effect(query): - mock_result = MagicMock() - if "messages" in str(query): - mock_result.all.return_value = mock_messages - elif "workflow_runs" in str(query): - mock_result.all.return_value = mock_workflow_runs - else: - mock_result.all.return_value = [] - return mock_result - - mock_scalars.side_effect = mock_scalars_side_effect - - # Act - should not raise exception - result = conversation.status_count - - # Assert - should handle invalid status gracefully - assert result["success"] == 0 - assert result["failed"] == 0 - assert result["partial_success"] == 0 - assert result["paused"] == 0 - - def test_status_count_paused(self): - """Test status_count includes paused workflow runs.""" - # Arrange - from graphon.enums import WorkflowExecutionStatus - - app_id = str(uuid4()) - conversation_id = str(uuid4()) - workflow_run_id = str(uuid4()) - - conversation = Conversation( - app_id=app_id, - mode=AppMode.CHAT, - name="Test Conversation", - status="normal", - from_source=ConversationFromSource.API, - ) - conversation.id = conversation_id - - mock_messages = [ - MagicMock( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id, - ), - ] - - mock_workflow_runs = [ - MagicMock( - id=workflow_run_id, - status=WorkflowExecutionStatus.PAUSED.value, - app_id=app_id, - ), - ] - - with patch("models.model.db.session.scalars") as mock_scalars: - - def mock_scalars_side_effect(query): - mock_result = MagicMock() - if "messages" in str(query): - mock_result.all.return_value = mock_messages - elif "workflow_runs" in str(query): - mock_result.all.return_value = mock_workflow_runs - else: - mock_result.all.return_value = [] - return mock_result - - mock_scalars.side_effect = mock_scalars_side_effect - - # Act - result = conversation.status_count - - # Assert - assert result["paused"] == 1 diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 6c8a91129b..51d95c4239 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -12,7 +12,7 @@ This test suite covers: import json import pickle from datetime import UTC, datetime -from unittest.mock import patch +from unittest.mock import Mock, patch from uuid import uuid4 from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -25,6 +25,7 @@ from models.dataset import ( Document, DocumentSegment, Embedding, + ExternalKnowledgeBindings, ) from models.enums import ( DataSourceType, @@ -180,6 +181,24 @@ class TestDatasetModelValidation: assert result["top_k"] == 2 assert result["score_threshold"] == 0.0 + def test_dataset_external_knowledge_info_returns_none_for_cross_tenant_template(self): + """Test external datasets fail closed when the bound template is outside the tenant.""" + dataset = Dataset( + tenant_id=str(uuid4()), + name="External Dataset", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=str(uuid4()), + provider="external", + ) + binding = Mock(spec=ExternalKnowledgeBindings) + binding.external_knowledge_id = "knowledge-1" + binding.external_knowledge_api_id = str(uuid4()) + + with patch("models.dataset.db") as mock_db: + mock_db.session.scalar.side_effect = [binding, None] + + assert dataset.external_knowledge_info is None + def test_dataset_retrieval_model_dict_property(self): """Test retrieval_model_dict property with default values.""" # Arrange diff --git a/api/tests/unit_tests/models/test_model.py b/api/tests/unit_tests/models/test_model.py index a5909f60a8..3f6d6bfbe3 100644 --- a/api/tests/unit_tests/models/test_model.py +++ b/api/tests/unit_tests/models/test_model.py @@ -101,118 +101,6 @@ def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) - return mapping -@pytest.mark.parametrize("owner_cls", [Conversation, Message]) -def test_inputs_resolve_owner_tenant_for_single_file_mapping( - monkeypatch: pytest.MonkeyPatch, - owner_cls: type[Conversation] | type[Message], -): - model_module = importlib.import_module("models.model") - build_calls: list[tuple[dict[str, object], str]] = [] - - monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app") - - def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller): - _ = config, strict_type_validation, access_controller - build_calls.append((dict(mapping), tenant_id)) - return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} - - monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping) - - owner = owner_cls(app_id="app-1") - owner.inputs = {"file": _build_local_file_mapping("upload-1")} - - restored_inputs = owner.inputs - - assert restored_inputs["file"] == {"tenant_id": "tenant-from-app", "upload_file_id": "upload-1"} - assert build_calls == [ - ( - { - **_build_local_file_mapping("upload-1"), - "upload_file_id": "upload-1", - }, - "tenant-from-app", - ) - ] - - -@pytest.mark.parametrize("owner_cls", [Conversation, Message]) -def test_inputs_resolve_owner_tenant_for_file_list_mapping( - monkeypatch: pytest.MonkeyPatch, - owner_cls: type[Conversation] | type[Message], -): - model_module = importlib.import_module("models.model") - build_calls: list[tuple[dict[str, object], str]] = [] - - monkeypatch.setattr(model_module.db.session, "scalar", lambda _: "tenant-from-app") - - def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller): - _ = config, strict_type_validation, access_controller - build_calls.append((dict(mapping), tenant_id)) - return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} - - monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping) - - owner = owner_cls(app_id="app-1") - owner.inputs = { - "files": [ - _build_local_file_mapping("upload-1"), - _build_local_file_mapping("upload-2"), - ] - } - - restored_inputs = owner.inputs - - assert restored_inputs["files"] == [ - {"tenant_id": "tenant-from-app", "upload_file_id": "upload-1"}, - {"tenant_id": "tenant-from-app", "upload_file_id": "upload-2"}, - ] - assert build_calls == [ - ( - { - **_build_local_file_mapping("upload-1"), - "upload_file_id": "upload-1", - }, - "tenant-from-app", - ), - ( - { - **_build_local_file_mapping("upload-2"), - "upload_file_id": "upload-2", - }, - "tenant-from-app", - ), - ] - - -@pytest.mark.parametrize("owner_cls", [Conversation, Message]) -def test_inputs_prefer_serialized_tenant_id_when_present( - monkeypatch: pytest.MonkeyPatch, - owner_cls: type[Conversation] | type[Message], -): - model_module = importlib.import_module("models.model") - - def fail_if_called(_): - raise AssertionError("App tenant lookup should not run when tenant_id exists in the file mapping") - - monkeypatch.setattr(model_module.db.session, "scalar", fail_if_called) - - def fake_build_from_mapping(*, mapping, tenant_id, config=None, strict_type_validation=False, access_controller): - _ = config, strict_type_validation, access_controller - return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} - - monkeypatch.setattr("factories.file_factory.build_from_mapping", fake_build_from_mapping) - - owner = owner_cls(app_id="app-1") - owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id="tenant-from-payload")} - - restored_inputs = owner.inputs - - assert restored_inputs["file"] == { - "tenant_id": "tenant-from-payload", - "upload_file_id": "upload-1", - } - - @pytest.mark.parametrize("owner_cls", [Conversation, Message]) def test_inputs_restore_external_remote_url_file_mappings(owner_cls: type[Conversation] | type[Message]) -> None: owner = owner_cls(app_id="app-1") diff --git a/api/tests/unit_tests/models/test_workflow_trigger_log.py b/api/tests/unit_tests/models/test_workflow_trigger_log.py deleted file mode 100644 index 7fdad92fb6..0000000000 --- a/api/tests/unit_tests/models/test_workflow_trigger_log.py +++ /dev/null @@ -1,188 +0,0 @@ -import types - -import pytest - -from models.engine import db -from models.enums import CreatorUserRole -from models.workflow import WorkflowNodeExecutionModel - - -@pytest.fixture -def fake_db_scalar(monkeypatch): - """Provide a controllable fake for db.session.scalar (SQLAlchemy 2.0 style).""" - calls = [] - - def _install(side_effect): - def _fake_scalar(statement): - calls.append(statement) - return side_effect(statement) - - # Patch the modern API used by the model implementation - monkeypatch.setattr(db.session, "scalar", _fake_scalar) - - # Backward-compatibility: if the implementation still uses db.session.get, - # make it delegate to the same side_effect so tests remain valid on older code. - if hasattr(db.session, "get"): - - def _fake_get(*_args, **_kwargs): - return side_effect(None) - - monkeypatch.setattr(db.session, "get", _fake_get) - - return calls - - return _install - - -def make_account(id_: str = "acc-1"): - # Use a simple object to avoid constructing a full SQLAlchemy model instance - # Python 3.12 forbids reassigning __class__ for SimpleNamespace; not needed here. - obj = types.SimpleNamespace() - obj.id = id_ - return obj - - -def make_end_user(id_: str = "user-1"): - # Lightweight stand-in object; no need to spoof class identity. - obj = types.SimpleNamespace() - obj.id = id_ - return obj - - -def test_created_by_account_returns_account_when_role_account(fake_db_scalar): - account = make_account("acc-1") - - # The implementation uses db.session.scalar(select(Account)...). We only need to - # return the expected object when called; the exact SQL is irrelevant for this unit test. - def side_effect(_statement): - return account - - fake_db_scalar(side_effect) - - log = WorkflowNodeExecutionModel( - tenant_id="t1", - app_id="a1", - workflow_id="w1", - triggered_from="workflow-run", - workflow_run_id=None, - index=1, - predecessor_node_id=None, - node_execution_id=None, - node_id="n1", - node_type="start", - title="Start", - inputs=None, - process_data=None, - outputs=None, - status="succeeded", - error=None, - elapsed_time=0.0, - execution_metadata=None, - created_by_role=CreatorUserRole.ACCOUNT.value, - created_by="acc-1", - ) - - assert log.created_by_account is account - - -def test_created_by_account_returns_none_when_role_not_account(fake_db_scalar): - # Even if an Account with matching id exists, property should return None when role is END_USER - account = make_account("acc-1") - - def side_effect(_statement): - return account - - fake_db_scalar(side_effect) - - log = WorkflowNodeExecutionModel( - tenant_id="t1", - app_id="a1", - workflow_id="w1", - triggered_from="workflow-run", - workflow_run_id=None, - index=1, - predecessor_node_id=None, - node_execution_id=None, - node_id="n1", - node_type="start", - title="Start", - inputs=None, - process_data=None, - outputs=None, - status="succeeded", - error=None, - elapsed_time=0.0, - execution_metadata=None, - created_by_role=CreatorUserRole.END_USER.value, - created_by="acc-1", - ) - - assert log.created_by_account is None - - -def test_created_by_end_user_returns_end_user_when_role_end_user(fake_db_scalar): - end_user = make_end_user("user-1") - - def side_effect(_statement): - return end_user - - fake_db_scalar(side_effect) - - log = WorkflowNodeExecutionModel( - tenant_id="t1", - app_id="a1", - workflow_id="w1", - triggered_from="workflow-run", - workflow_run_id=None, - index=1, - predecessor_node_id=None, - node_execution_id=None, - node_id="n1", - node_type="start", - title="Start", - inputs=None, - process_data=None, - outputs=None, - status="succeeded", - error=None, - elapsed_time=0.0, - execution_metadata=None, - created_by_role=CreatorUserRole.END_USER.value, - created_by="user-1", - ) - - assert log.created_by_end_user is end_user - - -def test_created_by_end_user_returns_none_when_role_not_end_user(fake_db_scalar): - end_user = make_end_user("user-1") - - def side_effect(_statement): - return end_user - - fake_db_scalar(side_effect) - - log = WorkflowNodeExecutionModel( - tenant_id="t1", - app_id="a1", - workflow_id="w1", - triggered_from="workflow-run", - workflow_run_id=None, - index=1, - predecessor_node_id=None, - node_execution_id=None, - node_id="n1", - node_type="start", - title="Start", - inputs=None, - process_data=None, - outputs=None, - status="succeeded", - error=None, - elapsed_time=0.0, - execution_metadata=None, - created_by_role=CreatorUserRole.ACCOUNT.value, - created_by="user-1", - ) - - assert log.created_by_end_user is None diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py b/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py deleted file mode 100644 index 78815a8d1a..0000000000 --- a/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Unit tests for workflow_node_execution repositories. -""" diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py deleted file mode 100644 index 10850970d8..0000000000 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ /dev/null @@ -1,340 +0,0 @@ -""" -Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository. -""" - -import json -import uuid -from datetime import datetime -from decimal import Decimal -from unittest.mock import MagicMock, PropertyMock - -import pytest -from graphon.entities import ( - WorkflowNodeExecution, -) -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.model_runtime.utils.encoders import jsonable_encoder -from pytest_mock import MockerFixture -from sqlalchemy.orm import Session, sessionmaker - -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig -from models.account import Account, Tenant -from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom - - -def configure_mock_execution(mock_execution): - """Configure a mock execution with proper JSON serializable values.""" - # Configure inputs, outputs, process_data, and execution_metadata to return JSON serializable values - type(mock_execution).inputs = PropertyMock(return_value='{"key": "value"}') - type(mock_execution).outputs = PropertyMock(return_value='{"result": "success"}') - type(mock_execution).process_data = PropertyMock(return_value='{"process": "data"}') - type(mock_execution).execution_metadata = PropertyMock(return_value='{"metadata": "info"}') - - # Configure status and triggered_from to be valid enum values - mock_execution.status = "running" - mock_execution.triggered_from = "workflow-run" - - return mock_execution - - -@pytest.fixture -def session(): - """Create a mock SQLAlchemy session.""" - session = MagicMock(spec=Session) - # Configure the session to be used as a context manager - session.__enter__ = MagicMock(return_value=session) - session.__exit__ = MagicMock(return_value=None) - - # Configure the session factory to return the session - session_factory = MagicMock(spec=sessionmaker) - session_factory.return_value = session - return session, session_factory - - -@pytest.fixture -def mock_user(): - """Create a user instance for testing.""" - user = Account(name="test", email="test@example.com") - user.id = "test-user-id" - - tenant = Tenant(name="Test Workspace") - tenant.id = "test-tenant" - user._current_tenant = MagicMock() - user._current_tenant.id = "test-tenant" - - return user - - -@pytest.fixture -def repository(session, mock_user): - """Create a repository instance with test data.""" - _, session_factory = session - app_id = "test-app" - return SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, - user=mock_user, - app_id=app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - -def test_save(repository, session): - """Test save method.""" - session_obj, _ = session - # Create a mock execution - execution = MagicMock(spec=WorkflowNodeExecution) - execution.id = "test-id" - execution.node_execution_id = "test-node-execution-id" - execution.tenant_id = None - execution.app_id = None - execution.inputs = None - execution.process_data = None - execution.outputs = None - execution.metadata = None - execution.workflow_id = str(uuid.uuid4()) - - # Mock the to_db_model method to return the execution itself - # This simulates the behavior of setting tenant_id and app_id - db_model = MagicMock(spec=WorkflowNodeExecutionModel) - db_model.id = "test-id" - db_model.node_execution_id = "test-node-execution-id" - repository._to_db_model = MagicMock(return_value=db_model) - - # Mock session.get to return None (no existing record) - session_obj.get.return_value = None - - # Call save method - repository.save(execution) - - # Assert to_db_model was called with the execution - repository._to_db_model.assert_called_once_with(execution) - - # Assert session.get was called to check for existing record - session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, db_model.id) - - # Assert session.add was called for new record - session_obj.add.assert_called_once_with(db_model) - - # Assert session.commit was called - session_obj.commit.assert_called_once() - - -def test_save_with_existing_tenant_id(repository, session): - """Test save method with existing tenant_id.""" - session_obj, _ = session - # Create a mock execution with existing tenant_id - execution = MagicMock(spec=WorkflowNodeExecutionModel) - execution.id = "existing-id" - execution.node_execution_id = "existing-node-execution-id" - execution.tenant_id = "existing-tenant" - execution.app_id = None - execution.inputs = None - execution.process_data = None - execution.outputs = None - execution.metadata = None - - # Create a modified execution that will be returned by _to_db_model - modified_execution = MagicMock(spec=WorkflowNodeExecutionModel) - modified_execution.id = "existing-id" - modified_execution.node_execution_id = "existing-node-execution-id" - modified_execution.tenant_id = "existing-tenant" # Tenant ID should not change - modified_execution.app_id = repository._app_id # App ID should be set - # Create a dictionary to simulate __dict__ for updating attributes - modified_execution.__dict__ = { - "id": "existing-id", - "node_execution_id": "existing-node-execution-id", - "tenant_id": "existing-tenant", - "app_id": repository._app_id, - } - - # Mock the to_db_model method to return the modified execution - repository._to_db_model = MagicMock(return_value=modified_execution) - - # Mock session.get to return an existing record - existing_model = MagicMock(spec=WorkflowNodeExecutionModel) - session_obj.get.return_value = existing_model - - # Call save method - repository.save(execution) - - # Assert to_db_model was called with the execution - repository._to_db_model.assert_called_once_with(execution) - - # Assert session.get was called to check for existing record - session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, modified_execution.id) - - # Assert session.add was NOT called since we're updating existing - session_obj.add.assert_not_called() - - # Assert session.commit was called - session_obj.commit.assert_called_once() - - -def test_get_by_workflow_execution(repository, session, mocker: MockerFixture): - """Test get_by_workflow_execution method.""" - session_obj, _ = session - # Set up mock - mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select") - mock_asc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.asc") - mock_desc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.desc") - - mock_WorkflowNodeExecutionModel = mocker.patch( - "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel" - ) - mock_stmt = mocker.MagicMock() - mock_select.return_value = mock_stmt - mock_stmt.where.return_value = mock_stmt - mock_stmt.order_by.return_value = mock_stmt - mock_asc.return_value = mock_stmt - mock_desc.return_value = mock_stmt - mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.return_value = mock_stmt - - # Create a properly configured mock execution - mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel) - configure_mock_execution(mock_execution) - session_obj.scalars.return_value.all.return_value = [mock_execution] - - # Create a mock domain model to be returned by _to_domain_model - mock_domain_model = mocker.MagicMock() - # Mock the _to_domain_model method to return our mock domain model - repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model) - - # Call method - order_config = OrderConfig(order_by=["index"], order_direction="desc") - result = repository.get_by_workflow_execution( - workflow_execution_id="test-workflow-run-id", - order_config=order_config, - ) - - # Assert select was called with correct parameters - mock_select.assert_called_once() - session_obj.scalars.assert_called_once_with(mock_stmt) - mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.assert_called_once_with(mock_stmt) - # Assert _to_domain_model was called with the mock execution - repository._to_domain_model.assert_called_once_with(mock_execution) - # Assert the result contains our mock domain model - assert len(result) == 1 - assert result[0] is mock_domain_model - - -def test_to_db_model(repository): - """Test to_db_model method.""" - # Create a domain model - domain_model = WorkflowNodeExecution( - id="test-id", - workflow_id="test-workflow-id", - node_execution_id="test-node-execution-id", - workflow_execution_id="test-workflow-run-id", - index=1, - predecessor_node_id="test-predecessor-id", - node_id="test-node-id", - node_type=BuiltinNodeTypes.START, - title="Test Node", - inputs={"input_key": "input_value"}, - process_data={"process_key": "process_value"}, - outputs={"output_key": "output_value"}, - status=WorkflowNodeExecutionStatus.RUNNING, - error=None, - elapsed_time=1.5, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"), - }, - created_at=datetime.now(), - finished_at=None, - ) - - # Convert to DB model - db_model = repository._to_db_model(domain_model) - - # Assert DB model has correct values - assert isinstance(db_model, WorkflowNodeExecutionModel) - assert db_model.id == domain_model.id - assert db_model.tenant_id == repository._tenant_id - assert db_model.app_id == repository._app_id - assert db_model.workflow_id == domain_model.workflow_id - assert db_model.triggered_from == repository._triggered_from - assert db_model.workflow_run_id == domain_model.workflow_execution_id - assert db_model.index == domain_model.index - assert db_model.predecessor_node_id == domain_model.predecessor_node_id - assert db_model.node_execution_id == domain_model.node_execution_id - assert db_model.node_id == domain_model.node_id - assert db_model.node_type == domain_model.node_type - assert db_model.title == domain_model.title - - assert db_model.inputs_dict == domain_model.inputs - assert db_model.process_data_dict == domain_model.process_data - assert db_model.outputs_dict == domain_model.outputs - assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata) - - assert db_model.status == domain_model.status - assert db_model.error == domain_model.error - assert db_model.elapsed_time == domain_model.elapsed_time - assert db_model.created_at == domain_model.created_at - assert db_model.created_by_role == repository._creator_user_role - assert db_model.created_by == repository._creator_user_id - assert db_model.finished_at == domain_model.finished_at - - -def test_to_domain_model(repository): - """Test _to_domain_model method.""" - # Create input dictionaries - inputs_dict = {"input_key": "input_value"} - process_data_dict = {"process_key": "process_value"} - outputs_dict = {"output_key": "output_value"} - metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100} - - # Create a DB model using our custom subclass - db_model = WorkflowNodeExecutionModel() - db_model.id = "test-id" - db_model.tenant_id = "test-tenant-id" - db_model.app_id = "test-app-id" - db_model.workflow_id = "test-workflow-id" - db_model.triggered_from = "workflow-run" - db_model.workflow_run_id = "test-workflow-run-id" - db_model.index = 1 - db_model.predecessor_node_id = "test-predecessor-id" - db_model.node_execution_id = "test-node-execution-id" - db_model.node_id = "test-node-id" - db_model.node_type = BuiltinNodeTypes.START - db_model.title = "Test Node" - db_model.inputs = json.dumps(inputs_dict) - db_model.process_data = json.dumps(process_data_dict) - db_model.outputs = json.dumps(outputs_dict) - db_model.status = WorkflowNodeExecutionStatus.RUNNING - db_model.error = None - db_model.elapsed_time = 1.5 - db_model.execution_metadata = json.dumps(metadata_dict) - db_model.created_at = datetime.now() - db_model.created_by_role = "account" - db_model.created_by = "test-user-id" - db_model.finished_at = None - - # Convert to domain model - domain_model = repository._to_domain_model(db_model) - - # Assert domain model has correct values - assert isinstance(domain_model, WorkflowNodeExecution) - assert domain_model.id == db_model.id - assert domain_model.workflow_id == db_model.workflow_id - assert domain_model.workflow_execution_id == db_model.workflow_run_id - assert domain_model.index == db_model.index - assert domain_model.predecessor_node_id == db_model.predecessor_node_id - assert domain_model.node_execution_id == db_model.node_execution_id - assert domain_model.node_id == db_model.node_id - assert domain_model.node_type == db_model.node_type - assert domain_model.title == db_model.title - assert domain_model.inputs == inputs_dict - assert domain_model.process_data == process_data_dict - assert domain_model.outputs == outputs_dict - assert domain_model.status == WorkflowNodeExecutionStatus(db_model.status) - assert domain_model.error == db_model.error - assert domain_model.elapsed_time == db_model.elapsed_time - assert domain_model.metadata == metadata_dict - assert domain_model.created_at == db_model.created_at - assert domain_model.finished_at == db_model.finished_at diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py deleted file mode 100644 index 2322be9e80..0000000000 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality. -""" - -from datetime import datetime -from typing import Any -from unittest.mock import MagicMock, Mock - -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes -from sqlalchemy.orm import sessionmaker - -from core.repositories.sqlalchemy_workflow_node_execution_repository import ( - SQLAlchemyWorkflowNodeExecutionRepository, -) -from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom - - -class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData: - """Test process_data truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository.""" - - def create_mock_account(self) -> Account: - """Create a mock Account for testing.""" - account = Mock(spec=Account) - account.id = "test-user-id" - account.tenant_id = "test-tenant-id" - return account - - def create_mock_session_factory(self) -> sessionmaker: - """Create a mock session factory for testing.""" - mock_session = MagicMock() - mock_session_factory = MagicMock(spec=sessionmaker) - mock_session_factory.return_value.__enter__.return_value = mock_session - mock_session_factory.return_value.__exit__.return_value = None - return mock_session_factory - - def create_repository(self, mock_file_service=None) -> SQLAlchemyWorkflowNodeExecutionRepository: - """Create a repository instance for testing.""" - mock_account = self.create_mock_account() - mock_session_factory = self.create_mock_session_factory() - - repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=mock_session_factory, - user=mock_account, - app_id="test-app-id", - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - if mock_file_service: - repository._file_service = mock_file_service - - return repository - - def create_workflow_node_execution( - self, - process_data: dict[str, Any] | None = None, - execution_id: str = "test-execution-id", - ) -> WorkflowNodeExecution: - """Create a WorkflowNodeExecution instance for testing.""" - return WorkflowNodeExecution( - id=execution_id, - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=process_data, - created_at=datetime.now(), - ) - - def test_to_domain_model_without_offload_data(self): - """Test _to_domain_model without offload data.""" - repository = self.create_repository() - - # Create mock database model without offload data - db_model = Mock(spec=WorkflowNodeExecutionModel) - db_model.id = "test-execution-id" - db_model.node_execution_id = "test-node-execution-id" - db_model.workflow_id = "test-workflow-id" - db_model.workflow_run_id = None - db_model.index = 1 - db_model.predecessor_node_id = None - db_model.node_id = "test-node-id" - db_model.node_type = "llm" - db_model.title = "Test Node" - db_model.status = "succeeded" - db_model.error = None - db_model.elapsed_time = 1.5 - db_model.created_at = datetime.now() - db_model.finished_at = None - - process_data = {"normal": "data"} - db_model.process_data_dict = process_data - db_model.inputs_dict = None - db_model.outputs_dict = None - db_model.execution_metadata_dict = {} - db_model.offload_data = None - - domain_model = repository._to_domain_model(db_model) - - # Domain model should have the data from database - assert domain_model.process_data == process_data - - # Should not be truncated - assert domain_model.process_data_truncated is False - assert domain_model.get_truncated_process_data() is None diff --git a/api/tests/unit_tests/services/dataset_metadata.py b/api/tests/unit_tests/services/dataset_metadata.py deleted file mode 100644 index b825a8686a..0000000000 --- a/api/tests/unit_tests/services/dataset_metadata.py +++ /dev/null @@ -1,1014 +0,0 @@ -""" -Comprehensive unit tests for MetadataService. - -This module contains extensive unit tests for the MetadataService class, -which handles dataset metadata CRUD operations and filtering/querying functionality. - -The MetadataService provides methods for: -- Creating, reading, updating, and deleting metadata fields -- Managing built-in metadata fields -- Updating document metadata values -- Metadata filtering and querying operations -- Lock management for concurrent metadata operations - -Metadata in Dify allows users to add custom fields to datasets and documents, -enabling rich filtering and search capabilities. Metadata can be of various -types (string, number, date, boolean, etc.) and can be used to categorize -and filter documents within a dataset. - -This test suite ensures: -- Correct creation of metadata fields with validation -- Proper updating of metadata names and values -- Accurate deletion of metadata fields -- Built-in field management (enable/disable) -- Document metadata updates (partial and full) -- Lock management for concurrent operations -- Metadata querying and filtering functionality - -================================================================================ -ARCHITECTURE OVERVIEW -================================================================================ - -The MetadataService is a critical component in the Dify platform's metadata -management system. It serves as the primary interface for all metadata-related -operations, including field definitions and document-level metadata values. - -Key Concepts: -1. DatasetMetadata: Defines a metadata field for a dataset. Each metadata - field has a name, type, and is associated with a specific dataset. - -2. DatasetMetadataBinding: Links metadata fields to documents. This allows - tracking which documents have which metadata fields assigned. - -3. Document Metadata: The actual metadata values stored on documents. This - is stored as a JSON object in the document's doc_metadata field. - -4. Built-in Fields: System-defined metadata fields that are automatically - available when enabled (document_name, uploader, upload_date, etc.). - -5. Lock Management: Redis-based locking to prevent concurrent metadata - operations that could cause data corruption. - -================================================================================ -TESTING STRATEGY -================================================================================ - -This test suite follows a comprehensive testing strategy that covers: - -1. CRUD Operations: - - Creating metadata fields with validation - - Reading/retrieving metadata fields - - Updating metadata field names - - Deleting metadata fields - -2. Built-in Field Management: - - Enabling built-in fields - - Disabling built-in fields - - Getting built-in field definitions - -3. Document Metadata Operations: - - Updating document metadata (partial and full) - - Managing metadata bindings - - Handling built-in field updates - -4. Lock Management: - - Acquiring locks for dataset operations - - Acquiring locks for document operations - - Handling lock conflicts - -5. Error Handling: - - Validation errors (name length, duplicates) - - Not found errors - - Lock conflict errors - -================================================================================ -""" - -from unittest.mock import Mock, patch - -import pytest - -from core.rag.index_processor.constant.built_in_field import BuiltInField -from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding -from services.entities.knowledge_entities.knowledge_entities import ( - MetadataArgs, - MetadataValue, -) -from services.metadata_service import MetadataService - -# ============================================================================ -# Test Data Factory -# ============================================================================ -# The Test Data Factory pattern is used here to centralize the creation of -# test objects and mock instances. This approach provides several benefits: -# -# 1. Consistency: All test objects are created using the same factory methods, -# ensuring consistent structure across all tests. -# -# 2. Maintainability: If the structure of models changes, we only need to -# update the factory methods rather than every individual test. -# -# 3. Reusability: Factory methods can be reused across multiple test classes, -# reducing code duplication. -# -# 4. Readability: Tests become more readable when they use descriptive factory -# method calls instead of complex object construction logic. -# -# ============================================================================ - - -class MetadataTestDataFactory: - """ - Factory class for creating test data and mock objects for metadata service tests. - - This factory provides static methods to create mock objects for: - - DatasetMetadata instances - - DatasetMetadataBinding instances - - Dataset instances - - Document instances - - MetadataArgs and MetadataOperationData entities - - User and tenant context - - The factory methods help maintain consistency across tests and reduce - code duplication when setting up test scenarios. - """ - - @staticmethod - def create_metadata_mock( - metadata_id: str = "metadata-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - name: str = "category", - metadata_type: str = "string", - created_by: str = "user-123", - **kwargs, - ) -> Mock: - """ - Create a mock DatasetMetadata with specified attributes. - - Args: - metadata_id: Unique identifier for the metadata field - dataset_id: ID of the dataset this metadata belongs to - tenant_id: Tenant identifier - name: Name of the metadata field - metadata_type: Type of metadata (string, number, date, etc.) - created_by: ID of the user who created the metadata - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a DatasetMetadata instance - """ - metadata = Mock(spec=DatasetMetadata) - metadata.id = metadata_id - metadata.dataset_id = dataset_id - metadata.tenant_id = tenant_id - metadata.name = name - metadata.type = metadata_type - metadata.created_by = created_by - metadata.updated_by = None - metadata.updated_at = None - for key, value in kwargs.items(): - setattr(metadata, key, value) - return metadata - - @staticmethod - def create_metadata_binding_mock( - binding_id: str = "binding-123", - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - metadata_id: str = "metadata-123", - document_id: str = "document-123", - created_by: str = "user-123", - **kwargs, - ) -> Mock: - """ - Create a mock DatasetMetadataBinding with specified attributes. - - Args: - binding_id: Unique identifier for the binding - dataset_id: ID of the dataset - tenant_id: Tenant identifier - metadata_id: ID of the metadata field - document_id: ID of the document - created_by: ID of the user who created the binding - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a DatasetMetadataBinding instance - """ - binding = Mock(spec=DatasetMetadataBinding) - binding.id = binding_id - binding.dataset_id = dataset_id - binding.tenant_id = tenant_id - binding.metadata_id = metadata_id - binding.document_id = document_id - binding.created_by = created_by - for key, value in kwargs.items(): - setattr(binding, key, value) - return binding - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - built_in_field_enabled: bool = False, - doc_metadata: list | None = None, - **kwargs, - ) -> Mock: - """ - Create a mock Dataset with specified attributes. - - Args: - dataset_id: Unique identifier for the dataset - tenant_id: Tenant identifier - built_in_field_enabled: Whether built-in fields are enabled - doc_metadata: List of metadata field definitions - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a Dataset instance - """ - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.built_in_field_enabled = built_in_field_enabled - dataset.doc_metadata = doc_metadata or [] - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_document_mock( - document_id: str = "document-123", - dataset_id: str = "dataset-123", - name: str = "Test Document", - doc_metadata: dict | None = None, - uploader: str = "user-123", - data_source_type: str = "upload_file", - **kwargs, - ) -> Mock: - """ - Create a mock Document with specified attributes. - - Args: - document_id: Unique identifier for the document - dataset_id: ID of the dataset this document belongs to - name: Name of the document - doc_metadata: Dictionary of metadata values - uploader: ID of the user who uploaded the document - data_source_type: Type of data source - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a Document instance - """ - document = Mock() - document.id = document_id - document.dataset_id = dataset_id - document.name = name - document.doc_metadata = doc_metadata or {} - document.uploader = uploader - document.data_source_type = data_source_type - - # Mock datetime objects for upload_date and last_update_date - - document.upload_date = Mock() - document.upload_date.timestamp.return_value = 1234567890.0 - document.last_update_date = Mock() - document.last_update_date.timestamp.return_value = 1234567890.0 - - for key, value in kwargs.items(): - setattr(document, key, value) - return document - - @staticmethod - def create_metadata_args_mock( - name: str = "category", - metadata_type: str = "string", - ) -> Mock: - """ - Create a mock MetadataArgs entity. - - Args: - name: Name of the metadata field - metadata_type: Type of metadata - - Returns: - Mock object configured as a MetadataArgs instance - """ - metadata_args = Mock(spec=MetadataArgs) - metadata_args.name = name - metadata_args.type = metadata_type - return metadata_args - - @staticmethod - def create_metadata_value_mock( - metadata_id: str = "metadata-123", - name: str = "category", - value: str = "test", - ) -> Mock: - """ - Create a mock MetadataValue entity. - - Args: - metadata_id: ID of the metadata field - name: Name of the metadata field - value: Value of the metadata - - Returns: - Mock object configured as a MetadataValue instance - """ - metadata_value = Mock(spec=MetadataValue) - metadata_value.id = metadata_id - metadata_value.name = name - metadata_value.value = value - return metadata_value - - -# ============================================================================ -# Tests for create_metadata -# ============================================================================ - - -class TestMetadataServiceCreateMetadata: - """ - Comprehensive unit tests for MetadataService.create_metadata method. - - This test class covers the metadata field creation functionality, - including validation, duplicate checking, and database operations. - - The create_metadata method: - 1. Validates metadata name length (max 255 characters) - 2. Checks for duplicate metadata names within the dataset - 3. Checks for conflicts with built-in field names - 4. Creates a new DatasetMetadata instance - 5. Adds it to the database session and commits - 6. Returns the created metadata - - Test scenarios include: - - Successful creation with valid data - - Name length validation - - Duplicate name detection - - Built-in field name conflicts - - Database transaction handling - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing database operations. - - Provides a mocked database session that can be used to verify: - - Query construction and execution - - Add operations for new metadata - - Commit operations for transaction completion - """ - with patch("services.metadata_service.db.session") as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """ - Mock current user and tenant context. - - Provides mocked current_account_with_tenant function that returns - a user and tenant ID for testing authentication and authorization. - """ - with patch("services.metadata_service.current_account_with_tenant") as mock_get_user: - mock_user = Mock() - mock_user.id = "user-123" - mock_tenant_id = "tenant-123" - mock_get_user.return_value = (mock_user, mock_tenant_id) - yield mock_get_user - - def test_create_metadata_success(self, mock_db_session, mock_current_user): - """ - Test successful creation of a metadata field. - - Verifies that when all validation passes, a new metadata field - is created and persisted to the database. - - This test ensures: - - Metadata name validation passes - - No duplicate name exists - - No built-in field conflict - - New metadata is added to database - - Transaction is committed - - Created metadata is returned - """ - # Arrange - dataset_id = "dataset-123" - metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string") - - # Mock query to return None (no existing metadata with same name) - mock_db_session.scalar.return_value = None - - # Mock BuiltInField enum iteration - with patch("services.metadata_service.BuiltInField") as mock_builtin: - mock_builtin.__iter__ = Mock(return_value=iter([])) - - # Act - result = MetadataService.create_metadata(dataset_id, metadata_args) - - # Assert - assert result is not None - assert isinstance(result, DatasetMetadata) - - # Verify metadata was added and committed - mock_db_session.add.assert_called_once() - mock_db_session.commit.assert_called_once() - - def test_create_metadata_name_too_long_error(self, mock_db_session, mock_current_user): - """ - Test error handling when metadata name exceeds 255 characters. - - Verifies that when a metadata name is longer than 255 characters, - a ValueError is raised with an appropriate message. - - This test ensures: - - Name length validation is enforced - - Error message is clear and descriptive - - No database operations are performed - """ - # Arrange - dataset_id = "dataset-123" - long_name = "a" * 256 # 256 characters (exceeds limit) - metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name=long_name, metadata_type="string") - - # Act & Assert - with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters"): - MetadataService.create_metadata(dataset_id, metadata_args) - - # Verify no database operations were performed - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - def test_create_metadata_duplicate_name_error(self, mock_db_session, mock_current_user): - """ - Test error handling when metadata name already exists. - - Verifies that when a metadata field with the same name already exists - in the dataset, a ValueError is raised. - - This test ensures: - - Duplicate name detection works correctly - - Error message is clear - - No new metadata is created - """ - # Arrange - dataset_id = "dataset-123" - metadata_args = MetadataTestDataFactory.create_metadata_args_mock(name="category", metadata_type="string") - - # Mock existing metadata with same name - existing_metadata = MetadataTestDataFactory.create_metadata_mock(name="category") - mock_db_session.scalar.return_value = existing_metadata - - # Act & Assert - with pytest.raises(ValueError, match="Metadata name already exists"): - MetadataService.create_metadata(dataset_id, metadata_args) - - # Verify no new metadata was added - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - def test_create_metadata_builtin_field_conflict_error(self, mock_db_session, mock_current_user): - """ - Test error handling when metadata name conflicts with built-in field. - - Verifies that when a metadata name matches a built-in field name, - a ValueError is raised. - - This test ensures: - - Built-in field name conflicts are detected - - Error message is clear - - No new metadata is created - """ - # Arrange - dataset_id = "dataset-123" - metadata_args = MetadataTestDataFactory.create_metadata_args_mock( - name=BuiltInField.document_name, metadata_type="string" - ) - - # Mock query to return None (no duplicate in database) - mock_db_session.scalar.return_value = None - - # Mock BuiltInField to include the conflicting name - with patch("services.metadata_service.BuiltInField") as mock_builtin: - mock_field = Mock() - mock_field.value = BuiltInField.document_name - mock_builtin.__iter__ = Mock(return_value=iter([mock_field])) - - # Act & Assert - with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields"): - MetadataService.create_metadata(dataset_id, metadata_args) - - # Verify no new metadata was added - mock_db_session.add.assert_not_called() - mock_db_session.commit.assert_not_called() - - -# ============================================================================ -# Tests for update_metadata_name -# ============================================================================ - - -class TestMetadataServiceUpdateMetadataName: - """ - Comprehensive unit tests for MetadataService.update_metadata_name method. - - This test class covers the metadata field name update functionality, - including validation, duplicate checking, and document metadata updates. - - The update_metadata_name method: - 1. Validates new name length (max 255 characters) - 2. Checks for duplicate names - 3. Checks for built-in field conflicts - 4. Acquires a lock for the dataset - 5. Updates the metadata name - 6. Updates all related document metadata - 7. Releases the lock - 8. Returns the updated metadata - - Test scenarios include: - - Successful name update - - Name length validation - - Duplicate name detection - - Built-in field conflicts - - Lock management - - Document metadata updates - """ - - @pytest.fixture - def mock_db_session(self): - """Mock database session for testing.""" - with patch("services.metadata_service.db.session") as mock_db: - yield mock_db - - @pytest.fixture - def mock_current_user(self): - """Mock current user and tenant context.""" - with patch("services.metadata_service.current_account_with_tenant") as mock_get_user: - mock_user = Mock() - mock_user.id = "user-123" - mock_tenant_id = "tenant-123" - mock_get_user.return_value = (mock_user, mock_tenant_id) - yield mock_get_user - - @pytest.fixture - def mock_redis_client(self): - """Mock Redis client for lock management.""" - with patch("services.metadata_service.redis_client") as mock_redis: - mock_redis.get.return_value = None # No existing lock - mock_redis.set.return_value = True - mock_redis.delete.return_value = True - yield mock_redis - - def test_update_metadata_name_success(self, mock_db_session, mock_current_user, mock_redis_client): - """ - Test successful update of metadata field name. - - Verifies that when all validation passes, the metadata name is - updated and all related document metadata is updated accordingly. - - This test ensures: - - Name validation passes - - Lock is acquired and released - - Metadata name is updated - - Related document metadata is updated - - Transaction is committed - """ - # Arrange - dataset_id = "dataset-123" - metadata_id = "metadata-123" - new_name = "updated_category" - - existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category") - - # Mock scalar calls: first for duplicate check (None), second for metadata retrieval - mock_db_session.scalar.side_effect = [None, existing_metadata] - - # Mock no metadata bindings (no documents to update) - mock_db_session.scalars.return_value.all.return_value = [] - - # Mock BuiltInField enum - with patch("services.metadata_service.BuiltInField") as mock_builtin: - mock_builtin.__iter__ = Mock(return_value=iter([])) - - # Act - result = MetadataService.update_metadata_name(dataset_id, metadata_id, new_name) - - # Assert - assert result is not None - assert result.name == new_name - - # Verify lock was acquired and released - mock_redis_client.get.assert_called() - mock_redis_client.set.assert_called() - mock_redis_client.delete.assert_called() - - # Verify metadata was updated and committed - mock_db_session.commit.assert_called() - - def test_update_metadata_name_not_found_error(self, mock_db_session, mock_current_user, mock_redis_client): - """ - Test error handling when metadata is not found. - - Verifies that when the metadata ID doesn't exist, a ValueError - is raised with an appropriate message. - - This test ensures: - - Not found error is handled correctly - - Lock is properly released even on error - - No updates are committed - """ - # Arrange - dataset_id = "dataset-123" - metadata_id = "non-existent-metadata" - new_name = "updated_category" - - # Mock scalar calls: first for duplicate check (None), second for metadata retrieval (None = not found) - mock_db_session.scalar.side_effect = [None, None] - - # Mock BuiltInField enum - with patch("services.metadata_service.BuiltInField") as mock_builtin: - mock_builtin.__iter__ = Mock(return_value=iter([])) - - # Act & Assert - with pytest.raises(ValueError, match="Metadata not found"): - MetadataService.update_metadata_name(dataset_id, metadata_id, new_name) - - # Verify lock was released - mock_redis_client.delete.assert_called() - - -# ============================================================================ -# Tests for delete_metadata -# ============================================================================ - - -class TestMetadataServiceDeleteMetadata: - """ - Comprehensive unit tests for MetadataService.delete_metadata method. - - This test class covers the metadata field deletion functionality, - including document metadata cleanup and lock management. - - The delete_metadata method: - 1. Acquires a lock for the dataset - 2. Retrieves the metadata to delete - 3. Deletes the metadata from the database - 4. Removes metadata from all related documents - 5. Releases the lock - 6. Returns the deleted metadata - - Test scenarios include: - - Successful deletion - - Not found error handling - - Document metadata cleanup - - Lock management - """ - - @pytest.fixture - def mock_db_session(self): - """Mock database session for testing.""" - with patch("services.metadata_service.db.session") as mock_db: - yield mock_db - - @pytest.fixture - def mock_redis_client(self): - """Mock Redis client for lock management.""" - with patch("services.metadata_service.redis_client") as mock_redis: - mock_redis.get.return_value = None - mock_redis.set.return_value = True - mock_redis.delete.return_value = True - yield mock_redis - - def test_delete_metadata_success(self, mock_db_session, mock_redis_client): - """ - Test successful deletion of a metadata field. - - Verifies that when the metadata exists, it is deleted and all - related document metadata is cleaned up. - - This test ensures: - - Lock is acquired and released - - Metadata is deleted from database - - Related document metadata is removed - - Transaction is committed - """ - # Arrange - dataset_id = "dataset-123" - metadata_id = "metadata-123" - - existing_metadata = MetadataTestDataFactory.create_metadata_mock(metadata_id=metadata_id, name="category") - - # Mock metadata retrieval - mock_db_session.scalar.return_value = existing_metadata - - # Mock no metadata bindings (no documents to update) - mock_db_session.scalars.return_value.all.return_value = [] - - # Act - result = MetadataService.delete_metadata(dataset_id, metadata_id) - - # Assert - assert result == existing_metadata - - # Verify lock was acquired and released - mock_redis_client.get.assert_called() - mock_redis_client.set.assert_called() - mock_redis_client.delete.assert_called() - - # Verify metadata was deleted and committed - mock_db_session.delete.assert_called_once_with(existing_metadata) - mock_db_session.commit.assert_called() - - def test_delete_metadata_not_found_error(self, mock_db_session, mock_redis_client): - """ - Test error handling when metadata is not found. - - Verifies that when the metadata ID doesn't exist, a ValueError - is raised and the lock is properly released. - - This test ensures: - - Not found error is handled correctly - - Lock is released even on error - - No deletion is performed - """ - # Arrange - dataset_id = "dataset-123" - metadata_id = "non-existent-metadata" - - # Mock metadata retrieval to return None - mock_db_session.scalar.return_value = None - - # Act & Assert - with pytest.raises(ValueError, match="Metadata not found"): - MetadataService.delete_metadata(dataset_id, metadata_id) - - # Verify lock was released - mock_redis_client.delete.assert_called() - - # Verify no deletion was performed - mock_db_session.delete.assert_not_called() - - -# ============================================================================ -# Tests for get_built_in_fields -# ============================================================================ - - -class TestMetadataServiceGetBuiltInFields: - """ - Comprehensive unit tests for MetadataService.get_built_in_fields method. - - This test class covers the built-in field retrieval functionality. - - The get_built_in_fields method: - 1. Returns a list of built-in field definitions - 2. Each definition includes name and type - - Test scenarios include: - - Successful retrieval of built-in fields - - Correct field definitions - """ - - def test_get_built_in_fields_success(self): - """ - Test successful retrieval of built-in fields. - - Verifies that the method returns the correct list of built-in - field definitions with proper structure. - - This test ensures: - - All built-in fields are returned - - Each field has name and type - - Field definitions are correct - """ - # Act - result = MetadataService.get_built_in_fields() - - # Assert - assert isinstance(result, list) - assert len(result) > 0 - - # Verify each field has required properties - for field in result: - assert "name" in field - assert "type" in field - assert isinstance(field["name"], str) - assert isinstance(field["type"], str) - - # Verify specific built-in fields are present - field_names = [field["name"] for field in result] - assert BuiltInField.document_name in field_names - assert BuiltInField.uploader in field_names - - -# ============================================================================ -# Tests for knowledge_base_metadata_lock_check -# ============================================================================ - - -class TestMetadataServiceLockCheck: - """ - Comprehensive unit tests for MetadataService.knowledge_base_metadata_lock_check method. - - This test class covers the lock management functionality for preventing - concurrent metadata operations. - - The knowledge_base_metadata_lock_check method: - 1. Checks if a lock exists for the dataset or document - 2. Raises ValueError if lock exists (operation in progress) - 3. Sets a lock with expiration time (3600 seconds) - 4. Supports both dataset-level and document-level locks - - Test scenarios include: - - Successful lock acquisition - - Lock conflict detection - - Dataset-level locks - - Document-level locks - """ - - @pytest.fixture - def mock_redis_client(self): - """Mock Redis client for lock management.""" - with patch("services.metadata_service.redis_client") as mock_redis: - yield mock_redis - - def test_lock_check_dataset_success(self, mock_redis_client): - """ - Test successful lock acquisition for dataset operations. - - Verifies that when no lock exists, a new lock is acquired - for the dataset. - - This test ensures: - - Lock check passes when no lock exists - - Lock is set with correct key and expiration - - No error is raised - """ - # Arrange - dataset_id = "dataset-123" - mock_redis_client.get.return_value = None # No existing lock - - # Act (should not raise) - MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - - # Assert - mock_redis_client.get.assert_called_once_with(f"dataset_metadata_lock_{dataset_id}") - mock_redis_client.set.assert_called_once_with(f"dataset_metadata_lock_{dataset_id}", 1, ex=3600) - - def test_lock_check_dataset_conflict_error(self, mock_redis_client): - """ - Test error handling when dataset lock already exists. - - Verifies that when a lock exists for the dataset, a ValueError - is raised with an appropriate message. - - This test ensures: - - Lock conflict is detected - - Error message is clear - - No new lock is set - """ - # Arrange - dataset_id = "dataset-123" - mock_redis_client.get.return_value = "1" # Lock exists - - # Act & Assert - with pytest.raises(ValueError, match="Another knowledge base metadata operation is running"): - MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - - # Verify lock was checked but not set - mock_redis_client.get.assert_called_once() - mock_redis_client.set.assert_not_called() - - def test_lock_check_document_success(self, mock_redis_client): - """ - Test successful lock acquisition for document operations. - - Verifies that when no lock exists, a new lock is acquired - for the document. - - This test ensures: - - Lock check passes when no lock exists - - Lock is set with correct key and expiration - - No error is raised - """ - # Arrange - document_id = "document-123" - mock_redis_client.get.return_value = None # No existing lock - - # Act (should not raise) - MetadataService.knowledge_base_metadata_lock_check(None, document_id) - - # Assert - mock_redis_client.get.assert_called_once_with(f"document_metadata_lock_{document_id}") - mock_redis_client.set.assert_called_once_with(f"document_metadata_lock_{document_id}", 1, ex=3600) - - -# ============================================================================ -# Tests for get_dataset_metadatas -# ============================================================================ - - -class TestMetadataServiceGetDatasetMetadatas: - """ - Comprehensive unit tests for MetadataService.get_dataset_metadatas method. - - This test class covers the metadata retrieval functionality for datasets. - - The get_dataset_metadatas method: - 1. Retrieves all metadata fields for a dataset - 2. Excludes built-in fields from the list - 3. Includes usage count for each metadata field - 4. Returns built-in field enabled status - - Test scenarios include: - - Successful retrieval with metadata fields - - Empty metadata list - - Built-in field filtering - - Usage count calculation - """ - - @pytest.fixture - def mock_db_session(self): - """Mock database session for testing.""" - with patch("services.metadata_service.db.session") as mock_db: - yield mock_db - - def test_get_dataset_metadatas_success(self, mock_db_session): - """ - Test successful retrieval of dataset metadata fields. - - Verifies that all metadata fields are returned with correct - structure and usage counts. - - This test ensures: - - All metadata fields are included - - Built-in fields are excluded - - Usage counts are calculated correctly - - Built-in field status is included - """ - # Arrange - dataset = MetadataTestDataFactory.create_dataset_mock( - dataset_id="dataset-123", - built_in_field_enabled=True, - doc_metadata=[ - {"id": "metadata-1", "name": "category", "type": "string"}, - {"id": "metadata-2", "name": "priority", "type": "number"}, - {"id": "built-in", "name": "document_name", "type": "string"}, - ], - ) - - # Mock usage count queries - mock_db_session.scalar.return_value = 5 # 5 documents use this metadata - - # Act - result = MetadataService.get_dataset_metadatas(dataset) - - # Assert - assert "doc_metadata" in result - assert "built_in_field_enabled" in result - assert result["built_in_field_enabled"] is True - - # Verify built-in fields are excluded - metadata_ids = [meta["id"] for meta in result["doc_metadata"]] - assert "built-in" not in metadata_ids - - # Verify all custom metadata fields are included - assert len(result["doc_metadata"]) == 2 - - # Verify usage counts are included - for meta in result["doc_metadata"]: - assert "count" in meta - assert meta["count"] == 5 - - -# ============================================================================ -# Additional Documentation and Notes -# ============================================================================ -# -# This test suite covers the core metadata CRUD operations and basic -# filtering functionality. Additional test scenarios that could be added: -# -# 1. enable_built_in_field / disable_built_in_field: -# - Testing built-in field enablement -# - Testing built-in field disablement -# - Testing document metadata updates when enabling/disabling -# -# 2. update_documents_metadata: -# - Testing partial updates -# - Testing full updates -# - Testing metadata binding creation -# - Testing built-in field updates -# -# 3. Metadata Filtering and Querying: -# - Testing metadata-based document filtering -# - Testing complex metadata queries -# - Testing metadata value retrieval -# -# These scenarios are not currently implemented but could be added if needed -# based on real-world usage patterns or discovered edge cases. -# -# ============================================================================ diff --git a/api/tests/unit_tests/services/dataset_permission_service.py b/api/tests/unit_tests/services/dataset_permission_service.py deleted file mode 100644 index e098e90455..0000000000 --- a/api/tests/unit_tests/services/dataset_permission_service.py +++ /dev/null @@ -1,825 +0,0 @@ -""" -Comprehensive unit tests for DatasetPermissionService and DatasetService permission methods. - -This module contains extensive unit tests for dataset permission management, -including partial member list operations, permission validation, and permission -enum handling. - -The DatasetPermissionService provides methods for: -- Retrieving partial member permissions (get_dataset_partial_member_list) -- Updating partial member lists (update_partial_member_list) -- Validating permissions before operations (check_permission) -- Clearing partial member lists (clear_partial_member_list) - -The DatasetService provides permission checking methods: -- check_dataset_permission - validates user access to dataset -- check_dataset_operator_permission - validates operator permissions - -These operations are critical for dataset access control and security, ensuring -that users can only access datasets they have permission to view or modify. - -This test suite ensures: -- Correct retrieval of partial member lists -- Proper update of partial member permissions -- Accurate permission validation logic -- Proper handling of permission enums (only_me, all_team_members, partial_members) -- Security boundaries are maintained -- Error conditions are handled correctly - -================================================================================ -ARCHITECTURE OVERVIEW -================================================================================ - -The Dataset permission system is a multi-layered access control mechanism -that provides fine-grained control over who can access and modify datasets. - -1. Permission Levels: - - only_me: Only the dataset creator can access - - all_team_members: All members of the tenant can access - - partial_members: Only specific users listed in DatasetPermission can access - -2. Permission Storage: - - Dataset.permission: Stores the permission level enum - - DatasetPermission: Stores individual user permissions for partial_members - - Each DatasetPermission record links a dataset to a user account - -3. Permission Validation: - - Tenant-level checks: Users must be in the same tenant - - Role-based checks: OWNER role bypasses some restrictions - - Explicit permission checks: For partial_members, explicit DatasetPermission - records are required - -4. Permission Operations: - - Partial member list management: Add/remove users from partial access - - Permission validation: Check before allowing operations - - Permission clearing: Remove all partial members when changing permission level - -================================================================================ -TESTING STRATEGY -================================================================================ - -This test suite follows a comprehensive testing strategy that covers: - -1. Partial Member List Operations: - - Retrieving member lists - - Adding new members - - Updating existing members - - Removing members - - Empty list handling - -2. Permission Validation: - - Dataset editor permissions - - Dataset operator restrictions - - Permission enum validation - - Partial member list validation - - Tenant isolation - -3. Permission Enum Handling: - - only_me permission behavior - - all_team_members permission behavior - - partial_members permission behavior - - Permission transitions - - Edge cases for each enum value - -4. Security and Access Control: - - Tenant boundary enforcement - - Role-based access control - - Creator privilege validation - - Explicit permission requirement - -5. Error Handling: - - Invalid permission changes - - Missing required data - - Database transaction failures - - Permission denial scenarios - -================================================================================ -""" - -from unittest.mock import Mock, create_autospec, patch - -import pytest - -from models import Account, TenantAccountRole -from models.dataset import ( - Dataset, - DatasetPermission, - DatasetPermissionEnum, -) -from services.dataset_service import DatasetPermissionService, DatasetService -from services.errors.account import NoPermissionError - -# ============================================================================ -# Test Data Factory -# ============================================================================ -# The Test Data Factory pattern is used here to centralize the creation of -# test objects and mock instances. This approach provides several benefits: -# -# 1. Consistency: All test objects are created using the same factory methods, -# ensuring consistent structure across all tests. -# -# 2. Maintainability: If the structure of models or services changes, we only -# need to update the factory methods rather than every individual test. -# -# 3. Reusability: Factory methods can be reused across multiple test classes, -# reducing code duplication. -# -# 4. Readability: Tests become more readable when they use descriptive factory -# method calls instead of complex object construction logic. -# -# ============================================================================ - - -class DatasetPermissionTestDataFactory: - """ - Factory class for creating test data and mock objects for dataset permission tests. - - This factory provides static methods to create mock objects for: - - Dataset instances with various permission configurations - - User/Account instances with different roles and permissions - - DatasetPermission instances - - Permission enum values - - Database query results - - The factory methods help maintain consistency across tests and reduce - code duplication when setting up test scenarios. - """ - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "tenant-123", - permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, - created_by: str = "user-123", - name: str = "Test Dataset", - **kwargs, - ) -> Mock: - """ - Create a mock Dataset with specified attributes. - - Args: - dataset_id: Unique identifier for the dataset - tenant_id: Tenant identifier - permission: Permission level enum - created_by: ID of user who created the dataset - name: Dataset name - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a Dataset instance - """ - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.permission = permission - dataset.created_by = created_by - dataset.name = name - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-123", - tenant_id: str = "tenant-123", - role: TenantAccountRole = TenantAccountRole.NORMAL, - is_dataset_editor: bool = True, - is_dataset_operator: bool = False, - **kwargs, - ) -> Mock: - """ - Create a mock user (Account) with specified attributes. - - Args: - user_id: Unique identifier for the user - tenant_id: Tenant identifier - role: User role (OWNER, ADMIN, NORMAL, DATASET_OPERATOR, etc.) - is_dataset_editor: Whether user has dataset editor permissions - is_dataset_operator: Whether user is a dataset operator - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as an Account instance - """ - user = create_autospec(Account, instance=True) - user.id = user_id - user.current_tenant_id = tenant_id - user.current_role = role - user.is_dataset_editor = is_dataset_editor - user.is_dataset_operator = is_dataset_operator - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_dataset_permission_mock( - permission_id: str = "permission-123", - dataset_id: str = "dataset-123", - account_id: str = "user-456", - tenant_id: str = "tenant-123", - has_permission: bool = True, - **kwargs, - ) -> Mock: - """ - Create a mock DatasetPermission instance. - - Args: - permission_id: Unique identifier for the permission - dataset_id: Dataset ID - account_id: User account ID - tenant_id: Tenant identifier - has_permission: Whether permission is granted - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a DatasetPermission instance - """ - permission = Mock(spec=DatasetPermission) - permission.id = permission_id - permission.dataset_id = dataset_id - permission.account_id = account_id - permission.tenant_id = tenant_id - permission.has_permission = has_permission - for key, value in kwargs.items(): - setattr(permission, key, value) - return permission - - @staticmethod - def create_user_list_mock(user_ids: list[str]) -> list[dict[str, str]]: - """ - Create a list of user dictionaries for partial member list operations. - - Args: - user_ids: List of user IDs to include - - Returns: - List of user dictionaries with "user_id" keys - """ - return [{"user_id": user_id} for user_id in user_ids] - - -# ============================================================================ -# Tests for check_permission -# ============================================================================ - - -class TestDatasetPermissionServiceCheckPermission: - """ - Comprehensive unit tests for DatasetPermissionService.check_permission method. - - This test class covers the permission validation logic that ensures - users have the appropriate permissions to modify dataset permissions. - - The check_permission method: - 1. Validates user is a dataset editor - 2. Checks if dataset operator is trying to change permissions - 3. Validates partial member list when setting to partial_members - 4. Ensures dataset operators cannot change permission levels - 5. Ensures dataset operators cannot modify partial member lists - - Test scenarios include: - - Valid permission changes by dataset editors - - Dataset operator restrictions - - Partial member list validation - - Missing dataset editor permissions - - Invalid permission changes - """ - - @pytest.fixture - def mock_get_partial_member_list(self): - """ - Mock get_dataset_partial_member_list method. - - Provides a mocked version of the get_dataset_partial_member_list - method for testing permission validation logic. - """ - with patch.object(DatasetPermissionService, "get_dataset_partial_member_list") as mock_get_list: - yield mock_get_list - - def test_check_permission_dataset_editor_success(self, mock_get_partial_member_list): - """ - Test successful permission check for dataset editor. - - Verifies that when a dataset editor (not operator) tries to - change permissions, the check passes. - - This test ensures: - - Dataset editors can change permissions - - No errors are raised for valid changes - - Partial member list validation is skipped for non-operators - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=False) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME) - requested_permission = DatasetPermissionEnum.ALL_TEAM - requested_partial_member_list = None - - # Act (should not raise) - DatasetPermissionService.check_permission(user, dataset, requested_permission, requested_partial_member_list) - - # Assert - # Verify get_partial_member_list was not called (not needed for non-operators) - mock_get_partial_member_list.assert_not_called() - - def test_check_permission_not_dataset_editor_error(self): - """ - Test error when user is not a dataset editor. - - Verifies that when a user without dataset editor permissions - tries to change permissions, a NoPermissionError is raised. - - This test ensures: - - Non-editors cannot change permissions - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=False) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock() - requested_permission = DatasetPermissionEnum.ALL_TEAM - requested_partial_member_list = None - - # Act & Assert - with pytest.raises(NoPermissionError, match="User does not have permission to edit this dataset"): - DatasetPermissionService.check_permission( - user, dataset, requested_permission, requested_partial_member_list - ) - - def test_check_permission_operator_cannot_change_permission_error(self): - """ - Test error when dataset operator tries to change permission level. - - Verifies that when a dataset operator tries to change the permission - level, a NoPermissionError is raised. - - This test ensures: - - Dataset operators cannot change permission levels - - Error message is clear - - Current permission is preserved - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME) - requested_permission = DatasetPermissionEnum.ALL_TEAM # Trying to change - requested_partial_member_list = None - - # Act & Assert - with pytest.raises(NoPermissionError, match="Dataset operators cannot change the dataset permissions"): - DatasetPermissionService.check_permission( - user, dataset, requested_permission, requested_partial_member_list - ) - - def test_check_permission_operator_partial_members_missing_list_error(self, mock_get_partial_member_list): - """ - Test error when operator sets partial_members without providing list. - - Verifies that when a dataset operator tries to set permission to - partial_members without providing a member list, a ValueError is raised. - - This test ensures: - - Partial member list is required for partial_members permission - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - requested_permission = "partial_members" - requested_partial_member_list = None # Missing list - - # Act & Assert - with pytest.raises(ValueError, match="Partial member list is required when setting to partial members"): - DatasetPermissionService.check_permission( - user, dataset, requested_permission, requested_partial_member_list - ) - - def test_check_permission_operator_cannot_modify_partial_list_error(self, mock_get_partial_member_list): - """ - Test error when operator tries to modify partial member list. - - Verifies that when a dataset operator tries to change the partial - member list, a ValueError is raised. - - This test ensures: - - Dataset operators cannot modify partial member lists - - Error message is clear - - Current member list is preserved - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - requested_permission = "partial_members" - - # Current member list - current_member_list = ["user-456", "user-789"] - mock_get_partial_member_list.return_value = current_member_list - - # Requested member list (different from current) - requested_partial_member_list = DatasetPermissionTestDataFactory.create_user_list_mock( - ["user-456", "user-999"] # Different list - ) - - # Act & Assert - with pytest.raises(ValueError, match="Dataset operators cannot change the dataset permissions"): - DatasetPermissionService.check_permission( - user, dataset, requested_permission, requested_partial_member_list - ) - - def test_check_permission_operator_can_keep_same_partial_list(self, mock_get_partial_member_list): - """ - Test that operator can keep the same partial member list. - - Verifies that when a dataset operator keeps the same partial member - list, the check passes. - - This test ensures: - - Operators can keep existing partial member lists - - No errors are raised for unchanged lists - - Permission validation works correctly - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(is_dataset_editor=True, is_dataset_operator=True) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - requested_permission = "partial_members" - - # Current member list - current_member_list = ["user-456", "user-789"] - mock_get_partial_member_list.return_value = current_member_list - - # Requested member list (same as current) - requested_partial_member_list = DatasetPermissionTestDataFactory.create_user_list_mock( - ["user-456", "user-789"] # Same list - ) - - # Act (should not raise) - DatasetPermissionService.check_permission(user, dataset, requested_permission, requested_partial_member_list) - - # Assert - # Verify get_partial_member_list was called to compare lists - mock_get_partial_member_list.assert_called_once_with(dataset.id) - - -# ============================================================================ -# Tests for DatasetService.check_dataset_permission -# ============================================================================ - - -class TestDatasetServiceCheckDatasetPermission: - """ - Comprehensive unit tests for DatasetService.check_dataset_permission method. - - This test class covers the dataset permission checking logic that validates - whether a user has access to a dataset based on permission enums. - - The check_dataset_permission method: - 1. Validates tenant match - 2. Checks OWNER role (bypasses some restrictions) - 3. Validates only_me permission (creator only) - 4. Validates partial_members permission (explicit permission required) - 5. Validates all_team_members permission (all tenant members) - - Test scenarios include: - - Tenant boundary enforcement - - OWNER role bypass - - only_me permission validation - - partial_members permission validation - - all_team_members permission validation - - Permission denial scenarios - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - database queries for permission checks. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_check_dataset_permission_owner_bypass(self, mock_db_session): - """ - Test that OWNER role bypasses permission checks. - - Verifies that when a user has OWNER role, they can access any - dataset in their tenant regardless of permission level. - - This test ensures: - - OWNER role bypasses permission restrictions - - No database queries are needed for OWNER - - Access is granted automatically - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(role=TenantAccountRole.OWNER, tenant_id="tenant-123") - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.ONLY_ME, - created_by="other-user-123", # Not the current user - ) - - # Act (should not raise) - DatasetService.check_dataset_permission(dataset, user) - - # Assert - # Verify no permission queries were made (OWNER bypasses) - mock_db_session.query.assert_not_called() - - def test_check_dataset_permission_tenant_mismatch_error(self): - """ - Test error when user and dataset are in different tenants. - - Verifies that when a user tries to access a dataset from a different - tenant, a NoPermissionError is raised. - - This test ensures: - - Tenant boundary is enforced - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(tenant_id="tenant-123") - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(tenant_id="tenant-456") # Different tenant - - # Act & Assert - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_permission(dataset, user) - - def test_check_dataset_permission_only_me_creator_success(self): - """ - Test that creator can access only_me dataset. - - Verifies that when a user is the creator of an only_me dataset, - they can access it successfully. - - This test ensures: - - Creators can access their own only_me datasets - - No explicit permission record is needed - - Access is granted correctly - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.ONLY_ME, - created_by="user-123", # User is the creator - ) - - # Act (should not raise) - DatasetService.check_dataset_permission(dataset, user) - - def test_check_dataset_permission_only_me_non_creator_error(self): - """ - Test error when non-creator tries to access only_me dataset. - - Verifies that when a user who is not the creator tries to access - an only_me dataset, a NoPermissionError is raised. - - This test ensures: - - Non-creators cannot access only_me datasets - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.ONLY_ME, - created_by="other-user-456", # Different creator - ) - - # Act & Assert - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_permission(dataset, user) - - def test_check_dataset_permission_partial_members_creator_success(self, mock_db_session): - """ - Test that creator can access partial_members dataset without explicit permission. - - Verifies that when a user is the creator of a partial_members dataset, - they can access it even without an explicit DatasetPermission record. - - This test ensures: - - Creators can access their own datasets - - No explicit permission record is needed for creators - - Access is granted correctly - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="user-123", # User is the creator - ) - - # Act (should not raise) - DatasetService.check_dataset_permission(dataset, user) - - # Assert - # Verify permission query was not executed (creator bypasses) - mock_db_session.query.assert_not_called() - - def test_check_dataset_permission_all_team_members_success(self): - """ - Test that any tenant member can access all_team_members dataset. - - Verifies that when a dataset has all_team_members permission, any - user in the same tenant can access it. - - This test ensures: - - All team members can access - - No explicit permission record is needed - - Access is granted correctly - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.ALL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Act (should not raise) - DatasetService.check_dataset_permission(dataset, user) - - -# ============================================================================ -# Tests for DatasetService.check_dataset_operator_permission -# ============================================================================ - - -class TestDatasetServiceCheckDatasetOperatorPermission: - """ - Comprehensive unit tests for DatasetService.check_dataset_operator_permission method. - - This test class covers the dataset operator permission checking logic, - which validates whether a dataset operator has access to a dataset. - - The check_dataset_operator_permission method: - 1. Validates dataset exists - 2. Validates user exists - 3. Checks OWNER role (bypasses restrictions) - 4. Validates only_me permission (creator only) - 5. Validates partial_members permission (explicit permission required) - - Test scenarios include: - - Dataset not found error - - User not found error - - OWNER role bypass - - only_me permission validation - - partial_members permission validation - - Permission denial scenarios - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - database queries for permission checks. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_check_dataset_operator_permission_dataset_not_found_error(self): - """ - Test error when dataset is None. - - Verifies that when dataset is None, a ValueError is raised. - - This test ensures: - - Dataset existence is validated - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock() - dataset = None - - # Act & Assert - with pytest.raises(ValueError, match="Dataset not found"): - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - def test_check_dataset_operator_permission_user_not_found_error(self): - """ - Test error when user is None. - - Verifies that when user is None, a ValueError is raised. - - This test ensures: - - User existence is validated - - Error message is clear - - Error type is correct - """ - # Arrange - user = None - dataset = DatasetPermissionTestDataFactory.create_dataset_mock() - - # Act & Assert - with pytest.raises(ValueError, match="User not found"): - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - def test_check_dataset_operator_permission_owner_bypass(self): - """ - Test that OWNER role bypasses permission checks. - - Verifies that when a user has OWNER role, they can access any - dataset in their tenant regardless of permission level. - - This test ensures: - - OWNER role bypasses permission restrictions - - No database queries are needed for OWNER - - Access is granted automatically - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(role=TenantAccountRole.OWNER, tenant_id="tenant-123") - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.ONLY_ME, - created_by="other-user-123", # Not the current user - ) - - # Act (should not raise) - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - def test_check_dataset_operator_permission_only_me_creator_success(self): - """ - Test that creator can access only_me dataset. - - Verifies that when a user is the creator of an only_me dataset, - they can access it successfully. - - This test ensures: - - Creators can access their own only_me datasets - - No explicit permission record is needed - - Access is granted correctly - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.ONLY_ME, - created_by="user-123", # User is the creator - ) - - # Act (should not raise) - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - def test_check_dataset_operator_permission_only_me_non_creator_error(self): - """ - Test error when non-creator tries to access only_me dataset. - - Verifies that when a user who is not the creator tries to access - an only_me dataset, a NoPermissionError is raised. - - This test ensures: - - Non-creators cannot access only_me datasets - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.ONLY_ME, - created_by="other-user-456", # Different creator - ) - - # Act & Assert - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - -# ============================================================================ -# Additional Documentation and Notes -# ============================================================================ -# -# This test suite covers the core permission management operations for datasets. -# Additional test scenarios that could be added: -# -# 1. Permission Enum Transitions: -# - Testing transitions between permission levels -# - Testing validation during transitions -# - Testing partial member list updates during transitions -# -# 2. Bulk Operations: -# - Testing bulk permission updates -# - Testing bulk partial member list updates -# - Testing performance with large member lists -# -# 3. Edge Cases: -# - Testing with very large partial member lists -# - Testing with special characters in user IDs -# - Testing with deleted users -# - Testing with inactive permissions -# -# 4. Integration Scenarios: -# - Testing permission changes followed by access attempts -# - Testing concurrent permission updates -# - Testing permission inheritance -# -# These scenarios are not currently implemented but could be added if needed -# based on real-world usage patterns or discovered edge cases. -# -# ============================================================================ diff --git a/api/tests/unit_tests/services/dataset_service_update_delete.py b/api/tests/unit_tests/services/dataset_service_update_delete.py deleted file mode 100644 index 62c39f96d3..0000000000 --- a/api/tests/unit_tests/services/dataset_service_update_delete.py +++ /dev/null @@ -1,818 +0,0 @@ -""" -Comprehensive unit tests for DatasetService update and delete operations. - -This module contains extensive unit tests for the DatasetService class, -specifically focusing on update and delete operations for datasets. - -The DatasetService provides methods for: -- Updating dataset configuration and settings (update_dataset) -- Deleting datasets with proper cleanup (delete_dataset) -- Updating RAG pipeline dataset settings (update_rag_pipeline_dataset_settings) -- Checking if dataset is in use (dataset_use_check) -- Updating dataset API access status (update_dataset_api_status) - -These operations are critical for dataset lifecycle management and require -careful handling of permissions, dependencies, and data integrity. - -This test suite ensures: -- Correct update of dataset properties -- Proper permission validation before updates/deletes -- Cascade deletion handling -- Event signaling for cleanup operations -- RAG pipeline dataset configuration updates -- API status management -- Use check validation - -================================================================================ -ARCHITECTURE OVERVIEW -================================================================================ - -The DatasetService update and delete operations are part of the dataset -lifecycle management system. These operations interact with multiple -components: - -1. Permission System: All update/delete operations require proper - permission validation to ensure users can only modify datasets they - have access to. - -2. Event System: Dataset deletion triggers the dataset_was_deleted event, - which notifies other components to clean up related data (documents, - segments, vector indices, etc.). - -3. Dependency Checking: Before deletion, the system checks if the dataset - is in use by any applications (via AppDatasetJoin). - -4. RAG Pipeline Integration: RAG pipeline datasets have special update - logic that handles chunk structure, indexing techniques, and embedding - model configuration. - -5. API Status Management: Datasets can have their API access enabled or - disabled, which affects whether they can be accessed via the API. - -================================================================================ -TESTING STRATEGY -================================================================================ - -This test suite follows a comprehensive testing strategy that covers: - -1. Update Operations: - - Internal dataset updates - - External dataset updates - - RAG pipeline dataset updates - - Permission validation - - Name duplicate checking - - Configuration validation - -2. Delete Operations: - - Successful deletion - - Permission validation - - Event signaling - - Database cleanup - - Not found handling - -3. Use Check Operations: - - Dataset in use detection - - Dataset not in use detection - - AppDatasetJoin query validation - -4. API Status Operations: - - Enable API access - - Disable API access - - Permission validation - - Current user validation - -5. RAG Pipeline Operations: - - Unpublished dataset updates - - Published dataset updates - - Chunk structure validation - - Indexing technique changes - - Embedding model configuration - -================================================================================ -""" - -import datetime -from unittest.mock import Mock, create_autospec, patch - -import pytest -from sqlalchemy.orm import Session - -from core.rag.index_processor.constant.index_type import IndexTechniqueType -from models import Account, TenantAccountRole -from models.dataset import ( - AppDatasetJoin, - Dataset, - DatasetPermissionEnum, -) -from services.dataset_service import DatasetService -from services.errors.account import NoPermissionError - -# ============================================================================ -# Test Data Factory -# ============================================================================ -# The Test Data Factory pattern is used here to centralize the creation of -# test objects and mock instances. This approach provides several benefits: -# -# 1. Consistency: All test objects are created using the same factory methods, -# ensuring consistent structure across all tests. -# -# 2. Maintainability: If the structure of models or services changes, we only -# need to update the factory methods rather than every individual test. -# -# 3. Reusability: Factory methods can be reused across multiple test classes, -# reducing code duplication. -# -# 4. Readability: Tests become more readable when they use descriptive factory -# method calls instead of complex object construction logic. -# -# ============================================================================ - - -class DatasetUpdateDeleteTestDataFactory: - """ - Factory class for creating test data and mock objects for dataset update/delete tests. - - This factory provides static methods to create mock objects for: - - Dataset instances with various configurations - - User/Account instances with different roles - - Knowledge configuration objects - - Database session mocks - - Event signal mocks - - The factory methods help maintain consistency across tests and reduce - code duplication when setting up test scenarios. - """ - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - provider: str = "vendor", - name: str = "Test Dataset", - description: str = "Test description", - tenant_id: str = "tenant-123", - indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, - embedding_model_provider: str | None = "openai", - embedding_model: str | None = "text-embedding-ada-002", - collection_binding_id: str | None = "binding-123", - enable_api: bool = True, - permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, - created_by: str = "user-123", - chunk_structure: str | None = None, - runtime_mode: str = "general", - **kwargs, - ) -> Mock: - """ - Create a mock Dataset with specified attributes. - - Args: - dataset_id: Unique identifier for the dataset - provider: Dataset provider (vendor, external) - name: Dataset name - description: Dataset description - tenant_id: Tenant identifier - indexing_technique: Indexing technique (high_quality, economy) - embedding_model_provider: Embedding model provider - embedding_model: Embedding model name - collection_binding_id: Collection binding ID - enable_api: Whether API access is enabled - permission: Dataset permission level - created_by: ID of user who created the dataset - chunk_structure: Chunk structure for RAG pipeline datasets - runtime_mode: Runtime mode (general, rag_pipeline) - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a Dataset instance - """ - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.provider = provider - dataset.name = name - dataset.description = description - dataset.tenant_id = tenant_id - dataset.indexing_technique = indexing_technique - dataset.embedding_model_provider = embedding_model_provider - dataset.embedding_model = embedding_model - dataset.collection_binding_id = collection_binding_id - dataset.enable_api = enable_api - dataset.permission = permission - dataset.created_by = created_by - dataset.chunk_structure = chunk_structure - dataset.runtime_mode = runtime_mode - dataset.retrieval_model = {} - dataset.keyword_number = 10 - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-123", - tenant_id: str = "tenant-123", - role: TenantAccountRole = TenantAccountRole.NORMAL, - is_dataset_editor: bool = True, - **kwargs, - ) -> Mock: - """ - Create a mock user (Account) with specified attributes. - - Args: - user_id: Unique identifier for the user - tenant_id: Tenant identifier - role: User role (OWNER, ADMIN, NORMAL, etc.) - is_dataset_editor: Whether user has dataset editor permissions - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as an Account instance - """ - user = create_autospec(Account, instance=True) - user.id = user_id - user.current_tenant_id = tenant_id - user.current_role = role - user.is_dataset_editor = is_dataset_editor - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_knowledge_configuration_mock( - chunk_structure: str = "tree", - indexing_technique: str = IndexTechniqueType.HIGH_QUALITY, - embedding_model_provider: str = "openai", - embedding_model: str = "text-embedding-ada-002", - keyword_number: int = 10, - retrieval_model: dict | None = None, - **kwargs, - ) -> Mock: - """ - Create a mock KnowledgeConfiguration entity. - - Args: - chunk_structure: Chunk structure type - indexing_technique: Indexing technique - embedding_model_provider: Embedding model provider - embedding_model: Embedding model name - keyword_number: Keyword number for economy indexing - retrieval_model: Retrieval model configuration - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as a KnowledgeConfiguration instance - """ - config = Mock() - config.chunk_structure = chunk_structure - config.indexing_technique = indexing_technique - config.embedding_model_provider = embedding_model_provider - config.embedding_model = embedding_model - config.keyword_number = keyword_number - config.retrieval_model = Mock() - config.retrieval_model.model_dump.return_value = retrieval_model or { - "search_method": "semantic_search", - "top_k": 2, - } - for key, value in kwargs.items(): - setattr(config, key, value) - return config - - @staticmethod - def create_app_dataset_join_mock( - app_id: str = "app-123", - dataset_id: str = "dataset-123", - **kwargs, - ) -> Mock: - """ - Create a mock AppDatasetJoin instance. - - Args: - app_id: Application ID - dataset_id: Dataset ID - **kwargs: Additional attributes to set on the mock - - Returns: - Mock object configured as an AppDatasetJoin instance - """ - join = Mock(spec=AppDatasetJoin) - join.app_id = app_id - join.dataset_id = dataset_id - for key, value in kwargs.items(): - setattr(join, key, value) - return join - - -# ============================================================================ -# Tests for update_dataset -# ============================================================================ - - -class TestDatasetServiceUpdateDataset: - """ - Comprehensive unit tests for DatasetService.update_dataset method. - - This test class covers the dataset update functionality, including - internal and external dataset updates, permission validation, and - name duplicate checking. - - The update_dataset method: - 1. Retrieves the dataset by ID - 2. Validates dataset exists - 3. Checks for duplicate names - 4. Validates user permissions - 5. Routes to appropriate update handler (internal or external) - 6. Returns the updated dataset - - Test scenarios include: - - Successful internal dataset updates - - Successful external dataset updates - - Permission validation - - Duplicate name detection - - Dataset not found errors - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """ - Mock dataset service dependencies for testing. - - Provides mocked dependencies including: - - get_dataset method - - check_dataset_permission method - - _has_dataset_same_name method - - Database session - - Current time utilities - """ - with ( - patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, - patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, - patch("services.dataset_service.DatasetService._has_dataset_same_name") as mock_has_same_name, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, - ): - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_naive_utc_now.return_value = current_time - - yield { - "get_dataset": mock_get_dataset, - "check_permission": mock_check_perm, - "has_same_name": mock_has_same_name, - "db_session": mock_db, - "naive_utc_now": mock_naive_utc_now, - "current_time": current_time, - } - - def test_update_dataset_internal_success(self, mock_dataset_service_dependencies): - """ - Test successful update of an internal dataset. - - Verifies that when all validation passes, an internal dataset - is updated correctly through the _update_internal_dataset method. - - This test ensures: - - Dataset is retrieved correctly - - Permission is checked - - Name duplicate check is performed - - Internal update handler is called - - Updated dataset is returned - """ - # Arrange - dataset_id = "dataset-123" - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( - dataset_id=dataset_id, provider="vendor", name="Old Name" - ) - user = DatasetUpdateDeleteTestDataFactory.create_user_mock() - - update_data = { - "name": "New Name", - "description": "New Description", - } - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - mock_dataset_service_dependencies["has_same_name"].return_value = False - - with patch("services.dataset_service.DatasetService._update_internal_dataset") as mock_update_internal: - mock_update_internal.return_value = dataset - - # Act - result = DatasetService.update_dataset(dataset_id, update_data, user) - - # Assert - assert result == dataset - - # Verify dataset was retrieved - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id) - - # Verify permission was checked - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - - # Verify name duplicate check was performed - mock_dataset_service_dependencies["has_same_name"].assert_called_once() - - # Verify internal update handler was called - mock_update_internal.assert_called_once() - - def test_update_dataset_external_success(self, mock_dataset_service_dependencies): - """ - Test successful update of an external dataset. - - Verifies that when all validation passes, an external dataset - is updated correctly through the _update_external_dataset method. - - This test ensures: - - Dataset is retrieved correctly - - Permission is checked - - Name duplicate check is performed - - External update handler is called - - Updated dataset is returned - """ - # Arrange - dataset_id = "dataset-123" - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( - dataset_id=dataset_id, provider="external", name="Old Name" - ) - user = DatasetUpdateDeleteTestDataFactory.create_user_mock() - - update_data = { - "name": "New Name", - "external_knowledge_id": "new-knowledge-id", - } - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - mock_dataset_service_dependencies["has_same_name"].return_value = False - - with patch("services.dataset_service.DatasetService._update_external_dataset") as mock_update_external: - mock_update_external.return_value = dataset - - # Act - result = DatasetService.update_dataset(dataset_id, update_data, user) - - # Assert - assert result == dataset - - # Verify external update handler was called - mock_update_external.assert_called_once() - - def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies): - """ - Test error handling when dataset is not found. - - Verifies that when the dataset ID doesn't exist, a ValueError - is raised with an appropriate message. - - This test ensures: - - Dataset not found error is handled correctly - - No update operations are performed - - Error message is clear - """ - # Arrange - dataset_id = "non-existent-dataset" - user = DatasetUpdateDeleteTestDataFactory.create_user_mock() - - update_data = {"name": "New Name"} - - mock_dataset_service_dependencies["get_dataset"].return_value = None - - # Act & Assert - with pytest.raises(ValueError, match="Dataset not found"): - DatasetService.update_dataset(dataset_id, update_data, user) - - # Verify no update operations were attempted - mock_dataset_service_dependencies["check_permission"].assert_not_called() - mock_dataset_service_dependencies["has_same_name"].assert_not_called() - - def test_update_dataset_duplicate_name_error(self, mock_dataset_service_dependencies): - """ - Test error handling when dataset name already exists. - - Verifies that when a dataset with the same name already exists - in the tenant, a ValueError is raised. - - This test ensures: - - Duplicate name detection works correctly - - Error message is clear - - No update operations are performed - """ - # Arrange - dataset_id = "dataset-123" - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id) - user = DatasetUpdateDeleteTestDataFactory.create_user_mock() - - update_data = {"name": "Existing Name"} - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - mock_dataset_service_dependencies["has_same_name"].return_value = True # Duplicate exists - - # Act & Assert - with pytest.raises(ValueError, match="Dataset name already exists"): - DatasetService.update_dataset(dataset_id, update_data, user) - - # Verify permission check was not called (fails before that) - mock_dataset_service_dependencies["check_permission"].assert_not_called() - - def test_update_dataset_permission_denied_error(self, mock_dataset_service_dependencies): - """ - Test error handling when user lacks permission. - - Verifies that when the user doesn't have permission to update - the dataset, a NoPermissionError is raised. - - This test ensures: - - Permission validation works correctly - - Error is raised before any updates - - Error type is correct - """ - # Arrange - dataset_id = "dataset-123" - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock(dataset_id=dataset_id) - user = DatasetUpdateDeleteTestDataFactory.create_user_mock() - - update_data = {"name": "New Name"} - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - mock_dataset_service_dependencies["has_same_name"].return_value = False - mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission") - - # Act & Assert - with pytest.raises(NoPermissionError): - DatasetService.update_dataset(dataset_id, update_data, user) - - -# ============================================================================ -# Tests for update_rag_pipeline_dataset_settings -# ============================================================================ - - -class TestDatasetServiceUpdateRagPipelineDatasetSettings: - """ - Comprehensive unit tests for DatasetService.update_rag_pipeline_dataset_settings method. - - This test class covers the RAG pipeline dataset settings update functionality, - including chunk structure, indexing technique, and embedding model configuration. - - The update_rag_pipeline_dataset_settings method: - 1. Validates current_user and tenant - 2. Merges dataset into session - 3. Handles unpublished vs published datasets differently - 4. Updates chunk structure, indexing technique, and retrieval model - 5. Configures embedding model for high_quality indexing - 6. Updates keyword_number for economy indexing - 7. Commits transaction - 8. Triggers index update tasks if needed - - Test scenarios include: - - Unpublished dataset updates - - Published dataset updates - - Chunk structure validation - - Indexing technique changes - - Embedding model configuration - - Error handling - """ - - @pytest.fixture - def mock_session(self): - """ - Mock database session for testing. - - Provides a mocked SQLAlchemy session for testing session operations. - """ - return Mock(spec=Session) - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """ - Mock dataset service dependencies for testing. - - Provides mocked dependencies including: - - current_user context - - ModelManager - - DatasetCollectionBindingService - - Database session operations - - Task scheduling - """ - with ( - patch( - "services.dataset_service.current_user", create_autospec(Account, instance=True) - ) as mock_current_user, - patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, - patch( - "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" - ) as mock_get_binding, - patch("services.dataset_service.deal_dataset_index_update_task") as mock_task, - ): - mock_current_user.current_tenant_id = "tenant-123" - mock_current_user.id = "user-123" - - yield { - "current_user": mock_current_user, - "model_manager": mock_model_manager, - "get_binding": mock_get_binding, - "task": mock_task, - } - - def test_update_rag_pipeline_dataset_settings_unpublished_success( - self, mock_session, mock_dataset_service_dependencies - ): - """ - Test successful update of unpublished RAG pipeline dataset. - - Verifies that when a dataset is not published, all settings can - be updated including chunk structure and indexing technique. - - This test ensures: - - Current user validation passes - - Dataset is merged into session - - Chunk structure is updated - - Indexing technique is updated - - Embedding model is configured for high_quality - - Retrieval model is updated - - Dataset is added to session - """ - # Arrange - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( - dataset_id="dataset-123", - runtime_mode="rag_pipeline", - chunk_structure="tree", - indexing_technique=IndexTechniqueType.HIGH_QUALITY, - ) - - knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( - chunk_structure="list", - indexing_technique=IndexTechniqueType.HIGH_QUALITY, - embedding_model_provider="openai", - embedding_model="text-embedding-ada-002", - ) - - # Mock embedding model - mock_embedding_model = Mock() - mock_embedding_model.model_name = "text-embedding-ada-002" - mock_embedding_model.provider = "openai" - mock_embedding_model.credentials = {} - - mock_model_schema = Mock() - mock_model_schema.features = [] - - mock_text_embedding_model = Mock() - mock_text_embedding_model.get_model_schema.return_value = mock_model_schema - mock_embedding_model.model_type_instance = mock_text_embedding_model - - mock_model_instance = Mock() - mock_model_instance.get_model_instance.return_value = mock_embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_instance - - # Mock collection binding - mock_binding = Mock() - mock_binding.id = "binding-123" - mock_dataset_service_dependencies["get_binding"].return_value = mock_binding - - mock_session.merge.return_value = dataset - - # Act - DatasetService.update_rag_pipeline_dataset_settings( - mock_session, dataset, knowledge_config, has_published=False - ) - - # Assert - assert dataset.chunk_structure == "list" - assert dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY - assert dataset.embedding_model == "text-embedding-ada-002" - assert dataset.embedding_model_provider == "openai" - assert dataset.collection_binding_id == "binding-123" - - # Verify dataset was added to session - mock_session.add.assert_called_once_with(dataset) - - def test_update_rag_pipeline_dataset_settings_published_chunk_structure_error( - self, mock_session, mock_dataset_service_dependencies - ): - """ - Test error handling when trying to update chunk structure of published dataset. - - Verifies that when a dataset is published and has an existing chunk structure, - attempting to change it raises a ValueError. - - This test ensures: - - Chunk structure change is detected - - ValueError is raised with appropriate message - - No updates are committed - """ - # Arrange - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( - dataset_id="dataset-123", - runtime_mode="rag_pipeline", - chunk_structure="tree", # Existing structure - indexing_technique=IndexTechniqueType.HIGH_QUALITY, - ) - - knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( - chunk_structure="list", # Different structure - indexing_technique=IndexTechniqueType.HIGH_QUALITY, - ) - - mock_session.merge.return_value = dataset - - # Act & Assert - with pytest.raises(ValueError, match="Chunk structure is not allowed to be updated"): - DatasetService.update_rag_pipeline_dataset_settings( - mock_session, dataset, knowledge_config, has_published=True - ) - - # Verify no commit was attempted - mock_session.commit.assert_not_called() - - def test_update_rag_pipeline_dataset_settings_published_economy_error( - self, mock_session, mock_dataset_service_dependencies - ): - """ - Test error handling when trying to change to economy indexing on published dataset. - - Verifies that when a dataset is published, changing indexing technique to - economy is not allowed and raises a ValueError. - - This test ensures: - - Economy indexing change is detected - - ValueError is raised with appropriate message - - No updates are committed - """ - # Arrange - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock( - dataset_id="dataset-123", - runtime_mode="rag_pipeline", - indexing_technique=IndexTechniqueType.HIGH_QUALITY, # Current technique - ) - - knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock( - indexing_technique=IndexTechniqueType.ECONOMY, # Trying to change to economy - ) - - mock_session.merge.return_value = dataset - - # Act & Assert - with pytest.raises( - ValueError, match="Knowledge base indexing technique is not allowed to be updated to economy" - ): - DatasetService.update_rag_pipeline_dataset_settings( - mock_session, dataset, knowledge_config, has_published=True - ) - - def test_update_rag_pipeline_dataset_settings_missing_current_user_error( - self, mock_session, mock_dataset_service_dependencies - ): - """ - Test error handling when current_user is missing. - - Verifies that when current_user is None or has no tenant ID, a ValueError - is raised. - - This test ensures: - - Current user validation works correctly - - Error message is clear - - No updates are performed - """ - # Arrange - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset_mock() - knowledge_config = DatasetUpdateDeleteTestDataFactory.create_knowledge_configuration_mock() - - mock_dataset_service_dependencies["current_user"].current_tenant_id = None # Missing tenant - - # Act & Assert - with pytest.raises(ValueError, match="Current user or current tenant not found"): - DatasetService.update_rag_pipeline_dataset_settings( - mock_session, dataset, knowledge_config, has_published=False - ) - - -# ============================================================================ -# Additional Documentation and Notes -# ============================================================================ -# -# This test suite covers the core update and delete operations for datasets. -# Additional test scenarios that could be added: -# -# 1. Update Operations: -# - Testing with different indexing techniques -# - Testing embedding model provider changes -# - Testing retrieval model updates -# - Testing icon_info updates -# - Testing partial_member_list updates -# -# 2. Delete Operations: -# - Testing cascade deletion of related data -# - Testing event handler execution -# - Testing with datasets that have documents -# - Testing with datasets that have segments -# -# 3. RAG Pipeline Operations: -# - Testing economy indexing technique updates -# - Testing embedding model provider errors -# - Testing keyword_number updates -# - Testing index update task triggering -# -# 4. Integration Scenarios: -# - Testing update followed by delete -# - Testing multiple updates in sequence -# - Testing concurrent update attempts -# - Testing with different user roles -# -# These scenarios are not currently implemented but could be added if needed -# based on real-world usage patterns or discovered edge cases. -# -# ============================================================================ diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py index 5848603ab8..dd41c0c97e 100644 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ b/api/tests/unit_tests/services/external_dataset_service.py @@ -396,10 +396,11 @@ class TestExternalDatasetServiceUsageAndBindings: mock_db_session.scalar.return_value = 3 - in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1") + in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1") assert in_use is True assert count == 3 + assert "tenant_id" in str(mock_db_session.scalar.call_args.args[0]) def test_external_knowledge_api_use_check_not_in_use(self, mock_db_session: MagicMock): """ @@ -408,7 +409,7 @@ class TestExternalDatasetServiceUsageAndBindings: mock_db_session.scalar.return_value = 0 - in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1") + in_use, count = ExternalDatasetService.external_knowledge_api_use_check("api-1", "tenant-1") assert in_use is False assert count == 0 diff --git a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py index bc2f1c6ecc..021bebceff 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py @@ -6,23 +6,23 @@ MODULE = "services.plugin.plugin_auto_upgrade_service" def _patched_session(): - """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" + """Patch session_factory.create_session() to return a mock session as context manager.""" session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) - mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) - db_patcher = patch(f"{MODULE}.db") - return patcher, db_patcher, session + session.__enter__ = MagicMock(return_value=session) + session.__exit__ = MagicMock(return_value=False) + mock_factory = MagicMock() + mock_factory.create_session.return_value = session + patcher = patch(f"{MODULE}.session_factory", mock_factory) + return patcher, session class TestGetStrategy: def test_returns_strategy_when_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() strategy = MagicMock() session.scalar.return_value = strategy - with p1, p2: + with p1: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService result = PluginAutoUpgradeService.get_strategy("t1") @@ -30,10 +30,10 @@ class TestGetStrategy: assert result is strategy def test_returns_none_when_not_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2: + with p1: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService result = PluginAutoUpgradeService.get_strategy("t1") @@ -43,10 +43,10 @@ class TestGetStrategy: class TestChangeStrategy: def test_creates_new_strategy(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.return_value = MagicMock() from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -63,11 +63,11 @@ class TestChangeStrategy: session.add.assert_called_once() def test_updates_existing_strategy(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() session.scalar.return_value = existing - with p1, p2: + with p1: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService result = PluginAutoUpgradeService.change_strategy( @@ -89,12 +89,11 @@ class TestChangeStrategy: class TestExcludePlugin: def test_creates_default_strategy_when_none_exists(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None with ( p1, - p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls, patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs, @@ -110,13 +109,13 @@ class TestExcludePlugin: cs.assert_called_once() def test_appends_to_exclude_list_in_exclude_mode(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "exclude" existing.exclude_plugins = ["p-existing"] session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -128,13 +127,13 @@ class TestExcludePlugin: assert existing.exclude_plugins == ["p-existing", "p-new"] def test_removes_from_include_list_in_partial_mode(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "partial" existing.include_plugins = ["p1", "p2"] session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -146,12 +145,12 @@ class TestExcludePlugin: assert existing.include_plugins == ["p2"] def test_switches_to_exclude_mode_from_all(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "all" session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -164,13 +163,13 @@ class TestExcludePlugin: assert existing.exclude_plugins == ["p1"] def test_no_duplicate_in_exclude_list(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "exclude" existing.exclude_plugins = ["p1"] session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" diff --git a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py index 20f132c015..53a9e6210c 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py @@ -6,23 +6,25 @@ MODULE = "services.plugin.plugin_permission_service" def _patched_session(): - """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" + """Patch session_factory.create_session() to return a mock session as context manager.""" session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) - mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) - db_patcher = patch(f"{MODULE}.db") - return patcher, db_patcher, session + session.__enter__ = MagicMock(return_value=session) + session.__exit__ = MagicMock(return_value=False) + session.begin.return_value.__enter__ = MagicMock(return_value=session) + session.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_factory = MagicMock() + mock_factory.create_session.return_value = session + patcher = patch(f"{MODULE}.session_factory", mock_factory) + return patcher, session class TestGetPermission: def test_returns_permission_when_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() permission = MagicMock() session.scalar.return_value = permission - with p1, p2: + with p1: from services.plugin.plugin_permission_service import PluginPermissionService result = PluginPermissionService.get_permission("t1") @@ -30,10 +32,10 @@ class TestGetPermission: assert result is permission def test_returns_none_when_not_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2: + with p1: from services.plugin.plugin_permission_service import PluginPermissionService result = PluginPermissionService.get_permission("t1") @@ -43,10 +45,10 @@ class TestGetPermission: class TestChangePermission: def test_creates_new_permission_when_not_exists(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls: perm_cls.return_value = MagicMock() from services.plugin.plugin_permission_service import PluginPermissionService @@ -54,20 +56,24 @@ class TestChangePermission: "t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE ) + assert result is True + session.begin.assert_called_once() session.add.assert_called_once() def test_updates_existing_permission(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() session.scalar.return_value = existing - with p1, p2: + with p1: from services.plugin.plugin_permission_service import PluginPermissionService result = PluginPermissionService.change_permission( "t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS ) + assert result is True + session.begin.assert_called_once() assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS session.add.assert_not_called() diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py index f4fdac5f9f..6813a1bf2a 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_dsl_service.py @@ -247,10 +247,11 @@ workflow: dataset_mock = Mock() dataset_mock.id = "d1" mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) session = cast(MagicMock, Mock()) service = RagPipelineDslService(session=cast(Session, session)) - session.query.return_value.filter_by.return_value.all.return_value = [] + session.scalars.return_value.all.return_value = [] account = Mock(current_tenant_id="t1") result = service.import_rag_pipeline(account=account, import_mode="yaml-content", yaml_content=yaml_content) @@ -320,6 +321,7 @@ workflow: dataset_mock.id = "d1" mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset_mock) mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.DatasetCollectionBinding", return_value=Mock(id="b1")) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) service = RagPipelineDslService(session=Mock()) # Mocking self._session.scalar for the pipeline lookup @@ -406,12 +408,14 @@ def test_create_or_update_pipeline_create_new(mocker) -> None: mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1")) mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock()) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline") pipeline_instance = pipeline_cls.return_value pipeline_instance.tenant_id = "t1" pipeline_instance.id = "p1" pipeline_instance.name = "P" pipeline_instance.is_published = False + session.scalar.return_value = None result = service._create_or_update_pipeline(pipeline=None, data=data, account=account, dependencies=[]) @@ -447,8 +451,7 @@ def test_export_rag_pipeline_dsl_with_workflow(mocker) -> None: workflow.rag_pipeline_variables = [] workflow.to_dict.return_value = {"graph": {"nodes": []}} - # Mocking single .where() call - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], @@ -550,7 +553,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None: ] } } - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], @@ -568,7 +571,7 @@ def test_append_workflow_export_data_filters_credentials(mocker) -> None: def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None: session = cast(MagicMock, Mock()) service = RagPipelineDslService(session=cast(Session, session)) - session.query.return_value.filter_by.return_value.first.return_value = Mock() + session.scalar.return_value = Mock() create_entity = RagPipelineDatasetCreateEntity( name="Existing Name", description="", @@ -584,8 +587,8 @@ def test_create_rag_pipeline_dataset_raises_when_name_conflicts(mocker) -> None: def test_create_rag_pipeline_dataset_generates_name_when_missing(mocker) -> None: session = cast(MagicMock, Mock()) service = RagPipelineDslService(session=cast(Session, session)) - session.query.return_value.filter_by.return_value.first.return_value = None - session.query.return_value.filter_by.return_value.all.return_value = [Mock(name="Untitled")] + session.scalar.return_value = None + session.scalars.return_value.all.return_value = [Mock(name="Untitled")] mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="Untitled 2") mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", Mock(id="u1", current_tenant_id="t1")) mocker.patch.object( @@ -632,7 +635,7 @@ def test_append_workflow_export_data_encrypts_knowledge_retrieval_dataset_ids(mo ] } } - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch.object(service, "encrypt_dataset_id", side_effect=lambda dataset_id, tenant_id: f"enc-{dataset_id}") mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", @@ -727,7 +730,7 @@ def test_create_or_update_pipeline_decrypts_knowledge_retrieval_dataset_ids(mock }, } draft_workflow = Mock(id="wf1") - session.query.return_value.where.return_value.first.return_value = draft_workflow + session.scalar.return_value = draft_workflow mocker.patch.object(service, "decrypt_dataset_id", side_effect=["d1", None]) result = service._create_or_update_pipeline(pipeline=pipeline, data=data, account=account) @@ -743,7 +746,8 @@ def test_create_or_update_pipeline_creates_draft_when_missing(mocker) -> None: account = Mock(id="u1", current_tenant_id="t1") pipeline = Mock(id="p1", tenant_id="t1", name="N", description="D") data = {"rag_pipeline": {"name": "N2", "description": "D2"}, "workflow": {"graph": {"nodes": []}}} - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) workflow_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow") workflow_cls.return_value.id = "wf-new" @@ -817,7 +821,7 @@ def test_import_rag_pipeline_fails_for_non_string_version_type() -> None: def test_append_workflow_export_data_raises_when_draft_workflow_missing() -> None: session = cast(MagicMock, Mock()) service = RagPipelineDslService(session=cast(Session, session)) - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with pytest.raises(ValueError, match="Missing draft workflow configuration"): service._append_workflow_export_data(export_data={}, pipeline=Mock(tenant_id="t1"), include_secret=False) @@ -841,7 +845,7 @@ def test_append_workflow_export_data_keeps_secret_fields_when_include_secret_tru ] } } - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], @@ -1003,7 +1007,8 @@ def test_import_rag_pipeline_sets_default_version_and_kind(mocker) -> None: ) dataset = Mock(id="d1") mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Dataset", return_value=dataset) - session.query.return_value.filter_by.return_value.all.return_value = [] + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) + session.scalars.return_value.all.return_value = [] mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.generate_incremental_name", return_value="P") result = service.import_rag_pipeline( @@ -1061,7 +1066,7 @@ def test_append_workflow_export_data_skips_empty_node_data(mocker) -> None: workflow = Mock() workflow.graph_dict = {"nodes": []} workflow.to_dict.return_value = {"graph": {"nodes": [{"data": {}}, {}]}} - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow mocker.patch( "services.rag_pipeline.rag_pipeline_dsl_service.DependenciesAnalysisService.generate_dependencies", return_value=[], @@ -1246,11 +1251,12 @@ def test_create_or_update_pipeline_saves_dependencies_to_redis(mocker) -> None: account = Mock(id="u1", current_tenant_id="t1") mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.current_user", SimpleNamespace(id="u1")) mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Workflow", return_value=Mock(id="wf-1")) + mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.select", return_value=MagicMock()) pipeline_cls = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.Pipeline") pipeline = pipeline_cls.return_value pipeline.tenant_id = "t1" pipeline.id = "p1" - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None setex = mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.redis_client.setex") dependency = PluginDependency( type=PluginDependency.Type.Marketplace, diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py index f270ee0fde..941a665308 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py @@ -116,81 +116,6 @@ def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_serv assert has_more is True -def test_get_pipeline_raises_when_dataset_not_found(mocker, rag_pipeline_service) -> None: - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) - - with pytest.raises(ValueError, match="Dataset not found"): - rag_pipeline_service.get_pipeline("tenant-1", "dataset-1") - - -# --- update_customized_pipeline_template --- - - -def test_update_customized_pipeline_template_success(mocker) -> None: - template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) - - # First scalar finds the template, second scalar (duplicate check) returns None - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, None]) - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") - mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) - - info = PipelineTemplateInfoEntity( - name="new", - description="new desc", - icon_info=IconInfo(icon="🔥"), - ) - result = RagPipelineService.update_customized_pipeline_template("tpl-1", info) - - assert result.name == "new" - assert result.description == "new desc" - - -def test_update_customized_pipeline_template_not_found(mocker) -> None: - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) - mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) - - info = PipelineTemplateInfoEntity(name="x", description="d", icon_info=IconInfo(icon="i")) - with pytest.raises(ValueError, match="Customized pipeline template not found"): - RagPipelineService.update_customized_pipeline_template("tpl-missing", info) - - -def test_update_customized_pipeline_template_duplicate_name(mocker) -> None: - template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) - duplicate = SimpleNamespace(name="dup") - - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[template, duplicate]) - mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) - - info = PipelineTemplateInfoEntity(name="dup", description="d", icon_info=IconInfo(icon="i")) - with pytest.raises(ValueError, match="Template name is already exists"): - RagPipelineService.update_customized_pipeline_template("tpl-1", info) - - -# --- delete_customized_pipeline_template --- - - -def test_delete_customized_pipeline_template_success(mocker) -> None: - template = SimpleNamespace(id="tpl-1") - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template) - delete_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.delete") - commit_mock = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") - - mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) - - RagPipelineService.delete_customized_pipeline_template("tpl-1") - - delete_mock.assert_called_once_with(template) - commit_mock.assert_called_once() - - -def test_delete_customized_pipeline_template_not_found(mocker) -> None: - mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=None) - mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) - - with pytest.raises(ValueError, match="Customized pipeline template not found"): - RagPipelineService.delete_customized_pipeline_template("tpl-missing") - - # --- sync_draft_workflow --- diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index d15074e7a6..eeb5d178ec 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1427,16 +1427,7 @@ class TestRegisterService: mock_tenant.name = "Test Workspace" mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") - # Mock database queries - need to mock the sessionmaker query - mock_session = MagicMock() - mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = None @@ -1475,7 +1466,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with("newuser@example.com", session=mock_session) + mock_lookup.assert_called_once_with("newuser@example.com") def test_invite_new_member_normalizes_new_account_email( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies @@ -1486,13 +1477,7 @@ class TestRegisterService: mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") mixed_email = "Invitee@Example.com" - mock_session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = None @@ -1525,7 +1510,7 @@ class TestRegisterService: status=AccountStatus.PENDING, is_setup=True, ) - mock_lookup.assert_called_once_with(mixed_email, session=mock_session) + mock_lookup.assert_called_once_with(mixed_email) mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) @@ -1545,16 +1530,7 @@ class TestRegisterService: account_id="existing-user-456", email="existing@example.com", status="pending" ) - # Mock database queries - need to mock the sessionmaker query - mock_session = MagicMock() - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None - with ( - patch("services.account_service.sessionmaker", mock_sessionmaker), patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup, ): mock_lookup.return_value = mock_existing_account @@ -1584,7 +1560,7 @@ class TestRegisterService: mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) mock_task_dependencies.delay.assert_called_once() - mock_lookup.assert_called_once_with("existing@example.com", session=mock_session) + mock_lookup.assert_called_once_with("existing@example.com") def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): """Test inviting a member who is already in the tenant.""" diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 34f718ba02..36592196c6 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1275,42 +1275,6 @@ class TestBillingServiceEdgeCases: # Assert assert result["history_id"] == history_id - def test_is_tenant_owner_or_admin_editor_role_raises_error(self): - """Test tenant owner/admin check raises error for editor role.""" - # Arrange - current_user = MagicMock(spec=Account) - current_user.id = "account-123" - current_user.current_tenant_id = "tenant-456" - - mock_join = MagicMock(spec=TenantAccountJoin) - mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged - - with patch("services.billing_service.db.session") as mock_session: - mock_session.scalar.return_value = mock_join - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - BillingService.is_tenant_owner_or_admin(current_user) - assert "Only team owner or team admin can perform this action" in str(exc_info.value) - - def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self): - """Test tenant owner/admin check raises error for dataset operator role.""" - # Arrange - current_user = MagicMock(spec=Account) - current_user.id = "account-123" - current_user.current_tenant_id = "tenant-456" - - mock_join = MagicMock(spec=TenantAccountJoin) - mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged - - with patch("services.billing_service.db.session") as mock_session: - mock_session.scalar.return_value = mock_join - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - BillingService.is_tenant_owner_or_admin(current_user) - assert "Only team owner or team admin can perform this action" in str(exc_info.value) - class TestBillingServiceSubscriptionOperations: """Unit tests for subscription operations in BillingService. diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py index 3e989c55a3..1bbd214110 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -17,8 +17,7 @@ class TestClearFreePlanTenantExpiredLogs: def mock_session(self): """Create a mock database session.""" session = Mock(spec=Session) - session.query.return_value.filter.return_value.all.return_value = [] - session.query.return_value.filter.return_value.delete.return_value = 0 + session.scalars.return_value.all.return_value = [] return session @pytest.fixture @@ -54,18 +53,18 @@ class TestClearFreePlanTenantExpiredLogs: ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", []) # Should not call any database operations - mock_session.query.assert_not_called() + mock_session.scalars.assert_not_called() mock_storage.save.assert_not_called() def test_clear_message_related_tables_no_records_found(self, mock_session, sample_message_ids): """Test when no related records are found.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = [] + mock_session.scalars.return_value.all.return_value = [] ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) - # Should call query for each related table but find no records - assert mock_session.query.call_count > 0 + # Should call scalars for each related table but find no records + assert mock_session.scalars.call_count > 0 mock_storage.save.assert_not_called() def test_clear_message_related_tables_with_records_and_to_dict( @@ -73,7 +72,7 @@ class TestClearFreePlanTenantExpiredLogs: ): """Test when records are found and have to_dict method.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = sample_records + mock_session.scalars.return_value.all.return_value = sample_records ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) @@ -104,7 +103,7 @@ class TestClearFreePlanTenantExpiredLogs: records.append(record) # Mock records for first table only, empty for others - mock_session.query.return_value.where.return_value.all.side_effect = [ + mock_session.scalars.return_value.all.side_effect = [ records, [], [], @@ -126,13 +125,13 @@ class TestClearFreePlanTenantExpiredLogs: with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: mock_storage.save.side_effect = Exception("Storage error") - mock_session.query.return_value.where.return_value.all.return_value = sample_records + mock_session.scalars.return_value.all.return_value = sample_records # Should not raise exception ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) # Should still delete records even if backup fails - assert mock_session.query.return_value.where.return_value.delete.called + assert mock_session.execute.called def test_clear_message_related_tables_serialization_error_continues(self, mock_session, sample_message_ids): """Test that method continues even when record serialization fails.""" @@ -141,23 +140,23 @@ class TestClearFreePlanTenantExpiredLogs: record.id = "record-1" record.to_dict.side_effect = Exception("Serialization error") - mock_session.query.return_value.where.return_value.all.return_value = [record] + mock_session.scalars.return_value.all.return_value = [record] # Should not raise exception ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) # Should still delete records even if serialization fails - assert mock_session.query.return_value.where.return_value.delete.called + assert mock_session.execute.called def test_clear_message_related_tables_deletion_called(self, mock_session, sample_message_ids, sample_records): """Test that deletion is called for found records.""" with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = sample_records + mock_session.scalars.return_value.all.return_value = sample_records ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) - # Should call delete for each table that has records - assert mock_session.query.return_value.where.return_value.delete.called + # Should call execute(delete(...)) for each table that has records + assert mock_session.execute.called def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes( self, mock_session, sample_message_ids @@ -167,12 +166,12 @@ class TestClearFreePlanTenantExpiredLogs: record.to_dict.side_effect = Exception("Serialization error") with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = [record] + mock_session.scalars.return_value.all.return_value = [record] ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) mock_storage.save.assert_not_called() - assert mock_session.query.return_value.where.return_value.delete.called + assert mock_session.execute.called class _ImmediateFuture: @@ -263,42 +262,23 @@ def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) - conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"}) log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"}) - def make_query_with_batches(batches: list[list[object]]): - q = MagicMock() - q.where.return_value = q - q.limit.return_value = q - q.all.side_effect = batches - q.delete.return_value = 1 - return q - msg_session_1 = MagicMock() - msg_session_1.query.side_effect = lambda model: ( - make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock() - ) + msg_session_1.scalars.return_value.all.return_value = [msg1] + msg_session_2 = MagicMock() - msg_session_2.query.side_effect = lambda model: ( - make_query_with_batches([[]]) if model == service_module.Message else MagicMock() - ) + msg_session_2.scalars.return_value.all.return_value = [] conv_session_1 = MagicMock() - conv_session_1.query.side_effect = lambda model: ( - make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock() - ) + conv_session_1.scalars.return_value.all.return_value = [conv1] conv_session_2 = MagicMock() - conv_session_2.query.side_effect = lambda model: ( - make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock() - ) + conv_session_2.scalars.return_value.all.return_value = [] wal_session_1 = MagicMock() - wal_session_1.query.side_effect = lambda model: ( - make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock() - ) + wal_session_1.scalars.return_value.all.return_value = [log1] wal_session_2 = MagicMock() - wal_session_2.query.side_effect = lambda model: ( - make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock() - ) + wal_session_2.scalars.return_value.all.return_value = [] session_wrappers = [ _sessionmaker_wrapper_for_begin(msg_session_1), @@ -354,9 +334,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py # Total tenant count query count_session = MagicMock() - count_query = MagicMock() - count_query.count.return_value = 2 - count_session.query.return_value = count_query + count_session.scalar.return_value = 2 monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session)) @@ -421,32 +399,15 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt # Sessions used: # 1) total tenant count - # 2) per-batch tenant scan (count + tenant list) + # 2) per-batch tenant scan (interval counts + tenant list) total_session = MagicMock() - total_query = MagicMock() - total_query.count.return_value = 250 - total_session.query.return_value = total_query - - batch_session = MagicMock() - q1 = MagicMock() - q1.where.return_value = q1 - q1.count.return_value = 200 - q2 = MagicMock() - q2.where.return_value = q2 - q2.count.return_value = 200 - q3 = MagicMock() - q3.where.return_value = q3 - q3.count.return_value = 200 - q4 = MagicMock() - q4.where.return_value = q4 - q4.count.return_value = 50 # choose this interval, then scale it + total_session.scalar.return_value = 250 rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")] - q_rs = MagicMock() - q_rs.where.return_value = q_rs - q_rs.order_by.return_value = rows - - batch_session.query.side_effect = [q1, q2, q3, q4, q_rs] + batch_session = MagicMock() + # 4 test intervals queried: 200, 200, 200, 50 — breaks on 50 <= 100 (4th interval = 3h) + batch_session.scalar.side_effect = [200, 200, 200, 50] + batch_session.execute.return_value = rows sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)] monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0)) @@ -464,9 +425,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) count_session = MagicMock() - count_query = MagicMock() - count_query.count.return_value = 100 - count_session.query.return_value = count_query + count_session.scalar.return_value = 100 monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session)) flask_app = service_module.Flask("test-app") @@ -513,25 +472,13 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) total_session = MagicMock() - total_query = MagicMock() - total_query.count.return_value = 250 - total_session.query.return_value = total_query - - batch_session = MagicMock() - # Count results for all 5 intervals, all > 100 => take the for-else path. - count_queries = [] - for _ in range(5): - q = MagicMock() - q.where.return_value = q - q.count.return_value = 200 - count_queries.append(q) + total_session.scalar.return_value = 250 rows = [SimpleNamespace(id="tenant-a")] - q_rs = MagicMock() - q_rs.where.return_value = q_rs - q_rs.order_by.return_value = rows - - batch_session.query.side_effect = [*count_queries, q_rs] + batch_session = MagicMock() + # All 5 intervals have > 100 tenants => for-else falls through to min interval (1h) + batch_session.scalar.side_effect = [200, 200, 200, 200, 200] + batch_session.execute.return_value = rows sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)] monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0)) @@ -542,8 +489,7 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[]) assert process_tenant_mock.call_count == 1 - assert len(count_queries) == 5 - assert batch_session.query.call_count >= 6 + assert batch_session.scalar.call_count == 5 def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None: @@ -565,11 +511,7 @@ def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pyte # Make message/conversation/workflow_app_log loops no-op (empty immediately) empty_session = MagicMock() - q_empty = MagicMock() - q_empty.where.return_value = q_empty - q_empty.limit.return_value = q_empty - q_empty.all.return_value = [] - empty_session.query.return_value = q_empty + empty_session.scalars.return_value.all.return_value = [] session_wrappers = [ _sessionmaker_wrapper_for_begin(empty_session), _sessionmaker_wrapper_for_begin(empty_session), diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index a4359f00b8..68f4c51afe 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -435,36 +435,6 @@ class TestConversationServiceRename: assert conversation.name == "New Name" mock_db_session.commit.assert_called_once() - @patch("services.conversation_service.db.session") - @patch("services.conversation_service.ConversationService.get_conversation") - @patch("services.conversation_service.ConversationService.auto_generate_name") - def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session): - """ - Test renaming conversation with auto-generation. - - Should call auto_generate_name when auto_generate is True. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - mock_auto_generate.return_value = conversation - - # Act - result = ConversationService.rename( - app_model=app_model, - conversation_id="conv-123", - user=user, - name=None, - auto_generate=True, - ) - - # Assert - assert result == conversation - mock_auto_generate.assert_called_once_with(app_model, conversation) - class TestConversationServiceAutoGenerateName: """Test conversation auto-name generation operations.""" @@ -576,29 +546,6 @@ class TestConversationServiceDelete: mock_db_session.commit.assert_called_once() mock_delete_task.delay.assert_called_once_with(conversation.id) - @patch("services.conversation_service.db.session") - @patch("services.conversation_service.ConversationService.get_conversation") - def test_delete_handles_exception_and_rollback(self, mock_get_conversation, mock_db_session): - """ - Test deletion handles exceptions and rolls back transaction. - - Should rollback database changes when deletion fails. - """ - # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() - user = ConversationServiceTestDataFactory.create_account_mock() - conversation = ConversationServiceTestDataFactory.create_conversation_mock() - - mock_get_conversation.return_value = conversation - mock_db_session.delete.side_effect = Exception("Database Error") - - # Act & Assert - with pytest.raises(Exception, match="Database Error"): - ConversationService.delete(app_model, "conv-123", user) - - # Assert rollback was called - mock_db_session.rollback.assert_called_once() - class TestConversationServiceConversationalVariable: """Test conversational variable operations.""" diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py index b2c40763ea..2913ae20fe 100644 --- a/api/tests/unit_tests/services/test_dataset_service_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -532,6 +532,9 @@ class TestDatasetServiceCreationAndUpdate: with ( patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding, + patch( + "services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object() + ) as get_external_knowledge_api, patch("services.dataset_service.naive_utc_now", return_value=now), patch("services.dataset_service.db") as mock_db, ): @@ -557,6 +560,7 @@ class TestDatasetServiceCreationAndUpdate: assert dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM assert dataset.updated_by == "user-1" assert dataset.updated_at is now + get_external_knowledge_api.assert_called_once_with("api-1", dataset.tenant_id) update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1") mock_db.session.add.assert_called_once_with(dataset) mock_db.session.commit.assert_called_once() @@ -574,10 +578,35 @@ class TestDatasetServiceCreationAndUpdate: with pytest.raises(ValueError, match=message): DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1")) + def test_update_external_dataset_rejects_cross_tenant_external_api_id(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + + with ( + patch( + "services.dataset_service.ExternalDatasetService.get_external_knowledge_api", + side_effect=ValueError("api template not found"), + ) as get_external_knowledge_api, + patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding, + patch("services.dataset_service.db") as mock_db, + ): + with pytest.raises(ValueError, match="api template not found"): + DatasetService._update_external_dataset( + dataset, + { + "external_knowledge_id": "knowledge-1", + "external_knowledge_api_id": "foreign-api", + }, + SimpleNamespace(id="user-1"), + ) + + get_external_knowledge_api.assert_called_once_with("foreign-api", dataset.tenant_id) + update_binding.assert_not_called() + mock_db.session.commit.assert_not_called() + def test_update_external_knowledge_binding_updates_changed_binding_values(self): binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api") session = MagicMock() - session.query.return_value.filter_by.return_value.first.return_value = binding + session.scalar.return_value = binding session.add = MagicMock() session_context = _make_session_context(session) @@ -596,7 +625,7 @@ class TestDatasetServiceCreationAndUpdate: def test_update_external_knowledge_binding_raises_for_missing_binding(self): session = MagicMock() - session.query.return_value.filter_by.return_value.first.return_value = None + session.scalar.return_value = None session_context = _make_session_context(session) mock_sessionmaker = MagicMock() diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py index e5a2541da7..3f9386e704 100644 --- a/api/tests/unit_tests/services/test_dataset_service_document.py +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -129,7 +129,7 @@ class TestDocumentServiceQueryAndDownloadHelpers: def test_update_documents_need_summary_updates_matching_documents_and_commits(self): session = MagicMock() - session.query.return_value.filter.return_value.update.return_value = 2 + session.execute.return_value.rowcount = 2 with patch("services.dataset_service.session_factory") as session_factory_mock: session_factory_mock.create_session.return_value = _make_session_context(session) @@ -1069,6 +1069,33 @@ class TestDocumentServiceCreateValidation: assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1 assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False + def test_process_rule_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="hierarchical", + rules=Rule( + pre_processing_rules=[ + PreProcessingRule(id="remove_extra_spaces", enabled=True), + ], + segmentation=Segmentation(separator="\n", max_tokens=1024), + subchunk_segmentation=Segmentation(separator="\n", max_tokens=512), + ), + ), + ) + + DocumentService.process_rule_args_validate(knowledge_config) + + assert knowledge_config.process_rule is not None + assert knowledge_config.process_rule.rules is not None + assert knowledge_config.process_rule.rules.parent_mode == "paragraph" + class TestDocumentServiceSaveDocumentWithDatasetId: """Unit tests for non-SQL validation branches in save_document_with_dataset_id.""" diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index 70ecc158d6..c00a4938bb 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -179,11 +179,11 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session): - mock_db_session.query().first.return_value = MagicMock() + mock_db_session.scalar.return_value = MagicMock() assert service.is_system_oauth_params_exist(make_id()) is True def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None assert service.is_system_oauth_params_exist(make_id()) is False # ----------------------------------------------------------------------- @@ -205,7 +205,7 @@ class TestDatasourceProviderService: def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session): service.remove_oauth_custom_client_params("t1", make_id()) - mock_db_session.query().delete.assert_called_once() + mock_db_session.execute.assert_called_once() # ----------------------------------------------------------------------- # setup_oauth_custom_client_params (315-351) @@ -217,14 +217,14 @@ class TestDatasourceProviderService: mock_db_session.add.assert_not_called() def test_should_create_new_config_when_none_exists(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True) mock_db_session.add.assert_called_once() def test_should_update_existing_config_when_record_found(self, service, mock_db_session): existing = MagicMock() - mock_db_session.query().first.return_value = existing + mock_db_session.scalar.return_value = existing with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False) mock_db_session.add.assert_not_called() # update in place, no add @@ -255,7 +255,7 @@ class TestDatasourceProviderService: def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user): with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None assert service.get_datasource_credentials("t1", "prov", "org/plug") == {} def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user): @@ -264,7 +264,7 @@ class TestDatasourceProviderService: p.auth_type = "oauth2" p.expires_at = 0 # expired p.encrypted_credentials = {"tok": "x"} - mock_db_session.query().first.return_value = p + mock_db_session.scalar.return_value = p with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "get_oauth_client", return_value={"oc": "v"}), @@ -278,7 +278,7 @@ class TestDatasourceProviderService: p.auth_type = "api_key" p.expires_at = -1 # sentinel: never expires p.encrypted_credentials = {"k": "v"} - mock_db_session.query().first.return_value = p + mock_db_session.scalar.return_value = p with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}), @@ -292,7 +292,7 @@ class TestDatasourceProviderService: p.auth_type = "api_key" p.expires_at = -1 p.encrypted_credentials = {} - mock_db_session.query().first.return_value = p + mock_db_session.scalar.return_value = p with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}), @@ -306,7 +306,7 @@ class TestDatasourceProviderService: def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user): with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): - mock_db_session.query().all.return_value = [] + mock_db_session.scalars.return_value.all.return_value = [] assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == [] def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user): @@ -314,7 +314,7 @@ class TestDatasourceProviderService: p.auth_type = "oauth2" p.expires_at = 0 p.encrypted_credentials = {"t": "x"} - mock_db_session.query().all.return_value = [p] + mock_db_session.scalars.return_value.all.return_value = [p] with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "get_oauth_client", return_value={"oc": "v"}), @@ -328,22 +328,21 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(ValueError, match="not found"): service.update_datasource_provider_name("t1", make_id(), "new", "cred-id") def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.name = "same" - mock_db_session.query().first.return_value = p + mock_db_session.scalar.return_value = p service.update_datasource_provider_name("t1", make_id(), "same", "cred-id") def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.name = "old_name" p.is_default = False - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 1 # conflict + mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count with pytest.raises(ValueError, match="already exists"): service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") @@ -351,8 +350,7 @@ class TestDatasourceProviderService: p = MagicMock(spec=DatasourceProvider) p.name = "old_name" p.is_default = False - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") assert p.name == "new_name" @@ -361,7 +359,7 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(ValueError, match="not found"): service.set_default_datasource_provider("t1", make_id(), "bad-id") @@ -369,7 +367,7 @@ class TestDatasourceProviderService: target = MagicMock(spec=DatasourceProvider) target.provider = "provider" target.plugin_id = "org/plug" - mock_db_session.query().first.return_value = target + mock_db_session.scalar.return_value = target service.set_default_datasource_provider("t1", make_id(), "new-id") assert target.is_default is True @@ -428,13 +426,13 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_use_tenant_config_when_available(self, service, mock_db_session): - mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"}) + mock_db_session.scalar.return_value = MagicMock(client_params={"k": "v"}) with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): result = service.get_oauth_client("t1", make_id()) assert result == {"k": "dec"} def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session): - mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})] + mock_db_session.scalar.side_effect = [None, MagicMock(system_credentials={"k": "sys"})] with ( patch.object(service.provider_manager, "fetch_datasource_provider"), patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True), @@ -444,7 +442,7 @@ class TestDatasourceProviderService: def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session): """Neither tenant nor system credentials → raises ValueError.""" - mock_db_session.query().first.side_effect = [None, None] + mock_db_session.scalar.side_effect = [None, None] with ( patch.object(service.provider_manager, "fetch_datasource_provider"), patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False), @@ -457,15 +455,14 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {}) mock_db_session.add.assert_called_once() def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session): """Conflict on name results in auto-incremented name, not an error.""" - mock_db_session.query().count.return_value = 1 # conflict first, then auto-named - mock_db_session.query().all.return_value = [] + mock_db_session.scalar.return_value = 1 # conflict first, then auto-named with ( patch.object(service, "extract_secret_variables", return_value=[]), patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"), @@ -475,8 +472,7 @@ class TestDatasourceProviderService: def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session): """name=None causes auto-generation via generate_next_datasource_provider_name.""" - mock_db_session.query().count.return_value = 0 - mock_db_session.query().all.return_value = [] + mock_db_session.scalar.return_value = 0 with ( patch.object(service, "extract_secret_variables", return_value=[]), patch.object(service, "generate_next_datasource_provider_name", return_value="auto"), @@ -485,13 +481,13 @@ class TestDatasourceProviderService: mock_db_session.add.assert_called_once() def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=["secret_key"]): service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"}) self._enc.encrypt_token.assert_called() def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {}) self._redis.lock.assert_called() @@ -501,23 +497,21 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with patch.object(service, "extract_secret_variables", return_value=[]): with pytest.raises(ValueError, match="not found"): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id") def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with patch.object(service, "extract_secret_variables", return_value=[]): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 1 # conflict - mock_db_session.query().all.return_value = [] + mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count + mock_db_session.scalars.return_value.all.return_value = [] with patch.object(service, "extract_secret_variables", return_value=["tok"]): service.reauthorize_datasource_oauth_provider( "conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id" @@ -525,16 +519,14 @@ class TestDatasourceProviderService: def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with patch.object(service, "extract_secret_variables", return_value=["tok"]): service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id") self._enc.encrypt_token.assert_called() def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with patch.object(service, "extract_secret_variables", return_value=[]): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") self._redis.lock.assert_called() @@ -545,13 +537,13 @@ class TestDatasourceProviderService: def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user): """explicit name supplied + conflict → raises ValueError immediately.""" - mock_db_session.query().count.return_value = 1 + mock_db_session.scalar.return_value = 1 with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): with pytest.raises(ValueError, match="already exists"): service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"}) def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")), @@ -561,7 +553,7 @@ class TestDatasourceProviderService: service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"}) def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service.provider_manager, "validate_provider_credentials"), @@ -571,7 +563,7 @@ class TestDatasourceProviderService: mock_db_session.add.assert_called_once() def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user): - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service.provider_manager, "validate_provider_credentials"), @@ -694,7 +686,7 @@ class TestDatasourceProviderService: # ----------------------------------------------------------------------- def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user): - mock_db_session.query().first.return_value = None + mock_db_session.scalar.return_value = None with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): with pytest.raises(ValueError, match="not found"): service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name") @@ -704,8 +696,7 @@ class TestDatasourceProviderService: p.name = "old_name" p.auth_type = "api_key" p.encrypted_credentials = {"sk": "e"} - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 1 + mock_db_session.scalar.side_effect = [p, 1] # first: fetch provider, second: name conflict count with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): with pytest.raises(ValueError, match="already exists"): service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name") @@ -717,8 +708,7 @@ class TestDatasourceProviderService: p.name = "old_name" p.auth_type = "api_key" p.encrypted_credentials = {"sk": "e"} - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "extract_secret_variables", return_value=["sk"]), @@ -733,8 +723,7 @@ class TestDatasourceProviderService: p.name = "old_name" p.auth_type = "api_key" p.encrypted_credentials = {"sk": "old_enc"} - mock_db_session.query().first.return_value = p - mock_db_session.query().count.return_value = 0 + mock_db_session.scalar.side_effect = [p, 0] # first: fetch provider, second: name conflict count with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "extract_secret_variables", return_value=["sk"]), diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index 7c8dab5029..b802f6931f 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -974,26 +974,29 @@ class TestExternalDatasetServiceAPIUseCheck: """Test API use check when API has one binding.""" # Arrange api_id = "api-123" + tenant_id = "tenant-123" mock_db.session.scalar.return_value = 1 # Act - in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id) + in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id) # Assert assert in_use is True assert count == 1 + assert "tenant_id" in str(mock_db.session.scalar.call_args.args[0]) @patch("services.external_knowledge_service.db") def test_external_knowledge_api_use_check_in_use_multiple(self, mock_db, factory): """Test API use check with multiple bindings.""" # Arrange api_id = "api-123" + tenant_id = "tenant-123" mock_db.session.scalar.return_value = 10 # Act - in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id) + in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id) # Assert assert in_use is True @@ -1004,11 +1007,12 @@ class TestExternalDatasetServiceAPIUseCheck: """Test API use check when API is not in use.""" # Arrange api_id = "api-123" + tenant_id = "tenant-123" mock_db.session.scalar.return_value = 0 # Act - in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id) + in_use, count = ExternalDatasetService.external_knowledge_api_use_check(api_id, tenant_id) # Assert assert in_use is False @@ -1556,6 +1560,17 @@ class TestExternalDatasetServiceFetchRetrieval: with pytest.raises(ValueError, match="external knowledge binding not found"): ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {}) + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_cross_tenant_api_template_error(self, mock_db, factory): + """Test error when a binding points to an API template outside the dataset tenant.""" + # Arrange + binding = factory.create_external_knowledge_binding_mock() + mock_db.session.scalar.side_effect = [binding, None] + + # Act & Assert + with pytest.raises(ValueError, match="external api template not found"): + ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {}) + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @patch("services.external_knowledge_service.db") def test_fetch_external_knowledge_retrieval_empty_results(self, mock_db, mock_process, factory): diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py index b7259c3e82..8e1b22886b 100644 --- a/api/tests/unit_tests/services/test_file_service.py +++ b/api/tests/unit_tests/services/test_file_service.py @@ -165,7 +165,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "test_key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.storage") as mock_storage: mock_storage.load_once.return_value = b"test content" @@ -178,7 +178,7 @@ class TestFileService: mock_storage.load_once.assert_called_once_with("test_key") def test_get_file_base64_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_base64("non_existent") @@ -215,7 +215,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "pdf" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract: mock_extract.return_value = "Extracted text content" @@ -227,7 +227,7 @@ class TestFileService: assert result == "Extracted text content" def test_get_file_preview_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_preview("non_existent") @@ -235,7 +235,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "exe" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with pytest.raises(UnsupportedFileTypeError): file_service.get_file_preview("file_id") @@ -246,7 +246,7 @@ class TestFileService: upload_file.extension = "jpg" upload_file.mime_type = "image/jpeg" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with ( patch("services.file_service.file_helpers.verify_image_signature") as mock_verify, @@ -269,7 +269,7 @@ class TestFileService: file_service.get_image_preview("file_id", "ts", "nonce", "sign") def test_get_image_preview_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): @@ -279,7 +279,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(UnsupportedFileTypeError): @@ -289,7 +289,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with ( patch("services.file_service.file_helpers.verify_file_signature") as mock_verify, @@ -309,7 +309,7 @@ class TestFileService: file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): @@ -321,7 +321,7 @@ class TestFileService: upload_file.extension = "png" upload_file.mime_type = "image/png" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.storage") as mock_storage: mock_storage.load.return_value = b"image content" @@ -330,7 +330,7 @@ class TestFileService: assert mime == "image/png" def test_get_public_image_preview_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_public_image_preview("file_id") @@ -338,7 +338,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with pytest.raises(UnsupportedFileTypeError): file_service.get_public_image_preview("file_id") @@ -346,7 +346,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.storage") as mock_storage: mock_storage.load.return_value = b"hello world" @@ -354,7 +354,7 @@ class TestFileService: assert result == "hello world" def test_get_file_content_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_content("file_id") diff --git a/api/tests/unit_tests/services/test_ops_service.py b/api/tests/unit_tests/services/test_ops_service.py deleted file mode 100644 index 7067e3b3dd..0000000000 --- a/api/tests/unit_tests/services/test_ops_service.py +++ /dev/null @@ -1,392 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.ops.entities.config_entity import TracingProviderEnum -from models.model import App, TraceAppConfig -from services.ops_service import OpsService - - -class TestOpsService: - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): - # Arrange - mock_db.session.scalar.return_value = None - - # Act - result = OpsService.get_tracing_app_config("app_id", "arize") - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = None - - # Act - result = OpsService.get_tracing_app_config("app_id", "arize") - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = None - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - # Act & Assert - with pytest.raises(ValueError, match="Tracing config cannot be None."): - OpsService.get_tracing_app_config("app_id", "arize") - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - @pytest.mark.parametrize( - ("provider", "default_url"), - [ - ("arize", "https://app.arize.com/"), - ("phoenix", "https://app.phoenix.arize.com/projects/"), - ("langsmith", "https://smith.langchain.com/"), - ("opik", "https://www.comet.com/opik/"), - ("weave", "https://wandb.ai/"), - ("aliyun", "https://arms.console.aliyun.com/"), - ("tencent", "https://console.cloud.tencent.com/apm"), - ("mlflow", "http://localhost:5000/"), - ("databricks", "https://www.databricks.com/"), - ], - ) - def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} - mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") - mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - - # Act - result = OpsService.get_tracing_app_config("app_id", provider) - - # Assert - assert result["tracing_config"]["project_url"] == default_url - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - @pytest.mark.parametrize( - "provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"] - ) - def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} - mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url" - - # Act - result = OpsService.get_tracing_app_config("app_id", provider) - - # Assert - assert result["tracing_config"]["project_url"] == "success_url" - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" - - # Act - result = OpsService.get_tracing_app_config("app_id", "langfuse") - - # Assert - assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - trace_config.tracing_config = {"some": "config"} - trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = trace_config - mock_db.session.get.return_value = app - - mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} - mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - - # Act - result = OpsService.get_tracing_app_config("app_id", "langfuse") - - # Assert - assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): - # Act - result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {}) - - # Assert - assert result == {"error": "Invalid tracing provider: invalid_provider"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.LANGFUSE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = False - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"}) - - # Assert - assert result == {"error": "Invalid Credentials"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - @pytest.mark.parametrize( - ("provider", "config"), - [ - (TracingProviderEnum.ARIZE, {}), - (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), - (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), - (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), - ], - ) - def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config): - # Arrange - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") - mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") - mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, config) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.LANGFUSE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = app - mock_ops_trace_manager.encrypt_tracing_config.return_value = {} - - # Act - result = OpsService.create_tracing_app_config( - "app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"} - ) - - # Assert - assert result == {"result": "success"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.scalar.return_value = MagicMock(spec=TraceAppConfig) - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = None - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = app - mock_ops_trace_manager.encrypt_tracing_config.return_value = {} - - # Act - # 'project' is in other_keys for Arize - # provide an empty string for the project in the tracing_config - # create_tracing_app_config will replace it with the default from the model - result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""}) - - # Assert - assert result == {"result": "success"} - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url" - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = None - mock_db.session.get.return_value = app - mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"} - - # Act - result = OpsService.create_tracing_app_config("app_id", provider, {}) - - # Assert - assert result == {"result": "success"} - mock_db.session.add.assert_called() - mock_db.session.commit.assert_called() - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): - # Act & Assert - with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): - OpsService.update_tracing_app_config("app_id", "invalid_provider", {}) - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - mock_db.session.scalar.return_value = None - - # Act - result = OpsService.update_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - current_config = MagicMock(spec=TraceAppConfig) - mock_db.session.scalar.return_value = current_config - mock_db.session.get.return_value = None - - # Act - result = OpsService.update_tracing_app_config("app_id", provider, {}) - - # Assert - assert result is None - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - current_config = MagicMock(spec=TraceAppConfig) - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = current_config - mock_db.session.get.return_value = app - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.check_trace_config_is_effective.return_value = False - - # Act & Assert - with pytest.raises(ValueError, match="Invalid Credentials"): - OpsService.update_tracing_app_config("app_id", provider, {}) - - @patch("services.ops_service.db") - @patch("services.ops_service.OpsTraceManager") - def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): - # Arrange - provider = TracingProviderEnum.ARIZE - current_config = MagicMock(spec=TraceAppConfig) - current_config.to_dict.return_value = {"some": "data"} - app = MagicMock(spec=App) - app.tenant_id = "tenant_id" - mock_db.session.scalar.return_value = current_config - mock_db.session.get.return_value = app - mock_ops_trace_manager.decrypt_tracing_config.return_value = {} - mock_ops_trace_manager.check_trace_config_is_effective.return_value = True - - # Act - result = OpsService.update_tracing_app_config("app_id", provider, {}) - - # Assert - assert result == {"some": "data"} - mock_db.session.commit.assert_called_once() - - @patch("services.ops_service.db") - def test_delete_tracing_app_config_no_config(self, mock_db): - # Arrange - mock_db.session.scalar.return_value = None - - # Act - result = OpsService.delete_tracing_app_config("app_id", "arize") - - # Assert - assert result is None - - @patch("services.ops_service.db") - def test_delete_tracing_app_config_success(self, mock_db): - # Arrange - trace_config = MagicMock(spec=TraceAppConfig) - mock_db.session.scalar.return_value = trace_config - - # Act - result = OpsService.delete_tracing_app_config("app_id", "arize") - - # Assert - assert result is True - mock_db.session.delete.assert_called_with(trace_config) - mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index cbf3e121d8..e17d4134ac 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -124,10 +124,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes existing.disabled_by = "u" session = MagicMock(name="session") - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = existing - session.query.return_value = query + session.scalar.return_value = existing create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -149,10 +146,7 @@ def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytes def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None: session = MagicMock(name="session") - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -234,10 +228,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat # New session used after vectorization succeeds (record not found by id nor chunk_id). session = MagicMock(name="session") - q1 = MagicMock() - q1.filter_by.return_value = q1 - q1.first.side_effect = [None, None] - session.query.return_value = q1 + session.scalar.side_effect = [None, None] create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -267,10 +258,7 @@ def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytes # error_session should find record and commit status update error_session = MagicMock(name="error_session") - q = MagicMock() - q.filter_by.return_value = q - q.first.return_value = summary - error_session.query.return_value = q + error_session.scalar.return_value = summary create_session_mock = MagicMock(return_value=_SessionContext(error_session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -302,10 +290,7 @@ def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.Mo existing.enabled = False session = MagicMock() - query = MagicMock() - query.filter.return_value = query - query.all.return_value = [existing] - session.query.return_value = query + session.scalars.return_value.all.return_value = [existing] monkeypatch.setattr( summary_module, @@ -324,10 +309,7 @@ def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.Mon record = _summary_record() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, "session_factory", @@ -346,10 +328,7 @@ def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch) record = _summary_record(summary_content="") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, @@ -373,10 +352,7 @@ def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch record = _summary_record(summary_content="") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, @@ -415,10 +391,7 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch existing = _summary_record(summary_content="old", node_id="old-node") existing.id = "other-id" session = MagicMock(name="session") - q = MagicMock() - q.filter_by.return_value = q - q.first.side_effect = [None, existing] # miss by id, hit by chunk_id - session.query.return_value = q + session.scalar.side_effect = [None, existing] # miss by id, hit by chunk_id monkeypatch.setattr( summary_module, "session_factory", @@ -448,10 +421,7 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte existing = _summary_record(summary_content="old", node_id="old-node") session = MagicMock(name="session") - q = MagicMock() - q.filter_by.return_value = q - q.first.return_value = existing # hit by id - session.query.return_value = q + session.scalar.return_value = existing # hit by id monkeypatch.setattr( summary_module, "session_factory", @@ -487,10 +457,7 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon return None error_session = MagicMock() - q = MagicMock() - q.filter_by.return_value = q - q.first.return_value = summary - error_session.query.return_value = q + error_session.scalar.return_value = summary create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)]) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -516,21 +483,17 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc ) session = MagicMock() - q = MagicMock() - q.filter_by.return_value = q - q.first.side_effect = [None, None] # miss by id and chunk_id - session.query.return_value = q + session.scalar.side_effect = [None, None] # miss by id and chunk_id error_session = MagicMock() - eq = MagicMock() - eq.filter_by.return_value = eq - eq.first.return_value = summary - error_session.query.return_value = eq + error_session.scalar.return_value = summary create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)]) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) # Force the created record to be None so the "should not be None" guard triggers. + # Also mock select() so SQLAlchemy doesn't validate the mocked DocumentSegmentSummary as a real column clause. + monkeypatch.setattr(summary_module, "select", MagicMock(return_value=MagicMock())) monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None)) with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"): @@ -554,10 +517,7 @@ def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_ ) error_session = MagicMock(name="error_session") - q = MagicMock() - q.filter_by.return_value = q - q.first.side_effect = [None, None] # not found by id, not found by chunk_id - error_session.query.return_value = q + error_session.scalar.side_effect = [None, None] # not found by id, not found by chunk_id monkeypatch.setattr( summary_module, @@ -577,10 +537,7 @@ def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.Monk segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, "session_factory", @@ -599,10 +556,7 @@ def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(mo segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, "session_factory", @@ -646,11 +600,7 @@ def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: py seg2.id = "seg-2" session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [seg1, seg2] - session.query.return_value = query + session.scalars.return_value.all.return_value = [seg1, seg2] monkeypatch.setattr( summary_module, @@ -678,11 +628,7 @@ def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: document.doc_form = IndexStructureType.PARAGRAPH_INDEX session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [] - session.query.return_value = query + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -702,11 +648,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu seg = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [seg] - session.query.return_value = query + session.scalars.return_value.all.return_value = [seg] monkeypatch.setattr( summary_module, "session_factory", @@ -723,7 +665,7 @@ def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chu segment_ids=[seg.id], only_parent_chunks=True, ) - query.filter.assert_called() + session.scalars.assert_called() def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None: @@ -732,11 +674,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: summary2 = _summary_record(summary_content="s", node_id=None) session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [summary1, summary2] - session.query.return_value = query + session.scalars.return_value.all.return_value = [summary1, summary2] monkeypatch.setattr( summary_module, @@ -761,11 +699,7 @@ def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _dataset() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [] - session.query.return_value = query + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -793,21 +727,8 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt segment.status = SegmentStatus.COMPLETED session = MagicMock() - summary_query = MagicMock() - summary_query.filter_by.return_value = summary_query - summary_query.filter.return_value = summary_query - summary_query.all.return_value = [summary] - - seg_query = MagicMock() - seg_query.filter_by.return_value = seg_query - seg_query.first.return_value = segment - - def query_side_effect(model: object) -> MagicMock: - if model is summary_module.DocumentSegmentSummary: - return summary_query - return seg_query - - session.query.side_effect = query_side_effect + session.scalars.return_value.all.return_value = [summary] + session.scalar.return_value = segment monkeypatch.setattr( summary_module, @@ -826,11 +747,7 @@ def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pyt def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _dataset() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [] - session.query.return_value = query + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -860,21 +777,9 @@ def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vect good_segment.status = SegmentStatus.COMPLETED session = MagicMock() - summary_query = MagicMock() - summary_query.filter_by.return_value = summary_query - summary_query.filter.return_value = summary_query - summary_query.all.return_value = [summary1, summary2, summary3] + session.scalars.return_value.all.return_value = [summary1, summary2, summary3] + session.scalar.side_effect = [bad_segment, good_segment, good_segment] - seg_query = MagicMock() - seg_query.filter_by.return_value = seg_query - seg_query.first.side_effect = [bad_segment, good_segment, good_segment] - - def query_side_effect(model: object) -> MagicMock: - if model is summary_module.DocumentSegmentSummary: - return summary_query - return seg_query - - session.query.side_effect = query_side_effect monkeypatch.setattr( summary_module, "session_factory", @@ -895,11 +800,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch: summary = _summary_record(summary_content="sum", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [summary] - session.query.return_value = query + session.scalars.return_value.all.return_value = [summary] vector_instance = MagicMock() monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) @@ -918,11 +819,7 @@ def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch: def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: dataset = _dataset() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.filter.return_value = query - query.all.return_value = [] - session.query.return_value = query + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -946,10 +843,7 @@ def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch: record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record vector_instance = MagicMock() monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) @@ -971,10 +865,7 @@ def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatc record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, "session_factory", @@ -996,10 +887,7 @@ def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: py segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, "session_factory", @@ -1015,10 +903,7 @@ def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch: record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record vector_instance = MagicMock() monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) @@ -1044,10 +929,7 @@ def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: py record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, "session_factory", @@ -1073,10 +955,7 @@ def test_update_summary_for_segment_existing_vectorize_failure_returns_error_rec record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record monkeypatch.setattr( summary_module, "session_factory", @@ -1095,10 +974,7 @@ def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.Monke segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, "session_factory", @@ -1122,10 +998,7 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk record = _summary_record(summary_content="old", node_id="n1") session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = record - session.query.return_value = query + session.scalar.return_value = record session.flush.side_effect = RuntimeError("flush boom") monkeypatch.setattr( summary_module, @@ -1143,25 +1016,9 @@ def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monk def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None: record = _summary_record(summary_content="sum", node_id="n1") session = MagicMock() + session.scalar.return_value = record + session.scalars.return_value.all.return_value = [record] - q1 = MagicMock() - q1.where.return_value = q1 - q1.first.return_value = record - - q2 = MagicMock() - q2.filter.return_value = q2 - q2.all.return_value = [record] - - def query_side_effect(model: object) -> MagicMock: - if model is summary_module.DocumentSegmentSummary: - # first call used by get_segment_summary, second by get_document_summaries - if not hasattr(query_side_effect, "_called"): - query_side_effect._called = True # type: ignore[attr-defined] - return q1 - return q2 - return MagicMock() - - session.query.side_effect = query_side_effect monkeypatch.setattr( summary_module, "session_factory", @@ -1178,10 +1035,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No record2 = _summary_record() record2.chunk_id = "seg-2" session = MagicMock() - q = MagicMock() - q.where.return_value = q - q.all.return_value = [record1, record2] - session.query.return_value = q + session.scalars.return_value.all.return_value = [record1, record2] monkeypatch.setattr( summary_module, "session_factory", @@ -1194,10 +1048,7 @@ def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> No def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None: session = MagicMock() - q = MagicMock() - q.where.return_value = q - q.all.return_value = [] - session.query.return_value = q + session.scalars.return_value.all.return_value = [] monkeypatch.setattr( summary_module, "session_factory", @@ -1212,10 +1063,7 @@ def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.Monk def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None: session = MagicMock() - q = MagicMock() - q.where.return_value = q - q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")] - session.query.return_value = q + session.execute.return_value.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")] monkeypatch.setattr( summary_module, "session_factory", @@ -1237,10 +1085,7 @@ def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_erro segment = _segment() session = MagicMock() - query = MagicMock() - query.filter_by.return_value = query - query.first.return_value = None - session.query.return_value = query + session.scalar.return_value = None monkeypatch.setattr( summary_module, @@ -1267,10 +1112,7 @@ def test_get_segments_summaries_empty_list() -> None: def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None: seg_row = SimpleNamespace(id="seg-1", document_id="doc-1") session = MagicMock() - query = MagicMock() - query.where.return_value = query - query.all.return_value = [SimpleNamespace(id="seg-1")] - session.query.return_value = query + session.scalars.return_value.all.return_value = ["seg-1"] # get_document_summary_index_status returns IDs create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) @@ -1283,11 +1125,8 @@ def test_get_document_summary_index_status_and_documents_status(monkeypatch: pyt assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING" # Multiple docs - query2 = MagicMock() - query2.where.return_value = query2 - query2.all.return_value = [seg_row] session2 = MagicMock() - session2.query.return_value = query2 + session2.execute.return_value.all.return_value = [seg_row] # get_documents_summary_index_status uses execute monkeypatch.setattr( summary_module, "session_factory", diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py index 350ff718c1..bd2e936b62 100644 --- a/api/tests/unit_tests/services/test_trigger_provider_service.py +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -124,9 +124,7 @@ def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_su provider_id: TriggerProviderID, ) -> None: # Arrange - query = MagicMock() - query.filter_by.return_value.order_by.return_value.all.return_value = [] - mock_session.query.return_value = query + mock_session.scalars.return_value.all.return_value = [] # Act result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) @@ -152,11 +150,8 @@ def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workf db_sub = SimpleNamespace(to_api_entity=lambda: api_sub) usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2) - query_subs = MagicMock() - query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub] - query_usage = MagicMock() - query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row] - mock_session.query.side_effect = [query_subs, query_usage] + mock_session.scalars.return_value.all.return_value = [db_sub] + mock_session.execute.return_value.all.return_value = [usage_row] _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"}) @@ -188,11 +183,7 @@ def test_add_trigger_subscription_should_create_subscription_successfully_for_ap ) -> None: # Arrange _patch_redis_lock(mocker) - query_count = MagicMock() - query_count.filter_by.return_value.count.return_value = 0 - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = None - mock_session.query.side_effect = [query_count, query_existing] + mock_session.scalar.side_effect = [0, None] # count=0, no existing name _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(encrypted={"api_key": "enc"}) @@ -228,11 +219,7 @@ def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorize ) -> None: # Arrange _patch_redis_lock(mocker) - query_count = MagicMock() - query_count.filter_by.return_value.count.return_value = 0 - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = None - mock_session.query.side_effect = [query_count, query_existing] + mock_session.scalar.side_effect = [0, None] # count=0, no existing name _mock_get_trigger_provider(mocker, provider_controller) prop_enc = _encrypter_mock(encrypted={"p": "enc"}) @@ -267,9 +254,7 @@ def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached ) -> None: # Arrange _patch_redis_lock(mocker) - query_count = MagicMock() - query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__ - mock_session.query.return_value = query_count + mock_session.scalar.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__ _mock_get_trigger_provider(mocker, provider_controller) mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger") @@ -297,11 +282,7 @@ def test_add_trigger_subscription_should_raise_error_when_name_exists( ) -> None: # Arrange _patch_redis_lock(mocker) - query_count = MagicMock() - query_count.filter_by.return_value.count.return_value = 0 - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = object() - mock_session.query.side_effect = [query_count, query_existing] + mock_session.scalar.side_effect = [0, object()] # count=0, existing name conflict _mock_get_trigger_provider(mocker, provider_controller) # Act + Assert @@ -325,9 +306,7 @@ def test_update_trigger_subscription_should_raise_error_when_subscription_not_fo ) -> None: # Arrange _patch_redis_lock(mocker) - query_sub = MagicMock() - query_sub.filter_by.return_value.first.return_value = None - mock_session.query.return_value = query_sub + mock_session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="not found"): @@ -347,11 +326,7 @@ def test_update_trigger_subscription_should_raise_error_when_name_conflicts( provider_id="langgenius/github/github", credential_type=CredentialType.API_KEY.value, ) - query_sub = MagicMock() - query_sub.filter_by.return_value.first.return_value = subscription - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = object() - mock_session.query.side_effect = [query_sub, query_existing] + mock_session.scalar.side_effect = [subscription, object()] # found sub, name conflict _mock_get_trigger_provider(mocker, provider_controller) # Act + Assert @@ -378,11 +353,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache( credential_expires_at=0, expires_at=0, ) - query_sub = MagicMock() - query_sub.filter_by.return_value.first.return_value = subscription - query_existing = MagicMock() - query_existing.filter_by.return_value.first.return_value = None - mock_session.query.side_effect = [query_sub, query_existing] + mock_session.scalar.side_effect = [subscription, None] # found sub, no name conflict _mock_get_trigger_provider(mocker, provider_controller) prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"}) @@ -417,7 +388,7 @@ def test_update_trigger_subscription_should_update_fields_and_clear_cache( def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") @@ -439,7 +410,7 @@ def test_get_subscription_by_id_should_decrypt_credentials_and_properties( credentials={"token": "enc"}, properties={"project": "enc"}, ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(decrypted={"token": "plain"}) prop_enc = _encrypter_mock(decrypted={"project": "plain"}) @@ -466,7 +437,7 @@ def test_delete_trigger_provider_should_raise_error_when_subscription_missing( mock_session: MagicMock, ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="not found"): @@ -488,7 +459,7 @@ def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscri credentials={"token": "enc"}, to_entity=lambda: SimpleNamespace(id="sub-1"), ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(decrypted={"token": "plain"}) mocker.patch( @@ -524,7 +495,7 @@ def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized( credentials={}, to_entity=lambda: SimpleNamespace(id="sub-2"), ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger") mocker.patch( @@ -544,7 +515,7 @@ def test_refresh_oauth_token_should_raise_error_when_subscription_missing( mocker: MockerFixture, mock_session: MagicMock ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="not found"): @@ -556,7 +527,7 @@ def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials( ) -> None: # Arrange subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription # Act + Assert with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"): @@ -577,7 +548,7 @@ def test_refresh_oauth_token_should_refresh_and_persist_new_credentials( credentials={"access_token": "enc"}, credential_expires_at=0, ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) cache = MagicMock() cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"}) @@ -606,7 +577,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing( mocker: MockerFixture, mock_session: MagicMock ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act + Assert with pytest.raises(ValueError, match="not found"): @@ -616,7 +587,7 @@ def test_refresh_subscription_should_raise_error_when_subscription_missing( def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None: # Arrange subscription = SimpleNamespace(expires_at=200) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription # Act result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) @@ -643,7 +614,7 @@ def test_refresh_subscription_should_refresh_and_persist_properties( credentials={"c": "enc"}, credential_type=CredentialType.API_KEY.value, ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) cred_enc = _encrypter_mock(decrypted={"c": "plain"}) prop_cache = MagicMock() @@ -681,10 +652,7 @@ def test_get_oauth_client_should_return_tenant_client_when_available( ) -> None: # Arrange tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"}) - system_client = None - query_tenant = MagicMock() - query_tenant.filter_by.return_value.first.return_value = tenant_client - mock_session.query.return_value = query_tenant + mock_session.scalar.return_value = tenant_client _mock_get_trigger_provider(mocker, provider_controller) enc = _encrypter_mock(decrypted={"client_id": "plain"}) mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) @@ -703,11 +671,7 @@ def test_get_oauth_client_should_return_none_when_plugin_not_verified( provider_controller: MagicMock, ) -> None: # Arrange - query_tenant = MagicMock() - query_tenant.filter_by.return_value.first.return_value = None - query_system = MagicMock() - query_system.filter_by.return_value.first.return_value = None - mock_session.query.side_effect = [query_tenant, query_system] + mock_session.scalar.return_value = None # no tenant client; plugin not verified → early return _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) @@ -725,11 +689,7 @@ def test_get_oauth_client_should_return_decrypted_system_client_when_verified( provider_controller: MagicMock, ) -> None: # Arrange - query_tenant = MagicMock() - query_tenant.filter_by.return_value.first.return_value = None - query_system = MagicMock() - query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") - mock_session.query.side_effect = [query_tenant, query_system] + mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")] _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) mocker.patch( @@ -751,11 +711,7 @@ def test_get_oauth_client_should_raise_error_when_system_decryption_fails( provider_controller: MagicMock, ) -> None: # Arrange - query_tenant = MagicMock() - query_tenant.filter_by.return_value.first.return_value = None - query_system = MagicMock() - query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") - mock_session.query.side_effect = [query_tenant, query_system] + mock_session.scalar.side_effect = [None, SimpleNamespace(encrypted_oauth_params="enc")] _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) mocker.patch( @@ -794,7 +750,7 @@ def test_is_oauth_system_client_exists_should_reflect_database_record( provider_controller: MagicMock, ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None + mock_session.scalar.return_value = object() if has_client else None _mock_get_trigger_provider(mocker, provider_controller) mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) @@ -823,11 +779,11 @@ def test_save_custom_oauth_client_params_should_create_record_and_clear_params_w provider_controller: MagicMock, ) -> None: # Arrange - query = MagicMock() - query.filter_by.return_value.first.return_value = None - mock_session.query.return_value = query + mock_session.scalar.return_value = None _mock_get_trigger_provider(mocker, provider_controller) fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={}) + # Also mock select() so SQLAlchemy doesn't validate the patched TriggerOAuthTenantClient. + mocker.patch("services.trigger.trigger_provider_service.select", MagicMock(return_value=MagicMock())) mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model) # Act @@ -853,7 +809,7 @@ def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_c ) -> None: # Arrange custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False) - mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + mock_session.scalar.return_value = custom_client _mock_get_trigger_provider(mocker, provider_controller) cache = MagicMock() enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"}) @@ -882,7 +838,7 @@ def test_get_custom_oauth_client_params_should_return_empty_when_record_missing( provider_id: TriggerProviderID, ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) @@ -899,7 +855,7 @@ def test_get_custom_oauth_client_params_should_return_masked_decrypted_values( ) -> None: # Arrange custom_client = SimpleNamespace(oauth_params={"client_id": "enc"}) - mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + mock_session.scalar.return_value = custom_client _mock_get_trigger_provider(mocker, provider_controller) enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"}) mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) @@ -916,9 +872,6 @@ def test_delete_custom_oauth_client_params_should_delete_record_and_commit( mock_session: MagicMock, provider_id: TriggerProviderID, ) -> None: - # Arrange - mock_session.query.return_value.filter_by.return_value.delete.return_value = 1 - # Act result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id) @@ -934,7 +887,7 @@ def test_is_oauth_custom_client_enabled_should_return_expected_boolean( provider_id: TriggerProviderID, ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None + mock_session.scalar.return_value = object() if exists else None # Act result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id) @@ -947,7 +900,7 @@ def test_get_subscription_by_endpoint_should_return_none_when_not_found( mocker: MockerFixture, mock_session: MagicMock ) -> None: # Arrange - mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_session.scalar.return_value = None # Act result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") @@ -968,7 +921,7 @@ def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties( credentials={"token": "enc"}, properties={"hook": "enc"}, ) - mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + mock_session.scalar.return_value = subscription _mock_get_trigger_provider(mocker, provider_controller) mocker.patch( "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 08299e6680..0eefdf7209 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -657,7 +657,7 @@ def _app(**kwargs: Any) -> App: def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None: # Arrange fake_session = MagicMock() - fake_session.query.return_value = _FakeQuery(None) + fake_session.scalar.return_value = None _patch_session(monkeypatch, fake_session) # Act / Assert @@ -671,7 +671,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_foun # Arrange webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)] + fake_session.scalar.side_effect = [webhook_trigger, None] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -686,7 +686,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_lim webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED) fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -701,7 +701,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED) fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -714,7 +714,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(m webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -732,7 +732,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mod workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}} fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow] _patch_session(monkeypatch, fake_session) # Act @@ -751,7 +751,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(mo workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}} fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)] + fake_session.scalar.side_effect = [webhook_trigger, workflow] _patch_session(monkeypatch, fake_session) # Act diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 76fcb19ab2..406b4fb9d0 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -969,8 +969,7 @@ class TestWorkflowService: # 1. Workflow exists # 2. No app is currently using it # 3. Not published as a tool - mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it - mock_session.query.return_value.where.return_value.first.return_value = None # no tool provider + mock_session.scalar.side_effect = [mock_workflow, None, None] # workflow, no app using it, no tool provider with patch("services.workflow_service.select") as mock_select: mock_stmt = MagicMock() @@ -1045,8 +1044,7 @@ class TestWorkflowService: mock_tool_provider = MagicMock() mock_session = MagicMock() - mock_session.scalar.side_effect = [mock_workflow, None] # workflow exists, no app using it - mock_session.query.return_value.where.return_value.first.return_value = mock_tool_provider + mock_session.scalar.side_effect = [mock_workflow, None, mock_tool_provider] # workflow, no app, tool provider with patch("services.workflow_service.select") as mock_select: mock_stmt = MagicMock() diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py index e80c306854..79a2d30f57 100644 --- a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -32,7 +32,7 @@ class TestDeleteCustomOauthClientParams: result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google") assert result == {"result": "success"} - session.query.return_value.filter_by.return_value.delete.assert_called_once() + session.execute.assert_called_once() class TestListBuiltinToolProviderTools: @@ -111,7 +111,7 @@ class TestIsOauthSystemClientExists: @patch(f"{MODULE}.db") def test_true_when_exists(self, mock_db, mock_session_cls): session = _mock_session(mock_session_cls) - session.query.return_value.filter_by.return_value.first.return_value = MagicMock() + session.scalar.return_value = MagicMock() assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True @@ -119,7 +119,7 @@ class TestIsOauthSystemClientExists: @patch(f"{MODULE}.db") def test_false_when_missing(self, mock_db, mock_session_cls): session = _mock_session(mock_session_cls) - session.query.return_value.filter_by.return_value.first.return_value = None + session.scalar.return_value = None assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False @@ -129,7 +129,7 @@ class TestIsOauthCustomClientEnabled: @patch(f"{MODULE}.db") def test_true_when_enabled(self, mock_db, mock_session_cls): session = _mock_session(mock_session_cls) - session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True) + session.scalar.return_value = MagicMock(enabled=True) assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True @@ -137,7 +137,7 @@ class TestIsOauthCustomClientEnabled: @patch(f"{MODULE}.db") def test_false_when_none(self, mock_db, mock_session_cls): session = _mock_session(mock_session_cls) - session.query.return_value.filter_by.return_value.first.return_value = None + session.scalar.return_value = None assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False @@ -149,7 +149,7 @@ class TestDeleteBuiltinToolProvider: @patch(f"{MODULE}.db") def test_raises_when_not_found(self, mock_db, mock_sm_cls, mock_tm, mock_enc): session = _mock_sessionmaker(mock_sm_cls) - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with pytest.raises(ValueError, match="you have not added provider"): BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id") @@ -161,7 +161,7 @@ class TestDeleteBuiltinToolProvider: def test_deletes_provider_and_clears_cache(self, mock_db, mock_sm_cls, mock_tm, mock_enc): session = _mock_sessionmaker(mock_sm_cls) db_provider = MagicMock() - session.query.return_value.where.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider mock_cache = MagicMock() mock_enc.return_value = (MagicMock(), mock_cache) @@ -177,7 +177,7 @@ class TestSetDefaultProvider: @patch(f"{MODULE}.db") def test_raises_when_not_found(self, mock_db, mock_sm_cls): session = _mock_sessionmaker(mock_sm_cls) - session.query.return_value.filter_by.return_value.first.return_value = None + session.scalar.return_value = None with pytest.raises(ValueError, match="provider not found"): BuiltinToolManageService.set_default_provider("t", "u", "p", "id") @@ -187,7 +187,7 @@ class TestSetDefaultProvider: def test_sets_default_and_clears_old(self, mock_db, mock_sm_cls): session = _mock_sessionmaker(mock_sm_cls) target = MagicMock() - session.query.return_value.filter_by.return_value.first.return_value = target + session.scalar.return_value = target result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id") @@ -200,7 +200,7 @@ class TestUpdateBuiltinToolProvider: @patch(f"{MODULE}.db") def test_raises_when_provider_not_exists(self, mock_db, mock_sm_cls): session = _mock_sessionmaker(mock_sm_cls) - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None with pytest.raises(ValueError, match="you have not added provider"): BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c") @@ -213,7 +213,7 @@ class TestUpdateBuiltinToolProvider: def test_updates_credentials_and_commits(self, mock_db, mock_sm_cls, mock_tm, mock_cred_type, mock_enc): session = _mock_sessionmaker(mock_sm_cls) db_provider = MagicMock(credential_type="api_key", credentials="{}") - session.query.return_value.where.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider mock_cred_instance = MagicMock() mock_cred_instance.is_editable.return_value = True @@ -274,7 +274,7 @@ class TestGetOauthClient: mock_create_enc.return_value = (mock_encrypter, MagicMock()) user_client = MagicMock(oauth_params='{"encrypted": "data"}') - session.query.return_value.filter_by.return_value.first.return_value = user_client + session.scalar.return_value = user_client result = BuiltinToolManageService.get_oauth_client("t", "google") @@ -297,10 +297,7 @@ class TestGetOauthClient: mock_create_enc.return_value = (MagicMock(), MagicMock()) system_client = MagicMock(encrypted_oauth_params="enc") - session.query.return_value.filter_by.return_value.first.side_effect = [ - None, # user client - system_client, # system client - ] + session.scalar.side_effect = [None, system_client] result = BuiltinToolManageService.get_oauth_client("t", "google") @@ -325,7 +322,7 @@ class TestGetCustomOauthClientParams: @patch(f"{MODULE}.db") def test_returns_empty_when_none(self, mock_db, mock_session_cls): session = _mock_session(mock_session_cls) - session.query.return_value.filter_by.return_value.first.return_value = None + session.scalar.return_value = None result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p") @@ -391,7 +388,7 @@ class TestGetBuiltinProvider: session = _mock_session(mock_session_cls) mock_prov_id.return_value.provider_name = "google" mock_prov_id.return_value.organization = "langgenius" - session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + session.scalar.return_value = None result = BuiltinToolManageService.get_builtin_provider("google", "t") @@ -417,7 +414,7 @@ class TestGetBuiltinProvider: return m mock_prov_id.side_effect = prov_id_side_effect - session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider result = BuiltinToolManageService.get_builtin_provider("google", "t") @@ -439,7 +436,7 @@ class TestGetBuiltinProvider: mock_prov_id.side_effect = prov_id_side_effect db_provider = MagicMock(provider="third-party/custom/custom-tool") - session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t") @@ -452,7 +449,7 @@ class TestGetBuiltinProvider: session = _mock_session(mock_session_cls) mock_prov_id.side_effect = Exception("parse error") fallback = MagicMock() - session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback + session.scalar.return_value = fallback result = BuiltinToolManageService.get_builtin_provider("old-provider", "t") diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 34e474c921..5dad58b8f1 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -82,8 +82,8 @@ def mock_db_session(): """Mock session_factory.create_session() to return a session whose queries use shared test data. Tests set session._shared_data = {"dataset": , "documents": [, ...]} - This fixture makes session.query(Dataset).first() return the shared dataset, - and session.query(Document).all()/first() return from the shared documents. + This fixture makes session.scalar(select(Dataset)...) return the shared dataset, + and session.scalars(select(Document)...).all() return the shared documents. """ with patch("tasks.document_indexing_task.session_factory") as mock_sf: session = MagicMock() @@ -92,93 +92,68 @@ def mock_db_session(): # Keep a pointer so repeated Document.first() calls iterate across provided docs session._doc_first_idx = 0 - def _query_side_effect(model): - q = MagicMock() + def _get_entity(stmt) -> type | None: + """Extract the mapped entity class from a SQLAlchemy select statement.""" + try: + descs = stmt.column_descriptions + if descs: + return descs[0].get("entity") + except (AttributeError, TypeError): + pass + return None - # Capture filters passed via where(...) so first()/all() can honor them. - q._filters = {} + def _extract_id_from_where(stmt) -> str | None: + """Return the value bound to the 'id' column in the WHERE clause, if present.""" + try: + where = stmt.whereclause + if where is None: + return None + # Both single-clause and AND-clause-list cases + clauses = list(getattr(where, "clauses", [where])) + for clause in clauses: + left = getattr(clause, "left", None) + right = getattr(clause, "right", None) + if left is not None and right is not None: + if getattr(left, "key", None) == "id": + return getattr(right, "value", None) + except Exception: + pass + return None - def _extract_filters(*conds, **kw): - # Support both SQLAlchemy expressions (BinaryExpression) and kwargs - # We only need the simple fields used by production code: id, dataset_id, and id.in_(...) - for cond in conds: - left = getattr(cond, "left", None) - right = getattr(cond, "right", None) - key = None - if left is not None: - key = getattr(left, "key", None) or getattr(left, "name", None) - if not key: - continue - # Right side might be a BindParameter with .value, or a raw value/sequence - val = getattr(right, "value", right) - q._filters[key] = val - # Also accept kwargs (e.g., where(id=...)) just in case - for k, v in kw.items(): - q._filters[k] = v - - def _where_side_effect(*conds, **kw): - _extract_filters(*conds, **kw) - return q - - q.where.side_effect = _where_side_effect - - # Dataset queries - if model.__name__ == "Dataset": - - def _dataset_first(): - ds = session._shared_data.get("dataset") - if not ds: - return None - if "id" in q._filters: - val = q._filters["id"] - if isinstance(val, (list, tuple, set)): - return ds if ds.id in val else None - return ds if ds.id == val else None - return ds - - def _dataset_all(): - ds = session._shared_data.get("dataset") - if not ds: - return [] - first = _dataset_first() - return [first] if first else [] - - q.first.side_effect = _dataset_first - q.all.side_effect = _dataset_all - return q - - # Document queries - if model.__name__ == "Document": - - def _apply_doc_filters(docs): - result = list(docs) - for key in ("id", "dataset_id"): - if key in q._filters: - val = q._filters[key] - if isinstance(val, (list, tuple, set)): - result = [d for d in result if getattr(d, key, None) in val] - else: - result = [d for d in result if getattr(d, key, None) == val] - return result - - def _docs_all(): + def _scalar_side_effect(stmt): + entity = _get_entity(stmt) + if entity is not None: + if entity.__name__ == "Dataset": + return session._shared_data.get("dataset") + elif entity.__name__ == "Document": docs = session._shared_data.get("documents", []) - return _apply_doc_filters(docs) + if not docs: + return None + # When the WHERE clause filters by id, return the matching document + queried_id = _extract_id_from_where(stmt) + if queried_id: + doc_map = {d.id: d for d in docs} + return doc_map.get(queried_id, docs[0]) + return docs[0] + return None - def _docs_first(): - docs = _docs_all() - return docs[0] if docs else None + def _scalars_side_effect(stmt): + entity = _get_entity(stmt) + result = MagicMock() + if entity is not None: + if entity.__name__ == "Document": + result.all.return_value = list(session._shared_data.get("documents", [])) + elif entity.__name__ == "Dataset": + ds = session._shared_data.get("dataset") + result.all.return_value = [ds] if ds else [] + else: + result.all.return_value = [] + else: + result.all.return_value = [] + return result - q.all.side_effect = _docs_all - q.first.side_effect = _docs_first - return q - - # Default fallback - q.first.return_value = None - q.all.return_value = [] - return q - - session.query.side_effect = _query_side_effect + session.scalar.side_effect = _scalar_side_effect + session.scalars.side_effect = _scalars_side_effect # Implement session.begin() context manager that commits on exit session.commit = MagicMock() @@ -638,8 +613,6 @@ class TestProgressTracking: wrapper = TaskWrapper(data=next_task_data) mock_redis.rpop.return_value = wrapper.serialize() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -662,7 +635,6 @@ class TestProgressTracking: """ # Arrange mock_redis.rpop.return_value = None # No more tasks - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -780,8 +752,7 @@ class TestErrorHandling: If the dataset doesn't exist, the task should exit gracefully. """ - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = None + # Arrange - dataset is not in _shared_data (None by default), so scalar() returns None # Act _document_indexing(dataset_id, document_ids) @@ -806,8 +777,6 @@ class TestErrorHandling: # Set up rpop to return task once for concurrency check mock_redis.rpop.side_effect = [wrapper.serialize(), None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - # Make _document_indexing raise an error with patch("tasks.document_indexing_task._document_indexing") as mock_indexing: mock_indexing.side_effect = Exception("Processing failed") @@ -844,7 +813,7 @@ class TestErrorHandling: # Mock rpop to return tasks one by one mock_redis.rpop.side_effect = tasks[:concurrency_limit] + [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: @@ -977,7 +946,7 @@ class TestAdvancedScenarios: # Mock rpop to return tasks up to concurrency limit mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: @@ -1070,7 +1039,7 @@ class TestAdvancedScenarios: # Mock rpop to return tasks in FIFO order mock_redis.rpop.side_effect = tasks + [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3): with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: @@ -1108,7 +1077,7 @@ class TestAdvancedScenarios: """ # Arrange mock_redis.rpop.return_value = None # Empty queue - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: # Act @@ -1276,7 +1245,7 @@ class TestIntegration: # First call returns task 2, second call returns None mock_redis.rpop.side_effect = [wrapper.serialize(), None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1433,7 +1402,7 @@ class TestPerformanceScenarios: # Mock rpop to return tasks up to concurrency limit mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db_session._shared_data["dataset"] = mock_dataset with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: @@ -1536,10 +1505,8 @@ class TestDocumentIndexingTaskSummaryFlow: """Test early return when dataset does not exist.""" # Arrange session = MagicMock() - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = None - session.query.side_effect = lambda model: dataset_query + session = MagicMock() + session.scalar.return_value = None # dataset not found create_session_mock = MagicMock(return_value=_SessionContext(session)) monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock) @@ -1560,16 +1527,15 @@ class TestDocumentIndexingTaskSummaryFlow: dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") document = SimpleNamespace(id="doc-1", indexing_status=None, error=None, stopped_at=None) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.first.return_value = document - session = MagicMock() - session.query.side_effect = lambda model: dataset_query if model is Dataset else document_query + + def _scalar_se(stmt): + entity = stmt.column_descriptions[0].get("entity") + if entity is Dataset: + return dataset + return document + + session.scalar.side_effect = _scalar_se monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1643,9 +1609,12 @@ class TestDocumentIndexingTaskSummaryFlow: session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: phase1_document_query - session3.query.side_effect = lambda model: summary_document_query if model is Document else dataset_query + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=phase1_docs)) + session3.scalar.return_value = dataset + session3.scalars.return_value = MagicMock( + all=MagicMock(return_value=[doc_eligible, doc_skip_form, doc_skip_status]) + ) create_session_mock = MagicMock( side_effect=[_SessionContext(session1), _SessionContext(session2), _SessionContext(session3)] @@ -1704,9 +1673,11 @@ class TestDocumentIndexingTaskSummaryFlow: session2 = MagicMock() session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: phase1_query - session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query + + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) + session3.scalar.return_value = dataset + session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[doc_eligible])) monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1736,21 +1707,14 @@ class TestDocumentIndexingTaskSummaryFlow: """Test early return when dataset is missing after indexing.""" # Arrange dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.side_effect = [dataset, None] - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.all.return_value = [SimpleNamespace(id="doc-1")] session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: document_query - session3.query.side_effect = lambda model: dataset_query + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) + session3.scalar.return_value = None # dataset not found on second query monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1770,7 +1734,7 @@ class TestDocumentIndexingTaskSummaryFlow: _document_indexing("dataset-1", ["doc-1"]) # Assert - session3.query.assert_called() + session3.scalar.assert_called() def test_should_skip_summary_when_not_high_quality(self, monkeypatch: pytest.MonkeyPatch) -> None: """Test summary generation skipped when indexing_technique is not high_quality.""" @@ -1781,21 +1745,14 @@ class TestDocumentIndexingTaskSummaryFlow: indexing_technique="economy", summary_index_setting={"enable": True}, ) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.all.return_value = [SimpleNamespace(id="doc-1")] - session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: document_query - session3.query.side_effect = lambda model: dataset_query + + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) + session3.scalar.return_value = dataset monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1824,19 +1781,12 @@ class TestDocumentIndexingTaskSummaryFlow: """Test summary generation is skipped when indexing is paused.""" # Arrange dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.all.return_value = [SimpleNamespace(id="doc-1")] session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: document_query + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) create_session_mock = MagicMock(side_effect=[_SessionContext(session1), _SessionContext(session2)]) monkeypatch.setattr("tasks.document_indexing_task.session_factory.create_session", create_session_mock) @@ -1865,19 +1815,12 @@ class TestDocumentIndexingTaskSummaryFlow: """Test generic indexing runner exception is handled.""" # Arrange dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - document_query = MagicMock() - document_query.where.return_value = document_query - document_query.all.return_value = [SimpleNamespace(id="doc-1")] session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: document_query + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", @@ -1922,25 +1865,15 @@ class TestDocumentIndexingTaskSummaryFlow: indexing_technique="high_quality", summary_index_setting={"enable": True}, ) - dataset_query = MagicMock() - dataset_query.where.return_value = dataset_query - dataset_query.first.return_value = dataset - - phase1_query = MagicMock() - phase1_query.where.return_value = phase1_query - phase1_query.all.return_value = [SimpleNamespace(id="doc-1")] - - summary_query = MagicMock() - summary_query.where.return_value = summary_query - summary_query.all.return_value = [_FalseyDocument("missing-doc")] - session1 = MagicMock() session2 = MagicMock() session2.begin.return_value = nullcontext() session3 = MagicMock() - session1.query.side_effect = lambda model: dataset_query - session2.query.side_effect = lambda model: phase1_query - session3.query.side_effect = lambda model: summary_query if model is Document else dataset_query + + session1.scalar.return_value = dataset + session2.scalars.return_value = MagicMock(all=MagicMock(return_value=[SimpleNamespace(id="doc-1")])) + session3.scalar.return_value = dataset + session3.scalars.return_value = MagicMock(all=MagicMock(return_value=[_FalseyDocument("missing-doc")])) monkeypatch.setattr( "tasks.document_indexing_task.session_factory.create_session", diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index 0ed4ca05fa..626d1ee0a8 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, call, patch import pytest from libs.archive_storage import ArchiveStorageNotConfiguredError -from models.workflow import WorkflowArchiveLog from tasks.remove_app_and_related_data_task import ( _delete_app_workflow_archive_logs, _delete_archived_workflow_run_files, @@ -83,16 +82,11 @@ class TestDeleteWorkflowArchiveLogs: assert params == {"tenant_id": tenant_id, "app_id": app_id} assert name == "workflow archive log" - mock_query = MagicMock() - mock_delete_query = MagicMock() - mock_query.where.return_value = mock_delete_query - mock_db.session.query.return_value = mock_query + mock_session = MagicMock() - delete_func(mock_db.session, "log-1") + delete_func(mock_session, "log-1") - mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) - mock_query.where.assert_called_once() - mock_delete_query.delete.assert_called_once_with(synchronize_session=False) + mock_session.execute.assert_called_once() class TestDeleteArchivedWorkflowRunFiles: diff --git a/api/uv.lock b/api/uv.lock index b67646cb71..b5d0404693 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -195,7 +195,7 @@ wheels = [ [[package]] name = "aliyun-log-python-sdk" -version = "0.9.37" +version = "0.9.44" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dateparser" }, @@ -207,7 +207,7 @@ dependencies = [ { name = "requests" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/90/70/291d494619bb7b0cbcc00689ad995945737c2c9e0bff2733e0aa7dbaee14/aliyun_log_python_sdk-0.9.37.tar.gz", hash = "sha256:ea65c9cca3a7377cef87d568e897820338328a53a7acb1b02f1383910e103f68", size = 152549, upload-time = "2025-11-27T07:56:06.098Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/5c/f4076b129fe9168f5424f9d89afc587baf36a844f4ae7b619a951a97c76c/aliyun_log_python_sdk-0.9.44.tar.gz", hash = "sha256:891d0ba91cdce8e5e6b430a50512e092751621680bc9f0b7c7325aaa7c1944f1", size = 154147, upload-time = "2026-03-30T08:40:59.04Z" } [[package]] name = "aliyun-python-sdk-core" @@ -286,14 +286,14 @@ wheels = [ [[package]] name = "apscheduler" -version = "3.11.1" +version = "3.11.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "tzlocal" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d0/81/192db4f8471de5bc1f0d098783decffb1e6e69c4f8b4bc6711094691950b/apscheduler-3.11.1.tar.gz", hash = "sha256:0db77af6400c84d1747fe98a04b8b58f0080c77d11d338c4f507a9752880f221", size = 108044, upload-time = "2025-10-31T18:55:42.819Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/12/3e4389e5920b4c1763390c6d371162f3784f86f85cd6d6c1bfe68eef14e2/apscheduler-3.11.2.tar.gz", hash = "sha256:2a9966b052ec805f020c8c4c3ae6e6a06e24b1bf19f2e11d91d8cca0473eef41", size = 108683, upload-time = "2025-12-22T00:39:34.884Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/9f/d3c76f76c73fcc959d28e9def45b8b1cc3d7722660c5003b19c1022fd7f4/apscheduler-3.11.1-py3-none-any.whl", hash = "sha256:6162cb5683cb09923654fa9bdd3130c4be4bfda6ad8990971c9597ecd52965d2", size = 64278, upload-time = "2025-10-31T18:55:41.186Z" }, + { url = "https://files.pythonhosted.org/packages/9f/64/2e54428beba8d9992aa478bb8f6de9e4ecaa5f8f513bcfd567ed7fb0262d/apscheduler-3.11.2-py3-none-any.whl", hash = "sha256:ce005177f741409db4e4dd40a7431b76feb856b9dd69d57e0da49d6715bfd26d", size = 64439, upload-time = "2025-12-22T00:39:33.303Z" }, ] [[package]] @@ -437,7 +437,7 @@ wheels = [ [[package]] name = "bce-python-sdk" -version = "0.9.68" +version = "0.9.69" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "crc32c" }, @@ -445,9 +445,9 @@ dependencies = [ { name = "pycryptodome" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ca/7c/8b4d9128e571f898f9f177dc9f41e31692d8ddb08a963b0c576f219d1304/bce_python_sdk-0.9.68.tar.gz", hash = "sha256:adf182868ed25e53cc3c1573dad9a2b1e9b72ed1ffd0d3ef326f5fa93da7cfa6", size = 296349, upload-time = "2026-03-30T02:57:32.948Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/9c/8fdaf7f9259002b5aa9101bb88252f6d05f65c4535bbca567457da84d765/bce_python_sdk-0.9.69.tar.gz", hash = "sha256:2aaa9f4fc118b3efb720a66d7a541789b7d838a1ddacb9f3c6faa6b75e1c7d23", size = 300008, upload-time = "2026-04-10T08:13:29.769Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fa/4e/eaaba9264667d675c3de76485dc511f0f233c31bada8752411f7fc5170be/bce_python_sdk-0.9.68-py3-none-any.whl", hash = "sha256:fcb484db4a54aa2c4675834c10bc6c37d42929fd138faaf6c01f933d8fa927ed", size = 411932, upload-time = "2026-03-30T02:57:27.847Z" }, + { url = "https://files.pythonhosted.org/packages/ca/3b/41c2985d1b3b3bd5cdf103b4156b08320268ee7a0617f2a40c34fdd377e9/bce_python_sdk-0.9.69-py3-none-any.whl", hash = "sha256:50fb94833b5f4931255296396081b85143101bd9a7a894efbf20d1f759779de5", size = 415659, upload-time = "2026-04-10T08:13:27.958Z" }, ] [[package]] @@ -551,29 +551,29 @@ wheels = [ [[package]] name = "boto3" -version = "1.42.83" +version = "1.42.88" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/87/1ed88eaa1e814841a37e71fee74c2b74341d14b791c0c6038b7ba914bea1/boto3-1.42.83.tar.gz", hash = "sha256:cc5621e603982cb3145b7f6c9970e02e297a1a0eb94637cc7f7b69d3017640ee", size = 112719, upload-time = "2026-04-03T19:34:21.254Z" } +sdist = { url = "https://files.pythonhosted.org/packages/da/bb/7d4435cca6fccf235dd40c891c731bcb9078e815917b57ebadd1e0ffabaf/boto3-1.42.88.tar.gz", hash = "sha256:2d22c70de5726918676a06f1a03acfb4d5d9ea92fc759354800b67b22aaeef19", size = 113238, upload-time = "2026-04-10T19:41:06.912Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/b1/8a066bc8f02937d49783c0b3948ab951d8284e6fde436cab9f359dbd4d93/boto3-1.42.83-py3-none-any.whl", hash = "sha256:544846fdb10585bb7837e409868e8e04c6b372fa04479ba1597ce82cf1242076", size = 140555, upload-time = "2026-04-03T19:34:17.935Z" }, + { url = "https://files.pythonhosted.org/packages/0a/2b/8bfddb39a19f5fbc16a869f1a394771e6223f07160dbc0ff6b38e05ea0ae/boto3-1.42.88-py3-none-any.whl", hash = "sha256:2d0f52c971503377e4370d2a83edee6f077ddb8e684366ff38df4f13581d9cfc", size = 140557, upload-time = "2026-04-10T19:41:05.309Z" }, ] [[package]] name = "boto3-stubs" -version = "1.42.83" +version = "1.42.88" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore-stubs" }, { name = "types-s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2d/fe/6c43a048074d8567db38befe51bf0b770e8456aa2b91ce8fe6758f29ec3d/boto3_stubs-1.42.83.tar.gz", hash = "sha256:1ecbd88f4ae35764b9ea3579ca1e851b67ea0a73a442cb406de277fc1478daeb", size = 102188, upload-time = "2026-04-03T19:54:20.613Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c2/c7/d4dfbb4757cd72fd350ba666902ec3ac19e04d6be639e96cdad4543d4726/boto3_stubs-1.42.88.tar.gz", hash = "sha256:85215fb4938a94d1cf83cd8632f46ae7728b5ec88187d83468f393bbe64236d6", size = 102495, upload-time = "2026-04-10T19:55:57.526Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9c/4d/eee0444fd466ebe69fdb61cc1f24b97d8e21e9e545865f7c1d846294a413/boto3_stubs-1.42.83-py3-none-any.whl", hash = "sha256:06185ca5f11a1edc880286f5f33779a2b08be356bf270bf1ec128d0819782a20", size = 70448, upload-time = "2026-04-03T19:54:16.315Z" }, + { url = "https://files.pythonhosted.org/packages/b4/6f/3befd72080aedbb4ad26b353a6e364645668664930ce49668fd0bab8f2b5/boto3_stubs-1.42.88-py3-none-any.whl", hash = "sha256:9e74350715ca8ccd63fc250f8eca9fa3161b3d1704339554344d72e4e21c5ed1", size = 70603, upload-time = "2026-04-10T19:55:49.921Z" }, ] [package.optional-dependencies] @@ -583,16 +583,16 @@ bedrock-runtime = [ [[package]] name = "botocore" -version = "1.42.83" +version = "1.42.88" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4e/01/b46a3f8b6e9362258f78f1890db1a96d4ed73214d6a36420dc158dcfd221/botocore-1.42.83.tar.gz", hash = "sha256:34bc8cb64b17ac17f8901f073fe4fc9572a5cac9393a37b2b3ea372a83b87f4a", size = 15140337, upload-time = "2026-04-03T19:34:08.779Z" } +sdist = { url = "https://files.pythonhosted.org/packages/93/50/87966238f7aa3f7e5f87081185d5a407a95ede8b551e11bbe134ca3306dc/botocore-1.42.88.tar.gz", hash = "sha256:cbb59ee464662039b0c2c95a520cdf85b1e8ce00b72375ab9cd9f842cc001301", size = 15195331, upload-time = "2026-04-10T19:40:57.012Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a3/97/0d6f50822dc8c1df7f3eadb0bc6822fc0f98f02287c4efc7c7c88fde129a/botocore-1.42.83-py3-none-any.whl", hash = "sha256:ec0c3ecb3772936ed22a3bdda09883b34858933f71004686d460d829bab39d8e", size = 14818388, upload-time = "2026-04-03T19:34:03.333Z" }, + { url = "https://files.pythonhosted.org/packages/2a/46/ad14e41245adb8b0c83663ba13e822b68a0df08999dd250e75b0750fdf6c/botocore-1.42.88-py3-none-any.whl", hash = "sha256:032375b213305b6b81eedb269eaeefdf96f674620799bbf96117dca86052cc1a", size = 14876640, upload-time = "2026-04-10T19:40:53.663Z" }, ] [[package]] @@ -687,11 +687,11 @@ wheels = [ [[package]] name = "cachetools" -version = "5.3.3" +version = "7.0.5" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b3/4d/27a3e6dd09011649ad5210bdf963765bc8fa81a0827a4fc01bafd2705c5b/cachetools-5.3.3.tar.gz", hash = "sha256:ba29e2dfa0b8b556606f097407ed1aa62080ee108ab0dc5ec9d6a723a007d105", size = 26522, upload-time = "2024-02-26T20:33:23.386Z" } +sdist = { url = "https://files.pythonhosted.org/packages/af/dd/57fe3fdb6e65b25a5987fd2cdc7e22db0aef508b91634d2e57d22928d41b/cachetools-7.0.5.tar.gz", hash = "sha256:0cd042c24377200c1dcd225f8b7b12b0ca53cc2c961b43757e774ebe190fd990", size = 37367, upload-time = "2026-03-09T20:51:29.451Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fb/2b/a64c2d25a37aeb921fddb929111413049fc5f8b9a4c1aefaffaafe768d54/cachetools-5.3.3-py3-none-any.whl", hash = "sha256:0abad1021d3f8325b2fc1d2e9c8b9c9d57b04c3932657a72465447332c24d945", size = 9325, upload-time = "2024-02-26T20:33:20.308Z" }, + { url = "https://files.pythonhosted.org/packages/06/f3/39cf3367b8107baa44f861dc802cbf16263c945b62d8265d36034fc07bea/cachetools-7.0.5-py3-none-any.whl", hash = "sha256:46bc8ebefbe485407621d0a4264b23c080cedd913921bad7ac3ed2f26c183114", size = 13918, upload-time = "2026-03-09T20:51:27.33Z" }, ] [[package]] @@ -705,7 +705,7 @@ wheels = [ [[package]] name = "celery" -version = "5.6.2" +version = "5.6.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "billiard" }, @@ -718,9 +718,9 @@ dependencies = [ { name = "tzlocal" }, { name = "vine" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8f/9d/3d13596519cfa7207a6f9834f4b082554845eb3cd2684b5f8535d50c7c44/celery-5.6.2.tar.gz", hash = "sha256:4a8921c3fcf2ad76317d3b29020772103581ed2454c4c042cc55dcc43585009b", size = 1718802, upload-time = "2026-01-04T12:35:58.012Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/b4/a1233943ab5c8ea05fb877a88a0a0622bf47444b99e4991a8045ac37ea1d/celery-5.6.3.tar.gz", hash = "sha256:177006bd2054b882e9f01be59abd8529e88879ef50d7918a7050c5a9f4e12912", size = 1742243, upload-time = "2026-03-26T12:14:51.76Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/dd/bd/9ecd619e456ae4ba73b6583cc313f26152afae13e9a82ac4fe7f8856bfd1/celery-5.6.2-py3-none-any.whl", hash = "sha256:3ffafacbe056951b629c7abcf9064c4a2366de0bdfc9fdba421b97ebb68619a5", size = 445502, upload-time = "2026-01-04T12:35:55.894Z" }, + { url = "https://files.pythonhosted.org/packages/cf/c9/6eccdda96e098f7ae843162db2d3c149c6931a24fda69fe4ab84d0027eb5/celery-5.6.3-py3-none-any.whl", hash = "sha256:0808f42f80909c4d5833202360ffafb2a4f83f4d8e23e1285d926610e9a7afa6", size = 451235, upload-time = "2026-03-26T12:14:49.491Z" }, ] [[package]] @@ -778,27 +778,27 @@ wheels = [ [[package]] name = "charset-normalizer" -version = "3.4.4" +version = "3.4.7" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/a1/67fe25fac3c7642725500a3f6cfe5821ad557c3abb11c9d20d12c7008d3e/charset_normalizer-3.4.7.tar.gz", hash = "sha256:ae89db9e5f98a11a4bf50407d4363e7b09b31e55bc117b4f7d80aab97ba009e5", size = 144271, upload-time = "2026-04-02T09:28:39.342Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" }, - { url = "https://files.pythonhosted.org/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" }, - { url = "https://files.pythonhosted.org/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" }, - { url = "https://files.pythonhosted.org/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" }, - { url = "https://files.pythonhosted.org/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" }, - { url = "https://files.pythonhosted.org/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" }, - { url = "https://files.pythonhosted.org/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" }, - { url = "https://files.pythonhosted.org/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" }, - { url = "https://files.pythonhosted.org/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" }, - { url = "https://files.pythonhosted.org/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" }, - { url = "https://files.pythonhosted.org/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" }, - { url = "https://files.pythonhosted.org/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" }, - { url = "https://files.pythonhosted.org/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" }, - { url = "https://files.pythonhosted.org/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" }, - { url = "https://files.pythonhosted.org/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" }, - { url = "https://files.pythonhosted.org/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" }, - { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, + { url = "https://files.pythonhosted.org/packages/0c/eb/4fc8d0a7110eb5fc9cc161723a34a8a6c200ce3b4fbf681bc86feee22308/charset_normalizer-3.4.7-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:eca9705049ad3c7345d574e3510665cb2cf844c2f2dcfe675332677f081cbd46", size = 311328, upload-time = "2026-04-02T09:26:24.331Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e3/0fadc706008ac9d7b9b5be6dc767c05f9d3e5df51744ce4cc9605de7b9f4/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6178f72c5508bfc5fd446a5905e698c6212932f25bcdd4b47a757a50605a90e2", size = 208061, upload-time = "2026-04-02T09:26:25.568Z" }, + { url = "https://files.pythonhosted.org/packages/42/f0/3dd1045c47f4a4604df85ec18ad093912ae1344ac706993aff91d38773a2/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e1421b502d83040e6d7fb2fb18dff63957f720da3d77b2fbd3187ceb63755d7b", size = 229031, upload-time = "2026-04-02T09:26:26.865Z" }, + { url = "https://files.pythonhosted.org/packages/dc/67/675a46eb016118a2fbde5a277a5d15f4f69d5f3f5f338e5ee2f8948fcf43/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:edac0f1ab77644605be2cbba52e6b7f630731fc42b34cb0f634be1a6eface56a", size = 225239, upload-time = "2026-04-02T09:26:28.044Z" }, + { url = "https://files.pythonhosted.org/packages/4b/f8/d0118a2f5f23b02cd166fa385c60f9b0d4f9194f574e2b31cef350ad7223/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5649fd1c7bade02f320a462fdefd0b4bd3ce036065836d4f42e0de958038e116", size = 216589, upload-time = "2026-04-02T09:26:29.239Z" }, + { url = "https://files.pythonhosted.org/packages/b1/f1/6d2b0b261b6c4ceef0fcb0d17a01cc5bc53586c2d4796fa04b5c540bc13d/charset_normalizer-3.4.7-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:203104ed3e428044fd943bc4bf45fa73c0730391f9621e37fe39ecf477b128cb", size = 202733, upload-time = "2026-04-02T09:26:30.5Z" }, + { url = "https://files.pythonhosted.org/packages/6f/c0/7b1f943f7e87cc3db9626ba17807d042c38645f0a1d4415c7a14afb5591f/charset_normalizer-3.4.7-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:298930cec56029e05497a76988377cbd7457ba864beeea92ad7e844fe74cd1f1", size = 212652, upload-time = "2026-04-02T09:26:31.709Z" }, + { url = "https://files.pythonhosted.org/packages/38/dd/5a9ab159fe45c6e72079398f277b7d2b523e7f716acc489726115a910097/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:708838739abf24b2ceb208d0e22403dd018faeef86ddac04319a62ae884c4f15", size = 211229, upload-time = "2026-04-02T09:26:33.282Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ff/531a1cad5ca855d1c1a8b69cb71abfd6d85c0291580146fda7c82857caa1/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:0f7eb884681e3938906ed0434f20c63046eacd0111c4ba96f27b76084cd679f5", size = 203552, upload-time = "2026-04-02T09:26:34.845Z" }, + { url = "https://files.pythonhosted.org/packages/c1/4c/a5fb52d528a8ca41f7598cb619409ece30a169fbdf9cdce592e53b46c3a6/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4dc1e73c36828f982bfe79fadf5919923f8a6f4df2860804db9a98c48824ce8d", size = 230806, upload-time = "2026-04-02T09:26:36.152Z" }, + { url = "https://files.pythonhosted.org/packages/59/7a/071feed8124111a32b316b33ae4de83d36923039ef8cf48120266844285b/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:aed52fea0513bac0ccde438c188c8a471c4e0f457c2dd20cdbf6ea7a450046c7", size = 212316, upload-time = "2026-04-02T09:26:37.672Z" }, + { url = "https://files.pythonhosted.org/packages/fd/35/f7dba3994312d7ba508e041eaac39a36b120f32d4c8662b8814dab876431/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:fea24543955a6a729c45a73fe90e08c743f0b3334bbf3201e6c4bc1b0c7fa464", size = 227274, upload-time = "2026-04-02T09:26:38.93Z" }, + { url = "https://files.pythonhosted.org/packages/8a/2d/a572df5c9204ab7688ec1edc895a73ebded3b023bb07364710b05dd1c9be/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb6d88045545b26da47aa879dd4a89a71d1dce0f0e549b1abcb31dfe4a8eac49", size = 218468, upload-time = "2026-04-02T09:26:40.17Z" }, + { url = "https://files.pythonhosted.org/packages/86/eb/890922a8b03a568ca2f336c36585a4713c55d4d67bf0f0c78924be6315ca/charset_normalizer-3.4.7-cp312-cp312-win32.whl", hash = "sha256:2257141f39fe65a3fdf38aeccae4b953e5f3b3324f4ff0daf9f15b8518666a2c", size = 148460, upload-time = "2026-04-02T09:26:41.416Z" }, + { url = "https://files.pythonhosted.org/packages/35/d9/0e7dffa06c5ab081f75b1b786f0aefc88365825dfcd0ac544bdb7b2b6853/charset_normalizer-3.4.7-cp312-cp312-win_amd64.whl", hash = "sha256:5ed6ab538499c8644b8a3e18debabcd7ce684f3fa91cf867521a7a0279cab2d6", size = 159330, upload-time = "2026-04-02T09:26:42.554Z" }, + { url = "https://files.pythonhosted.org/packages/9e/5d/481bcc2a7c88ea6b0878c299547843b2521ccbc40980cb406267088bc701/charset_normalizer-3.4.7-cp312-cp312-win_arm64.whl", hash = "sha256:56be790f86bfb2c98fb742ce566dfb4816e5a83384616ab59c49e0604d49c51d", size = 147828, upload-time = "2026-04-02T09:26:44.075Z" }, + { url = "https://files.pythonhosted.org/packages/db/8f/61959034484a4a7c527811f4721e75d02d653a35afb0b6054474d8185d4c/charset_normalizer-3.4.7-py3-none-any.whl", hash = "sha256:3dce51d0f5e7951f8bb4900c257dad282f49190fdbebecd4ba99bcc41fef404d", size = 61958, upload-time = "2026-04-02T09:28:37.794Z" }, ] [[package]] @@ -950,7 +950,7 @@ wheels = [ [[package]] name = "clickzetta-connector-python" -version = "0.8.106" +version = "0.8.104" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, @@ -964,7 +964,7 @@ dependencies = [ { name = "urllib3" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/23/38/749c708619f402d4d582dfa73fbeb64ade77b1f250a93bd064d2a1aa3776/clickzetta_connector_python-0.8.106-py3-none-any.whl", hash = "sha256:120d6700051d97609dbd6655c002ab3bc260b7c8e67d39dfc7191e749563f7b4", size = 78121, upload-time = "2025-10-29T02:38:15.014Z" }, + { url = "https://files.pythonhosted.org/packages/8f/94/c7eee2224bdab39d16dfe5bb7687f5525c7ed345b7fe8812e18a2d9a6335/clickzetta_connector_python-0.8.104-py3-none-any.whl", hash = "sha256:ae3e466d990677f96c769ec1c29318237df80c80fe9c1e21ba1eaf42bdef0207", size = 79382, upload-time = "2025-09-10T08:46:39.731Z" }, ] [[package]] @@ -1115,15 +1115,14 @@ sdist = { url = "https://files.pythonhosted.org/packages/6b/b0/e595ce2a2527e169c [[package]] name = "croniter" -version = "6.0.0" +version = "6.2.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "python-dateutil" }, - { name = "pytz" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ad/2f/44d1ae153a0e27be56be43465e5cb39b9650c781e001e7864389deb25090/croniter-6.0.0.tar.gz", hash = "sha256:37c504b313956114a983ece2c2b07790b1f1094fe9d81cc94739214748255577", size = 64481, upload-time = "2024-12-17T17:17:47.32Z" } +sdist = { url = "https://files.pythonhosted.org/packages/df/de/5832661ed55107b8a09af3f0a2e71e0957226a59eb1dcf0a445cce6daf20/croniter-6.2.2.tar.gz", hash = "sha256:ba60832a5ec8e12e51b8691c3309a113d1cf6526bdf1a48150ce8ec7a532d0ab", size = 113762, upload-time = "2026-03-15T08:43:48.112Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/07/4b/290b4c3efd6417a8b0c284896de19b1d5855e6dbdb97d2a35e68fa42de85/croniter-6.0.0-py2.py3-none-any.whl", hash = "sha256:2f878c3856f17896979b2a4379ba1f09c83e374931ea15cc835c5dd2eee9b368", size = 25468, upload-time = "2024-12-17T17:17:45.359Z" }, + { url = "https://files.pythonhosted.org/packages/d0/39/783980e78cb92c2d7bdb1fc7dbc86e94ccc6d58224d76a7f1f51b6c51e30/croniter-6.2.2-py3-none-any.whl", hash = "sha256:a5d17b1060974d36251ea4faf388233eca8acf0d09cbd92d35f4c4ac8f279960", size = 45422, upload-time = "2026-03-15T08:43:46.626Z" }, ] [[package]] @@ -1470,74 +1469,74 @@ vdb = [ [package.metadata] requires-dist = [ - { name = "aliyun-log-python-sdk", specifier = "~=0.9.37" }, - { name = "apscheduler", specifier = ">=3.11.0" }, + { name = "aliyun-log-python-sdk", specifier = "~=0.9.44" }, + { name = "apscheduler", specifier = ">=3.11.2" }, { name = "arize-phoenix-otel", specifier = "~=0.15.0" }, { name = "azure-identity", specifier = "==1.25.3" }, { name = "beautifulsoup4", specifier = "==4.14.3" }, { name = "bleach", specifier = "~=6.3.0" }, - { name = "boto3", specifier = "==1.42.83" }, + { name = "boto3", specifier = "==1.42.88" }, { name = "bs4", specifier = "~=0.0.1" }, - { name = "cachetools", specifier = "~=5.3.0" }, - { name = "celery", specifier = "~=5.6.2" }, - { name = "charset-normalizer", specifier = ">=3.4.4" }, - { name = "croniter", specifier = ">=6.0.0" }, + { name = "cachetools", specifier = "~=7.0.5" }, + { name = "celery", specifier = "~=5.6.3" }, + { name = "charset-normalizer", specifier = ">=3.4.7" }, + { name = "croniter", specifier = ">=6.2.2" }, { name = "fastopenapi", extras = ["flask"], specifier = ">=0.7.0" }, - { name = "flask", specifier = "~=3.1.2" }, - { name = "flask-compress", specifier = ">=1.17,<1.25" }, - { name = "flask-cors", specifier = "~=6.0.0" }, + { name = "flask", specifier = "~=3.1.3" }, + { name = "flask-compress", specifier = ">=1.24,<1.25" }, + { name = "flask-cors", specifier = "~=6.0.2" }, { name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-migrate", specifier = "~=4.1.0" }, { name = "flask-orjson", specifier = "~=2.0.0" }, { name = "flask-restx", specifier = "~=1.3.2" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, - { name = "gevent", specifier = "~=25.9.1" }, + { name = "gevent", specifier = "~=26.4.0" }, { name = "gmpy2", specifier = "~=2.3.0" }, - { name = "google-api-core", specifier = ">=2.19.1" }, - { name = "google-api-python-client", specifier = "==2.193.0" }, - { name = "google-auth", specifier = ">=2.47.0" }, + { name = "google-api-core", specifier = ">=2.30.3" }, + { name = "google-api-python-client", specifier = "==2.194.0" }, + { name = "google-auth", specifier = ">=2.49.2" }, { name = "google-auth-httplib2", specifier = "==0.3.1" }, - { name = "google-cloud-aiplatform", specifier = ">=1.123.0" }, - { name = "googleapis-common-protos", specifier = ">=1.65.0" }, + { name = "google-cloud-aiplatform", specifier = ">=1.147.0" }, + { name = "googleapis-common-protos", specifier = ">=1.74.0" }, { name = "graphon", specifier = ">=0.1.2" }, { name = "gunicorn", specifier = "~=25.3.0" }, - { name = "httpx", extras = ["socks"], specifier = "~=0.28.0" }, + { name = "httpx", extras = ["socks"], specifier = "~=0.28.1" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, - { name = "json-repair", specifier = ">=0.55.1" }, - { name = "langfuse", specifier = ">=3.0.0,<5.0.0" }, - { name = "langsmith", specifier = "~=0.7.16" }, + { name = "json-repair", specifier = ">=0.59.2" }, + { name = "langfuse", specifier = ">=4.2.0,<5.0.0" }, + { name = "langsmith", specifier = "~=0.7.30" }, { name = "litellm", specifier = "==1.83.0" }, { name = "markdown", specifier = "~=3.10.2" }, - { name = "mlflow-skinny", specifier = ">=3.0.0" }, - { name = "numpy", specifier = "~=1.26.4" }, + { name = "mlflow-skinny", specifier = ">=3.11.1" }, + { name = "numpy", specifier = "~=2.4.4" }, { name = "openpyxl", specifier = "~=3.1.5" }, - { name = "opentelemetry-api", specifier = "==1.40.0" }, - { name = "opentelemetry-distro", specifier = "==0.61b0" }, - { name = "opentelemetry-exporter-otlp", specifier = "==1.40.0" }, - { name = "opentelemetry-exporter-otlp-proto-common", specifier = "==1.40.0" }, - { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = "==1.40.0" }, - { name = "opentelemetry-exporter-otlp-proto-http", specifier = "==1.40.0" }, - { name = "opentelemetry-instrumentation", specifier = "==0.61b0" }, - { name = "opentelemetry-instrumentation-celery", specifier = "==0.61b0" }, - { name = "opentelemetry-instrumentation-flask", specifier = "==0.61b0" }, - { name = "opentelemetry-instrumentation-httpx", specifier = "==0.61b0" }, - { name = "opentelemetry-instrumentation-redis", specifier = "==0.61b0" }, - { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.61b0" }, - { name = "opentelemetry-propagator-b3", specifier = "==1.40.0" }, - { name = "opentelemetry-proto", specifier = "==1.40.0" }, - { name = "opentelemetry-sdk", specifier = "==1.40.0" }, - { name = "opentelemetry-semantic-conventions", specifier = "==0.61b0" }, - { name = "opentelemetry-util-http", specifier = "==0.61b0" }, - { name = "opik", specifier = "~=1.10.37" }, - { name = "packaging", specifier = "~=23.2" }, - { name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=3.0.1" }, + { name = "opentelemetry-api", specifier = "==1.41.0" }, + { name = "opentelemetry-distro", specifier = "==0.62b0" }, + { name = "opentelemetry-exporter-otlp", specifier = "==1.41.0" }, + { name = "opentelemetry-exporter-otlp-proto-common", specifier = "==1.41.0" }, + { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = "==1.41.0" }, + { name = "opentelemetry-exporter-otlp-proto-http", specifier = "==1.41.0" }, + { name = "opentelemetry-instrumentation", specifier = "==0.62b0" }, + { name = "opentelemetry-instrumentation-celery", specifier = "==0.62b0" }, + { name = "opentelemetry-instrumentation-flask", specifier = "==0.62b0" }, + { name = "opentelemetry-instrumentation-httpx", specifier = "==0.62b0" }, + { name = "opentelemetry-instrumentation-redis", specifier = "==0.62b0" }, + { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.62b0" }, + { name = "opentelemetry-propagator-b3", specifier = "==1.41.0" }, + { name = "opentelemetry-proto", specifier = "==1.41.0" }, + { name = "opentelemetry-sdk", specifier = "==1.41.0" }, + { name = "opentelemetry-semantic-conventions", specifier = "==0.62b0" }, + { name = "opentelemetry-util-http", specifier = "==0.62b0" }, + { name = "opik", specifier = "~=1.11.2" }, + { name = "packaging", specifier = "~=26.0" }, + { name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=3.0.2" }, { name = "psycogreen", specifier = "~=1.0.2" }, - { name = "psycopg2-binary", specifier = "~=2.9.6" }, + { name = "psycopg2-binary", specifier = "~=2.9.11" }, { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.12.5" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, - { name = "pyjwt", specifier = "~=2.12.0" }, + { name = "pyjwt", specifier = "~=2.12.1" }, { name = "pypandoc", specifier = "~=1.13" }, { name = "pypdfium2", specifier = "==5.6.0" }, { name = "python-docx", specifier = "~=1.2.0" }, @@ -1545,59 +1544,59 @@ requires-dist = [ { name = "pyyaml", specifier = "~=6.0.1" }, { name = "readabilipy", specifier = "~=0.3.0" }, { name = "redis", extras = ["hiredis"], specifier = "~=7.4.0" }, - { name = "resend", specifier = "~=2.26.0" }, - { name = "sendgrid", specifier = "~=6.12.3" }, - { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.55.0" }, - { name = "sqlalchemy", specifier = "~=2.0.29" }, + { name = "resend", specifier = "~=2.27.0" }, + { name = "sendgrid", specifier = "~=6.12.5" }, + { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.57.0" }, + { name = "sqlalchemy", specifier = "~=2.0.49" }, { name = "sseclient-py", specifier = "~=1.9.0" }, { name = "starlette", specifier = "==1.0.0" }, { name = "tiktoken", specifier = "~=0.12.0" }, { name = "transformers", specifier = "~=5.3.0" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.21.5" }, { name = "weave", specifier = ">=0.52.16" }, - { name = "weaviate-client", specifier = "==4.20.4" }, + { name = "weaviate-client", specifier = "==4.20.5" }, { name = "yarl", specifier = "~=1.23.0" }, ] [package.metadata.requires-dev] dev = [ { name = "basedpyright", specifier = "~=1.39.0" }, - { name = "boto3-stubs", specifier = ">=1.38.20" }, + { name = "boto3-stubs", specifier = ">=1.42.88" }, { name = "celery-types", specifier = ">=0.23.0" }, { name = "coverage", specifier = "~=7.13.4" }, { name = "dotenv-linter", specifier = "~=0.7.0" }, - { name = "faker", specifier = "~=40.12.0" }, - { name = "hypothesis", specifier = ">=6.131.15" }, + { name = "faker", specifier = "~=40.13.0" }, + { name = "hypothesis", specifier = ">=6.151.12" }, { name = "import-linter", specifier = ">=2.3" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.20.0" }, { name = "pandas-stubs", specifier = "~=3.0.0" }, { name = "pyrefly", specifier = ">=0.60.0" }, - { name = "pytest", specifier = "~=9.0.2" }, + { name = "pytest", specifier = "~=9.0.3" }, { name = "pytest-benchmark", specifier = "~=5.2.3" }, { name = "pytest-cov", specifier = "~=7.1.0" }, { name = "pytest-env", specifier = "~=1.6.0" }, { name = "pytest-mock", specifier = "~=3.15.1" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, { name = "pytest-xdist", specifier = ">=3.8.0" }, - { name = "ruff", specifier = "~=0.15.5" }, + { name = "ruff", specifier = "~=0.15.10" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, - { name = "testcontainers", specifier = "~=4.14.1" }, + { name = "testcontainers", specifier = "~=4.14.2" }, { name = "types-aiofiles", specifier = "~=25.1.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, { name = "types-cachetools", specifier = "~=6.2.0" }, - { name = "types-cffi", specifier = ">=1.17.0" }, + { name = "types-cffi", specifier = ">=2.0.0.20260408" }, { name = "types-colorama", specifier = "~=0.4.15" }, { name = "types-defusedxml", specifier = "~=0.7.0" }, { name = "types-deprecated", specifier = "~=1.3.1" }, { name = "types-docutils", specifier = "~=0.22.3" }, { name = "types-flask-cors", specifier = "~=6.0.0" }, { name = "types-flask-migrate", specifier = "~=4.1.0" }, - { name = "types-gevent", specifier = "~=25.9.0" }, - { name = "types-greenlet", specifier = "~=3.3.0" }, + { name = "types-gevent", specifier = "~=26.4.0" }, + { name = "types-greenlet", specifier = "~=3.4.0" }, { name = "types-html5lib", specifier = "~=1.1.11" }, - { name = "types-jmespath", specifier = ">=1.0.2.20240106" }, + { name = "types-jmespath", specifier = ">=1.1.0.20260408" }, { name = "types-markdown", specifier = "~=3.10.2" }, { name = "types-oauthlib", specifier = "~=3.3.0" }, { name = "types-objgraph", specifier = "~=3.6.0" }, @@ -1611,25 +1610,25 @@ dev = [ { name = "types-pymysql", specifier = "~=1.1.0" }, { name = "types-pyopenssl", specifier = ">=24.1.0" }, { name = "types-python-dateutil", specifier = "~=2.9.0" }, - { name = "types-python-http-client", specifier = ">=3.3.7.20240910" }, + { name = "types-python-http-client", specifier = ">=3.3.7.20260408" }, { name = "types-pywin32", specifier = "~=311.0.0" }, { name = "types-pyyaml", specifier = "~=6.0.12" }, { name = "types-redis", specifier = ">=4.6.0.20241004" }, { name = "types-regex", specifier = "~=2026.4.4" }, - { name = "types-setuptools", specifier = ">=80.9.0" }, + { name = "types-setuptools", specifier = ">=82.0.0.20260408" }, { name = "types-shapely", specifier = "~=2.1.0" }, - { name = "types-simplejson", specifier = ">=3.20.0" }, - { name = "types-six", specifier = ">=1.17.0" }, - { name = "types-tensorflow", specifier = ">=2.18.0" }, - { name = "types-tqdm", specifier = ">=4.67.0" }, + { name = "types-simplejson", specifier = ">=3.20.0.20260408" }, + { name = "types-six", specifier = ">=1.17.0.20260408" }, + { name = "types-tensorflow", specifier = ">=2.18.0.20260408" }, + { name = "types-tqdm", specifier = ">=4.67.3.20260408" }, { name = "types-ujson", specifier = ">=5.10.0" }, ] storage = [ { name = "azure-storage-blob", specifier = "==12.28.0" }, - { name = "bce-python-sdk", specifier = "~=0.9.23" }, + { name = "bce-python-sdk", specifier = "~=0.9.69" }, { name = "cos-python-sdk-v5", specifier = "==1.9.41" }, { name = "esdk-obs-python", specifier = "==3.26.2" }, - { name = "google-cloud-storage", specifier = ">=3.0.0" }, + { name = "google-cloud-storage", specifier = ">=3.10.1" }, { name = "opendal", specifier = "~=0.46.0" }, { name = "oss2", specifier = "==2.19.1" }, { name = "supabase", specifier = "~=2.18.1" }, @@ -1647,24 +1646,24 @@ vdb = [ { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.6.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, - { name = "holo-search-sdk", specifier = ">=0.4.1" }, + { name = "holo-search-sdk", specifier = ">=0.4.2" }, { name = "intersystems-irispython", specifier = ">=5.1.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, { name = "mysql-connector-python", specifier = ">=9.3.0" }, { name = "opensearch-py", specifier = "==3.1.0" }, { name = "oracledb", specifier = "==3.4.2" }, - { name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" }, + { name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.2" }, { name = "pgvector", specifier = "==0.4.2" }, - { name = "pymilvus", specifier = "~=2.6.10" }, + { name = "pymilvus", specifier = "~=2.6.12" }, { name = "pymochow", specifier = "==2.4.0" }, { name = "pyobvector", specifier = "~=0.2.17" }, { name = "qdrant-client", specifier = "==1.9.0" }, - { name = "tablestore", specifier = "==6.4.3" }, + { name = "tablestore", specifier = "==6.4.4" }, { name = "tcvectordb", specifier = "~=2.1.0" }, { name = "tidb-vector", specifier = "==0.0.15" }, { name = "upstash-vector", specifier = "==0.8.0" }, { name = "volcengine-compat", specifier = "~=1.0.0" }, - { name = "weaviate-client", specifier = "==4.20.4" }, + { name = "weaviate-client", specifier = "==4.20.5" }, { name = "xinference-client", specifier = "~=2.4.0" }, ] @@ -1734,18 +1733,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" }, ] -[[package]] -name = "ecdsa" -version = "0.19.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, -] - [[package]] name = "elastic-transport" version = "8.17.1" @@ -1828,14 +1815,14 @@ wheels = [ [[package]] name = "faker" -version = "40.12.0" +version = "40.13.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "tzdata", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/66/c1/f8224fe97fea2f98d455c22438c1b09b10e14ef2cb95ae4f7cec9aa59659/faker-40.12.0.tar.gz", hash = "sha256:58b5a9054c367bd5fb2e948634105364cc570e78a98a8e5161a74691c45f158f", size = 1962003, upload-time = "2026-03-30T18:00:56.596Z" } +sdist = { url = "https://files.pythonhosted.org/packages/89/95/4822ffe94723553789aef783104f4f18fc20d7c4c68e1bbd633e11d09758/faker-40.13.0.tar.gz", hash = "sha256:a0751c84c3abac17327d7bb4c98e8afe70ebf7821e01dd7d0b15cd8856415525", size = 1962043, upload-time = "2026-04-06T16:44:55.68Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2b/5c/39452a6b6aa76ffa518fa7308e1975b37e9ba77caa6172a69d61e7180221/faker-40.12.0-py3-none-any.whl", hash = "sha256:6238a4058a8b581892e3d78fe5fdfa7568739e1c8283e4ede83f1dde0bfc1a3b", size = 1994601, upload-time = "2026-03-30T18:00:54.804Z" }, + { url = "https://files.pythonhosted.org/packages/da/8a/708103325edff16a0b0e004de0d37db8ba216a32713948c64d71f6d4a4c2/faker-40.13.0-py3-none-any.whl", hash = "sha256:c1298fd0d819b3688fb5fd358c4ba8f56c7c8c740b411fd3dbd8e30bf2c05019", size = 1994597, upload-time = "2026-04-06T16:44:53.698Z" }, ] [[package]] @@ -2086,7 +2073,7 @@ wheels = [ [[package]] name = "gevent" -version = "25.9.1" +version = "26.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation == 'CPython' and sys_platform == 'win32'" }, @@ -2094,16 +2081,16 @@ dependencies = [ { name = "zope-event" }, { name = "zope-interface" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9e/48/b3ef2673ffb940f980966694e40d6d32560f3ffa284ecaeb5ea3a90a6d3f/gevent-25.9.1.tar.gz", hash = "sha256:adf9cd552de44a4e6754c51ff2e78d9193b7fa6eab123db9578a210e657235dd", size = 5059025, upload-time = "2025-09-17T16:15:34.528Z" } +sdist = { url = "https://files.pythonhosted.org/packages/20/27/1062fa31333dc3428a1f5f33cd6598b0552165ba679ca3ba116de42c9e8e/gevent-26.4.0.tar.gz", hash = "sha256:288d03addfccf0d1c67268358b6759b04392bf3bc35d26f3d9a45c82899c292d", size = 6242440, upload-time = "2026-04-09T12:08:19.482Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/49/e55930ba5259629eb28ac7ee1abbca971996a9165f902f0249b561602f24/gevent-25.9.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:46b188248c84ffdec18a686fcac5dbb32365d76912e14fda350db5dc0bfd4f86", size = 2955991, upload-time = "2025-09-17T14:52:30.568Z" }, - { url = "https://files.pythonhosted.org/packages/aa/88/63dc9e903980e1da1e16541ec5c70f2b224ec0a8e34088cb42794f1c7f52/gevent-25.9.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f2b54ea3ca6f0c763281cd3f96010ac7e98c2e267feb1221b5a26e2ca0b9a692", size = 1808503, upload-time = "2025-09-17T15:41:25.59Z" }, - { url = "https://files.pythonhosted.org/packages/7a/8d/7236c3a8f6ef7e94c22e658397009596fa90f24c7d19da11ad7ab3a9248e/gevent-25.9.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:7a834804ac00ed8a92a69d3826342c677be651b1c3cd66cc35df8bc711057aa2", size = 1890001, upload-time = "2025-09-17T15:49:01.227Z" }, - { url = "https://files.pythonhosted.org/packages/4f/63/0d7f38c4a2085ecce26b50492fc6161aa67250d381e26d6a7322c309b00f/gevent-25.9.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:323a27192ec4da6b22a9e51c3d9d896ff20bc53fdc9e45e56eaab76d1c39dd74", size = 1855335, upload-time = "2025-09-17T15:49:20.582Z" }, - { url = "https://files.pythonhosted.org/packages/95/18/da5211dfc54c7a57e7432fd9a6ffeae1ce36fe5a313fa782b1c96529ea3d/gevent-25.9.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6ea78b39a2c51d47ff0f130f4c755a9a4bbb2dd9721149420ad4712743911a51", size = 2109046, upload-time = "2025-09-17T15:15:13.817Z" }, - { url = "https://files.pythonhosted.org/packages/a6/5a/7bb5ec8e43a2c6444853c4a9f955f3e72f479d7c24ea86c95fb264a2de65/gevent-25.9.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:dc45cd3e1cc07514a419960af932a62eb8515552ed004e56755e4bf20bad30c5", size = 1827099, upload-time = "2025-09-17T15:52:41.384Z" }, - { url = "https://files.pythonhosted.org/packages/ca/d4/b63a0a60635470d7d986ef19897e893c15326dd69e8fb342c76a4f07fe9e/gevent-25.9.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34e01e50c71eaf67e92c186ee0196a039d6e4f4b35670396baed4a2d8f1b347f", size = 2172623, upload-time = "2025-09-17T15:24:12.03Z" }, - { url = "https://files.pythonhosted.org/packages/d5/98/caf06d5d22a7c129c1fb2fc1477306902a2c8ddfd399cd26bbbd4caf2141/gevent-25.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:4acd6bcd5feabf22c7c5174bd3b9535ee9f088d2bbce789f740ad8d6554b18f3", size = 1682837, upload-time = "2025-09-17T19:48:47.318Z" }, + { url = "https://files.pythonhosted.org/packages/3d/16/131d3874f50974b355c90a061a12d3fe2292cde0f875a1fa3d8b224f1251/gevent-26.4.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:318a0a73f664113e8d86d0cb0e328e7650e2d7d9c2e045418ab6fb1285831ad3", size = 2928699, upload-time = "2026-04-08T21:25:36.215Z" }, + { url = "https://files.pythonhosted.org/packages/ea/8b/199e59b303adaff7f7365def9ab569c7ecd863363c974548bce3ddc2c89d/gevent-26.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:ce7aa033a3f68beb6732d1450a80c1af29e63e0c2d01abad7918cf2507f72fa6", size = 1783821, upload-time = "2026-04-08T22:23:18.73Z" }, + { url = "https://files.pythonhosted.org/packages/e2/2d/b8249c9bd3f386191311c3a9bec4068e192a3f9df2fad92a71a15265ba15/gevent-26.4.0-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:a1b897c952baefd72232efaeb3bdb1ca2fa7ae94cbfe68ac21201b03e843190a", size = 1879424, upload-time = "2026-04-08T22:27:10.561Z" }, + { url = "https://files.pythonhosted.org/packages/ef/89/59216985c1f2c11f2f28bbc88e583588ad44cdde823c530ad4e307be6612/gevent-26.4.0-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:7eef2ea508ce41795e20587a5fc868ae4919543097c81a40fbdfd65bc479f54f", size = 1830575, upload-time = "2026-04-08T22:34:37.093Z" }, + { url = "https://files.pythonhosted.org/packages/ee/a9/2d67d2b0aa0ca9d7bb7fe73c3bbb97b3695cb15c338a6ea7734f58da9add/gevent-26.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:f7e12fdd28cc9f39a463d8df5172d698c64a8ed385a21d98e7092fd8308a139a", size = 2113898, upload-time = "2026-04-08T21:54:14.9Z" }, + { url = "https://files.pythonhosted.org/packages/95/a3/457d58d9b3e7da17c8456d841c37a32af8d231a1d71237ad201b19129317/gevent-26.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d48e3ee13d7678c24c22f19d441ad6bc220a79f23662d03ff36fae0d62efdb59", size = 1795890, upload-time = "2026-04-08T22:26:53.252Z" }, + { url = "https://files.pythonhosted.org/packages/a7/cc/cbe78f2626643b20275aaa41cd2cc45ba75056e3665bde36bc190af3cae0/gevent-26.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c58c8e034f94329be4dc0979fba3301005a433dbab42cea0b2c33fd736946872", size = 2139791, upload-time = "2026-04-08T22:00:02.375Z" }, + { url = "https://files.pythonhosted.org/packages/f6/df/7875e08b06a95f4577b71708ec470d029fadf873a66eb813a2861d79dfb5/gevent-26.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1c737e6ac6ce1398df0e3f41c58d982e397c993cbe73ac05b7edbe39e128c9cb", size = 1680530, upload-time = "2026-04-08T23:15:38.714Z" }, ] [[package]] @@ -2160,7 +2147,7 @@ wheels = [ [[package]] name = "google-api-core" -version = "2.30.2" +version = "2.30.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-auth" }, @@ -2169,9 +2156,9 @@ dependencies = [ { name = "protobuf" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1a/2e/83ca41eb400eb228f9279ec14ed66f6475218b59af4c6daec2d5a509fe83/google_api_core-2.30.2.tar.gz", hash = "sha256:9a8113e1a88bdc09a7ff629707f2214d98d61c7f6ceb0ea38c42a095d02dc0f9", size = 176862, upload-time = "2026-04-02T21:23:44.876Z" } +sdist = { url = "https://files.pythonhosted.org/packages/16/ce/502a57fb0ec752026d24df1280b162294b22a0afb98a326084f9a979138b/google_api_core-2.30.3.tar.gz", hash = "sha256:e601a37f148585319b26db36e219df68c5d07b6382cff2d580e83404e44d641b", size = 177001, upload-time = "2026-04-10T00:41:28.035Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/84/e1/ebd5100cbb202e561c0c8b59e485ef3bd63fa9beb610f3fdcaea443f0288/google_api_core-2.30.2-py3-none-any.whl", hash = "sha256:a4c226766d6af2580577db1f1a51bf53cd262f722b49731ce7414c43068a9594", size = 173236, upload-time = "2026-04-02T21:23:06.395Z" }, + { url = "https://files.pythonhosted.org/packages/03/15/e56f351cf6ef1cfea58e6ac226a7318ed1deb2218c4b3cc9bd9e4b786c5a/google_api_core-2.30.3-py3-none-any.whl", hash = "sha256:a85761ba72c444dad5d611c2220633480b2b6be2521eca69cca2dbb3ffd6bfe8", size = 173274, upload-time = "2026-04-09T22:57:16.198Z" }, ] [package.optional-dependencies] @@ -2182,7 +2169,7 @@ grpc = [ [[package]] name = "google-api-python-client" -version = "2.193.0" +version = "2.194.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2191,22 +2178,22 @@ dependencies = [ { name = "httplib2" }, { name = "uritemplate" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/90/f4/e14b6815d3b1885328dd209676a3a4c704882743ac94e18ef0093894f5c8/google_api_python_client-2.193.0.tar.gz", hash = "sha256:8f88d16e89d11341e0a8b199cafde0fb7e6b44260dffb88d451577cbd1bb5d33", size = 14281006, upload-time = "2026-03-17T18:25:29.415Z" } +sdist = { url = "https://files.pythonhosted.org/packages/60/ab/e83af0eb043e4ccc49571ca7a6a49984e9d00f4e9e6e6f1238d60bc84dce/google_api_python_client-2.194.0.tar.gz", hash = "sha256:db92647bd1a90f40b79c9618461553c2b20b6a43ce7395fa6de07132dc14f023", size = 14443469, upload-time = "2026-04-08T23:07:35.757Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/6d/fe75167797790a56d17799b75e1129bb93f7ff061efc7b36e9731bd4be2b/google_api_python_client-2.193.0-py3-none-any.whl", hash = "sha256:c42aa324b822109901cfecab5dc4fc3915d35a7b376835233c916c70610322db", size = 14856490, upload-time = "2026-03-17T18:25:26.608Z" }, + { url = "https://files.pythonhosted.org/packages/b0/34/5a624e49f179aa5b0cb87b2ce8093960299030ff40423bfbde09360eb908/google_api_python_client-2.194.0-py3-none-any.whl", hash = "sha256:61eaaac3b8fc8fdf11c08af87abc3d1342d1b37319cc1b57405f86ef7697e717", size = 15016514, upload-time = "2026-04-08T23:07:33.093Z" }, ] [[package]] name = "google-auth" -version = "2.49.1" +version = "2.49.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, { name = "pyasn1-modules" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ea/80/6a696a07d3d3b0a92488933532f03dbefa4a24ab80fb231395b9a2a1be77/google_auth-2.49.1.tar.gz", hash = "sha256:16d40da1c3c5a0533f57d268fe72e0ebb0ae1cc3b567024122651c045d879b64", size = 333825, upload-time = "2026-03-12T19:30:58.135Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c6/fc/e925290a1ad95c975c459e2df070fac2b90954e13a0370ac505dff78cb99/google_auth-2.49.2.tar.gz", hash = "sha256:c1ae38500e73065dcae57355adb6278cf8b5c8e391994ae9cbadbcb9631ab409", size = 333958, upload-time = "2026-04-10T00:41:21.888Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/eb/c6c2478d8a8d633460be40e2a8a6f8f429171997a35a96f81d3b680dec83/google_auth-2.49.1-py3-none-any.whl", hash = "sha256:195ebe3dca18eddd1b3db5edc5189b76c13e96f29e73043b923ebcf3f1a860f7", size = 240737, upload-time = "2026-03-12T19:30:53.159Z" }, + { url = "https://files.pythonhosted.org/packages/73/76/d241a5c927433420507215df6cac1b1fa4ac0ba7a794df42a84326c68da8/google_auth-2.49.2-py3-none-any.whl", hash = "sha256:c2720924dfc82dedb962c9f52cabb2ab16714fd0a6a707e40561d217574ed6d5", size = 240638, upload-time = "2026-04-10T00:41:14.501Z" }, ] [package.optional-dependencies] @@ -2229,7 +2216,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.145.0" +version = "1.147.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -2245,9 +2232,9 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/26/e5/6442d9d2c019456638825d4665b1e87ec4eaf1d182950ba426d0f0210eab/google_cloud_aiplatform-1.145.0.tar.gz", hash = "sha256:7894c4f3d2684bdb60e9a122004c01678e3b585174a27298ae7a3ed1e5eaf3bd", size = 10222904, upload-time = "2026-04-02T14:06:58.322Z" } +sdist = { url = "https://files.pythonhosted.org/packages/23/93/9bfcaaf1ceab12999a881ccf69ebd9b30f467ec5623989c66894e81fc139/google_cloud_aiplatform-1.147.0.tar.gz", hash = "sha256:b2e1b669ba37f02426e03eb13187eebf4cbfeaa0a3bfed37b5578abb375ab689", size = 10235245, upload-time = "2026-04-09T17:14:49.179Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/c6/23e98d3407d5e2416a3dfaecb0a053da899848c50db69e5f2b61a555ce06/google_cloud_aiplatform-1.145.0-py2.py3-none-any.whl", hash = "sha256:4d1c31797a8bd8f3342ed5f186dd30d1f6bca73ddbee2bde452777100d2ddc11", size = 8396640, upload-time = "2026-04-02T14:06:54.125Z" }, + { url = "https://files.pythonhosted.org/packages/d3/d2/1c1c582f6bbed9bbc0daa5acf3a5d98751ca8bc48584548d28569b8ce1a7/google_cloud_aiplatform-1.147.0-py2.py3-none-any.whl", hash = "sha256:29f7ae020718d3c45094f0475464e06a97f81b1572bea150ae6a1b22c5f45997", size = 8408951, upload-time = "2026-04-09T17:14:45.482Z" }, ] [[package]] @@ -2330,7 +2317,7 @@ wheels = [ [[package]] name = "google-genai" -version = "1.65.0" +version = "1.72.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -2344,9 +2331,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/79/f9/cc1191c2540d6a4e24609a586c4ed45d2db57cfef47931c139ee70e5874a/google_genai-1.65.0.tar.gz", hash = "sha256:d470eb600af802d58a79c7f13342d9ea0d05d965007cae8f76c7adff3d7a4750", size = 497206, upload-time = "2026-02-26T00:20:33.824Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3c/20/2aff5ea3cd7459f85101d119c136d9ca4369fcda3dcf0cfee89b305611a4/google_genai-1.72.0.tar.gz", hash = "sha256:abe7d3aecfafb464b904e3a09c81b626fb425e160e123e71a5125a7021cea7b2", size = 522844, upload-time = "2026-04-09T21:35:46.283Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/68/3c/3fea4e7c91357c71782d7dcaad7a2577d636c90317e003386893c25bc62c/google_genai-1.65.0-py3-none-any.whl", hash = "sha256:68c025205856919bc03edb0155c11b4b833810b7ce17ad4b7a9eeba5158f6c44", size = 724429, upload-time = "2026-02-26T00:20:32.186Z" }, + { url = "https://files.pythonhosted.org/packages/9f/3d/9f70246114cdf56a2615a40428ced08bc844f5a26247fe812b2f0dd4eaca/google_genai-1.72.0-py3-none-any.whl", hash = "sha256:ea861e4c6946e3185c24b40d95503e088fc230a73a71fec0ef78164b369a8489", size = 764230, upload-time = "2026-04-09T21:35:44.587Z" }, ] [[package]] @@ -2637,16 +2624,16 @@ wheels = [ [[package]] name = "holo-search-sdk" -version = "0.4.1" +version = "0.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "psycopg", extra = ["binary"] }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0b/b8/70a4999dabbba15e98d201a7399aab76ab96931ad1a27392ba5252cc9165/holo_search_sdk-0.4.1.tar.gz", hash = "sha256:9aea98b6078b9202abb568ed69d798d5e0505d2b4cc3a136a6aa84402bcd2133", size = 56701, upload-time = "2026-01-28T01:44:57.645Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/6d/62bc3f27002a6e1fa6aefdc17f9e95bec67eebb5348542637bf01c8caa6a/holo_search_sdk-0.4.2.tar.gz", hash = "sha256:630ade92c82d3d610a6e4f933f530045a6acbab4528512f5dc5d7f67dd743263", size = 57433, upload-time = "2026-03-25T05:59:25.146Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/30/3059a979272f90a96f31b167443cc27675e8cc8f970a3ac0cb80bf803c70/holo_search_sdk-0.4.1-py3-none-any.whl", hash = "sha256:ef1059895ea936ff6a087f68dac92bd1ae0320e51ec5b1d4e7bed7a5dd6beb45", size = 32647, upload-time = "2026-01-28T01:44:56.098Z" }, + { url = "https://files.pythonhosted.org/packages/fd/9a/5021e499a1aa4fc1f1b8ca5dcbc9987d2ab7115da4fa9d1e464a6590d142/holo_search_sdk-0.4.2-py3-none-any.whl", hash = "sha256:b0ef8e6ee6a6980526317951ab0967d18dd2973500b7e3f38259f061471ac5da", size = 33488, upload-time = "2026-03-25T05:59:23.216Z" }, ] [[package]] @@ -2786,14 +2773,14 @@ wheels = [ [[package]] name = "hypothesis" -version = "6.151.11" +version = "6.151.12" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "sortedcontainers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a9/58/41af0d539b3c95644d1e4e353cbd6ac9473e892ea21802546a8886b79078/hypothesis-6.151.11.tar.gz", hash = "sha256:f33dcb68b62c7b07c9ac49664989be898fa8ce57583f0dc080259a197c6c7ff1", size = 463779, upload-time = "2026-04-05T17:35:55.935Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ce/ab/67ca321d1ab96fd3828b12142f1c258e2d4a668a025d06cd50ab3409787f/hypothesis-6.151.12.tar.gz", hash = "sha256:be485f503979af4c3dfa19e3fc2b967d0458e7f8c4e28128d7e215e0a55102e0", size = 463900, upload-time = "2026-04-08T19:40:06.205Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/06/f49393eca84b87b17a67aaebf9f6251190ba1e9fe9f2236504049fc43fee/hypothesis-6.151.11-py3-none-any.whl", hash = "sha256:7ac05173206746cec8312f95164a30a4eb4916815413a278922e63ff1e404648", size = 529572, upload-time = "2026-04-05T17:35:53.438Z" }, + { url = "https://files.pythonhosted.org/packages/0e/5a/6cecf134b631050a1f8605096adbe812483b60790d951470989d39b56860/hypothesis-6.151.12-py3-none-any.whl", hash = "sha256:37d4f3a768365c30571b11dfd7a6857a12173d933010b2c4ab65619f1b5952c5", size = 529656, upload-time = "2026-04-08T19:40:03.126Z" }, ] [[package]] @@ -2961,11 +2948,11 @@ wheels = [ [[package]] name = "json-repair" -version = "0.55.1" +version = "0.59.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c0/de/71d6bb078d167c0d0959776cee6b6bb8d2ad843f512a5222d7151dde4955/json_repair-0.55.1.tar.gz", hash = "sha256:b27aa0f6bf2e5bf58554037468690446ef26f32ca79c8753282adb3df25fb888", size = 39231, upload-time = "2026-01-23T09:37:20.93Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ed/cb/a49f1661737a78098ce33668350590c981a4163055bc9a01e0cc688d896a/json_repair-0.59.2.tar.gz", hash = "sha256:1d8abb2fa94c4035a66ef9892ea3785dace8dcf09c583e6de781cfd31b278b3d", size = 48341, upload-time = "2026-04-11T15:55:41.145Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/da/289ba9eb550ae420cfc457926f6c49b87cacf8083ee9927e96921888a665/json_repair-0.55.1-py3-none-any.whl", hash = "sha256:a1bcc151982a12bc3ef9e9528198229587b1074999cfe08921ab6333b0c8e206", size = 29743, upload-time = "2026-01-23T09:37:19.404Z" }, + { url = "https://files.pythonhosted.org/packages/e1/03/7afcecb4242d93b684708b47fb014abdc1922a01b38c0e30f1117ae74a83/json_repair-0.59.2-py3-none-any.whl", hash = "sha256:6ca6238519c24f671bcb05d1f38a0d6a452bb4ca5af82137595c5c2f1a0fb785", size = 46918, upload-time = "2026-04-11T15:55:39.817Z" }, ] [[package]] @@ -3052,7 +3039,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/0e/72/a3add0e4eec4eb9e2 [[package]] name = "langfuse" -version = "4.0.6" +version = "4.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff" }, @@ -3064,14 +3051,14 @@ dependencies = [ { name = "pydantic" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ab/d0/6d79ed5614f86f27f5df199cf10c6facf6874ff6f91b828ae4dad90aa86d/langfuse-4.0.6.tar.gz", hash = "sha256:83a6f8cc8f1431fa2958c91e2673bc4179f993297e9b1acd1dbf001785e6cf83", size = 274094, upload-time = "2026-04-01T20:04:15.153Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/9c/b912a00ffae92ff9955cdd9b74fb839be58f631d4329ae2a8a0376f697f2/langfuse-4.2.0.tar.gz", hash = "sha256:d0bd26d5065cf6a59d7d1093b08d8910e2458dc3da7ed8ccec160db114c18342", size = 275582, upload-time = "2026-04-10T11:55:25.21Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/50/b4/088048e37b6d7ec1b52c6a11bc33101454285a22eaab8303dcccfd78344d/langfuse-4.0.6-py3-none-any.whl", hash = "sha256:0562b1dcf83247f9d8349f0f755eaed9a7f952fee67e66580970f0738bf3adbf", size = 472841, upload-time = "2026-04-01T20:04:16.451Z" }, + { url = "https://files.pythonhosted.org/packages/be/0a/b84e3e68a690ccfe6d64953c572772c685fcb0915b7f2ee3a87c22e388ab/langfuse-4.2.0-py3-none-any.whl", hash = "sha256:bfd760bf10fd0228f297f6369436620f76d16b589de46393d65706b27e4e4082", size = 475449, upload-time = "2026-04-10T11:55:23.624Z" }, ] [[package]] name = "langsmith" -version = "0.7.25" +version = "0.7.30" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -3084,9 +3071,9 @@ dependencies = [ { name = "xxhash" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7e/d7/21ffae5ccdc3c9b8de283e8f8bf48a92039681df0d39f15133d8ff8965bd/langsmith-0.7.25.tar.gz", hash = "sha256:d17da71f156ca69eafd28ac9627c8e0e93170260ec37cd27cedc83205a067598", size = 1145410, upload-time = "2026-04-03T13:11:42.36Z" } +sdist = { url = "https://files.pythonhosted.org/packages/46/e7/d27d952ce9824d684a3bb500a06541a2d55734bc4d849cdfcca2dfd4d93a/langsmith-0.7.30.tar.gz", hash = "sha256:d9df7ba5e42f818b63bda78776c8f2fc853388be3ae77b117e5d183a149321a2", size = 1106040, upload-time = "2026-04-09T21:12:01.892Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/29/13/67889d41baf7dbaf13ffd0b334a0f284e107fad1cc8782a1abb1e56e5eeb/langsmith-0.7.25-py3-none-any.whl", hash = "sha256:55ecc24c547f6c79b5a684ff8685c669eec34e52fcac5d2c0af7d613aef5a632", size = 359417, upload-time = "2026-04-03T13:11:40.729Z" }, + { url = "https://files.pythonhosted.org/packages/37/19/96250cf58070c5563446651b03bb76c2eb5afbf08e754840ab639532d8c6/langsmith-0.7.30-py3-none-any.whl", hash = "sha256:43dd9f8d290e4d406606d6cc0bd62f5d1050963f05fe0ab6ffe50acf41f2f55a", size = 372682, upload-time = "2026-04-09T21:12:00.481Z" }, ] [[package]] @@ -3144,15 +3131,14 @@ wheels = [ [[package]] name = "llvmlite" -version = "0.45.1" +version = "0.47.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/99/8d/5baf1cef7f9c084fb35a8afbde88074f0d6a727bc63ef764fe0e7543ba40/llvmlite-0.45.1.tar.gz", hash = "sha256:09430bb9d0bb58fc45a45a57c7eae912850bedc095cd0810a57de109c69e1c32", size = 185600, upload-time = "2025-10-01T17:59:52.046Z" } +sdist = { url = "https://files.pythonhosted.org/packages/01/88/a8952b6d5c21e74cbf158515b779666f692846502623e9e3c39d8e8ba25f/llvmlite-0.47.0.tar.gz", hash = "sha256:62031ce968ec74e95092184d4b0e857e444f8fdff0b8f9213707699570c33ccc", size = 193614, upload-time = "2026-03-31T18:29:53.497Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e2/7c/82cbd5c656e8991bcc110c69d05913be2229302a92acb96109e166ae31fb/llvmlite-0.45.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:28e763aba92fe9c72296911e040231d486447c01d4f90027c8e893d89d49b20e", size = 43043524, upload-time = "2025-10-01T18:03:30.666Z" }, - { url = "https://files.pythonhosted.org/packages/9d/bc/5314005bb2c7ee9f33102c6456c18cc81745d7055155d1218f1624463774/llvmlite-0.45.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1a53f4b74ee9fd30cb3d27d904dadece67a7575198bd80e687ee76474620735f", size = 37253123, upload-time = "2025-10-01T18:04:18.177Z" }, - { url = "https://files.pythonhosted.org/packages/96/76/0f7154952f037cb320b83e1c952ec4a19d5d689cf7d27cb8a26887d7bbc1/llvmlite-0.45.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b3796b1b1e1c14dcae34285d2f4ea488402fbd2c400ccf7137603ca3800864f", size = 56288211, upload-time = "2025-10-01T18:01:24.079Z" }, - { url = "https://files.pythonhosted.org/packages/00/b1/0b581942be2683ceb6862d558979e87387e14ad65a1e4db0e7dd671fa315/llvmlite-0.45.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:779e2f2ceefef0f4368548685f0b4adde34e5f4b457e90391f570a10b348d433", size = 55140958, upload-time = "2025-10-01T18:02:30.482Z" }, - { url = "https://files.pythonhosted.org/packages/33/94/9ba4ebcf4d541a325fd8098ddc073b663af75cc8b065b6059848f7d4dce7/llvmlite-0.45.1-cp312-cp312-win_amd64.whl", hash = "sha256:9e6c9949baf25d9aa9cd7cf0f6d011b9ca660dd17f5ba2b23bdbdb77cc86b116", size = 38132231, upload-time = "2025-10-01T18:05:03.664Z" }, + { url = "https://files.pythonhosted.org/packages/fa/48/4b7fe0e34c169fa2f12532916133e0b219d2823b540733651b34fdac509a/llvmlite-0.47.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:306a265f408c259067257a732c8e159284334018b4083a9e35f67d19792b164f", size = 37232769, upload-time = "2026-03-31T18:28:43.735Z" }, + { url = "https://files.pythonhosted.org/packages/e6/4b/e3f2cd17822cf772a4a51a0a8080b0032e6d37b2dbe8cfb724eac4e31c52/llvmlite-0.47.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5853bf26160857c0c2573415ff4efe01c4c651e59e2c55c2a088740acfee51cd", size = 56275178, upload-time = "2026-03-31T18:28:48.342Z" }, + { url = "https://files.pythonhosted.org/packages/b6/55/a3b4a543185305a9bdf3d9759d53646ed96e55e7dfd43f53e7a421b8fbae/llvmlite-0.47.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:003bcf7fa579e14db59c1a1e113f93ab8a06b56a4be31c7f08264d1d4072d077", size = 55128632, upload-time = "2026-03-31T18:28:52.901Z" }, + { url = "https://files.pythonhosted.org/packages/2f/f5/d281ae0f79378a5a91f308ea9fdb9f9cc068fddd09629edc0725a5a8fde1/llvmlite-0.47.0-cp312-cp312-win_amd64.whl", hash = "sha256:f3079f25bdc24cd9d27c4b2b5e68f5f60c4fdb7e8ad5ee2b9b006007558f9df7", size = 38138692, upload-time = "2026-03-31T18:28:57.147Z" }, ] [[package]] @@ -3269,7 +3255,7 @@ wheels = [ [[package]] name = "mlflow-skinny" -version = "3.10.1" +version = "3.11.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -3292,9 +3278,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "uvicorn" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/65/5b2c28e74c167ba8a5afe59399ef44291a0f140487f534db1900f09f59f6/mlflow_skinny-3.10.1.tar.gz", hash = "sha256:3d1c5c30245b6e7065b492b09dd47be7528e0a14c4266b782fe58f9bcd1e0be0", size = 2478631, upload-time = "2026-03-05T10:49:01.47Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/77/fe2027ddad9e52ed1ac360fbc262169e6366f6678632e350cbd0d901bb9b/mlflow_skinny-3.11.1.tar.gz", hash = "sha256:86ce63491349f6713afc8a4ef0bf77a8314d0e79e03753cb150d6c860a0b0475", size = 2642799, upload-time = "2026-04-07T14:26:43.818Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4b/52/17460157271e70b0d8444d27f8ad730ef7d95fb82fac59dc19f11519b921/mlflow_skinny-3.10.1-py3-none-any.whl", hash = "sha256:df1dd507d8ddadf53bfab2423c76cdcafc235cd1a46921a06d1a6b4dd04b023c", size = 2987098, upload-time = "2026-03-05T10:48:59.566Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a7/e61ec397b34dc3c9e91572f45e41617f429d5c524d38a4e1aa2316ee1b5e/mlflow_skinny-3.11.1-py3-none-any.whl", hash = "sha256:82ffd5f6980320b4ac19f741e7a754faa1d01707e632b002ea68e04fd25a0535", size = 3171551, upload-time = "2026-04-07T14:26:41.762Z" }, ] [[package]] @@ -3509,19 +3495,18 @@ wheels = [ [[package]] name = "numba" -version = "0.62.1" +version = "0.65.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "llvmlite" }, { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a3/20/33dbdbfe60e5fd8e3dbfde299d106279a33d9f8308346022316781368591/numba-0.62.1.tar.gz", hash = "sha256:7b774242aa890e34c21200a1fc62e5b5757d5286267e71103257f4e2af0d5161", size = 2749817, upload-time = "2025-09-29T10:46:31.551Z" } +sdist = { url = "https://files.pythonhosted.org/packages/49/61/7299643b9c18d669e04be7c5bcb64d985070d07553274817b45b049e7bfe/numba-0.65.0.tar.gz", hash = "sha256:edad0d9f6682e93624c00125a471ae4df186175d71fd604c983c377cdc03e68b", size = 2764131, upload-time = "2026-04-01T03:52:01.946Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/fa/30fa6873e9f821c0ae755915a3ca444e6ff8d6a7b6860b669a3d33377ac7/numba-0.62.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:1b743b32f8fa5fff22e19c2e906db2f0a340782caf024477b97801b918cf0494", size = 2685346, upload-time = "2025-09-29T10:43:43.677Z" }, - { url = "https://files.pythonhosted.org/packages/a9/d5/504ce8dc46e0dba2790c77e6b878ee65b60fe3e7d6d0006483ef6fde5a97/numba-0.62.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90fa21b0142bcf08ad8e32a97d25d0b84b1e921bc9423f8dda07d3652860eef6", size = 2688139, upload-time = "2025-09-29T10:44:04.894Z" }, - { url = "https://files.pythonhosted.org/packages/50/5f/6a802741176c93f2ebe97ad90751894c7b0c922b52ba99a4395e79492205/numba-0.62.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6ef84d0ac19f1bf80431347b6f4ce3c39b7ec13f48f233a48c01e2ec06ecbc59", size = 3796453, upload-time = "2025-09-29T10:42:52.771Z" }, - { url = "https://files.pythonhosted.org/packages/7e/df/efd21527d25150c4544eccc9d0b7260a5dec4b7e98b5a581990e05a133c0/numba-0.62.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9315cc5e441300e0ca07c828a627d92a6802bcbf27c5487f31ae73783c58da53", size = 3496451, upload-time = "2025-09-29T10:43:19.279Z" }, - { url = "https://files.pythonhosted.org/packages/80/44/79bfdab12a02796bf4f1841630355c82b5a69933b1d50eb15c7fa37dabe8/numba-0.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:44e3aa6228039992f058f5ebfcfd372c83798e9464297bdad8cc79febcf7891e", size = 2745552, upload-time = "2025-09-29T10:44:26.399Z" }, + { url = "https://files.pythonhosted.org/packages/6c/2f/8bd31a1ea43c01ac215283d83aa5f8d5acbe7a36c85b82f1757bfe9ccb31/numba-0.65.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b27ee4847e1bfb17e9604d100417ee7c1d10f15a6711c6213404b3da13a0b2aa", size = 2680705, upload-time = "2026-04-01T03:51:32.597Z" }, + { url = "https://files.pythonhosted.org/packages/73/36/88406bd58600cc696417b8e5dd6a056478da808f3eaf48d18e2421e0c2d9/numba-0.65.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a52d92ffd297c10364bce60cd1fcb88f99284ab5df085f2c6bcd1cb33b529a6f", size = 3801411, upload-time = "2026-04-01T03:51:34.321Z" }, + { url = "https://files.pythonhosted.org/packages/0c/61/ce753a1d7646dd477e16d15e89473703faebb8995d2f71d7ad69a540b565/numba-0.65.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da8e371e328c06d0010c3d8b44b21858652831b85bcfba78cb22c042e22dbd8e", size = 3501622, upload-time = "2026-04-01T03:51:36.348Z" }, + { url = "https://files.pythonhosted.org/packages/7d/86/db87a5393f1b1fabef53ac3ba4e6b938bb27e40a04ad7cc512098fcae032/numba-0.65.0-cp312-cp312-win_amd64.whl", hash = "sha256:59bb9f2bb9f1238dfd8e927ba50645c18ae769fef4f3d58ea0ea22a2683b91f5", size = 2749979, upload-time = "2026-04-01T03:51:37.88Z" }, ] [[package]] @@ -3545,30 +3530,33 @@ wheels = [ [[package]] name = "numpy" -version = "1.26.4" +version = "2.4.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129, upload-time = "2024-02-06T00:26:44.495Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/9f/b8cef5bffa569759033adda9481211426f12f53299629b410340795c2514/numpy-2.4.4.tar.gz", hash = "sha256:2d390634c5182175533585cc89f3608a4682ccb173cc9bb940b2881c8d6f8fa0", size = 20731587, upload-time = "2026-03-29T13:22:01.298Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/12/8f2020a8e8b8383ac0177dc9570aad031a3beb12e38847f7129bacd96228/numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218", size = 20335901, upload-time = "2024-02-05T23:55:32.801Z" }, - { url = "https://files.pythonhosted.org/packages/75/5b/ca6c8bd14007e5ca171c7c03102d17b4f4e0ceb53957e8c44343a9546dcc/numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b", size = 13685868, upload-time = "2024-02-05T23:55:56.28Z" }, - { url = "https://files.pythonhosted.org/packages/79/f8/97f10e6755e2a7d027ca783f63044d5b1bc1ae7acb12afe6a9b4286eac17/numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b", size = 13925109, upload-time = "2024-02-05T23:56:20.368Z" }, - { url = "https://files.pythonhosted.org/packages/0f/50/de23fde84e45f5c4fda2488c759b69990fd4512387a8632860f3ac9cd225/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed", size = 17950613, upload-time = "2024-02-05T23:56:56.054Z" }, - { url = "https://files.pythonhosted.org/packages/4c/0c/9c603826b6465e82591e05ca230dfc13376da512b25ccd0894709b054ed0/numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a", size = 13572172, upload-time = "2024-02-05T23:57:21.56Z" }, - { url = "https://files.pythonhosted.org/packages/76/8c/2ba3902e1a0fc1c74962ea9bb33a534bb05984ad7ff9515bf8d07527cadd/numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0", size = 17786643, upload-time = "2024-02-05T23:57:56.585Z" }, - { url = "https://files.pythonhosted.org/packages/28/4a/46d9e65106879492374999e76eb85f87b15328e06bd1550668f79f7b18c6/numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110", size = 5677803, upload-time = "2024-02-05T23:58:08.963Z" }, - { url = "https://files.pythonhosted.org/packages/16/2e/86f24451c2d530c88daf997cb8d6ac622c1d40d19f5a031ed68a4b73a374/numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818", size = 15517754, upload-time = "2024-02-05T23:58:36.364Z" }, + { url = "https://files.pythonhosted.org/packages/28/05/32396bec30fb2263770ee910142f49c1476d08e8ad41abf8403806b520ce/numpy-2.4.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:15716cfef24d3a9762e3acdf87e27f58dc823d1348f765bbea6bef8c639bfa1b", size = 16689272, upload-time = "2026-03-29T13:18:49.223Z" }, + { url = "https://files.pythonhosted.org/packages/c5/f3/a983d28637bfcd763a9c7aafdb6d5c0ebf3d487d1e1459ffdb57e2f01117/numpy-2.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23cbfd4c17357c81021f21540da84ee282b9c8fba38a03b7b9d09ba6b951421e", size = 14699573, upload-time = "2026-03-29T13:18:52.629Z" }, + { url = "https://files.pythonhosted.org/packages/9b/fd/e5ecca1e78c05106d98028114f5c00d3eddb41207686b2b7de3e477b0e22/numpy-2.4.4-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8b3b60bb7cba2c8c81837661c488637eee696f59a877788a396d33150c35d842", size = 5204782, upload-time = "2026-03-29T13:18:55.579Z" }, + { url = "https://files.pythonhosted.org/packages/de/2f/702a4594413c1a8632092beae8aba00f1d67947389369b3777aed783fdca/numpy-2.4.4-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:e4a010c27ff6f210ff4c6ef34394cd61470d01014439b192ec22552ee867f2a8", size = 6552038, upload-time = "2026-03-29T13:18:57.769Z" }, + { url = "https://files.pythonhosted.org/packages/7f/37/eed308a8f56cba4d1fdf467a4fc67ef4ff4bf1c888f5fc980481890104b1/numpy-2.4.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f9e75681b59ddaa5e659898085ae0eaea229d054f2ac0c7e563a62205a700121", size = 15670666, upload-time = "2026-03-29T13:19:00.341Z" }, + { url = "https://files.pythonhosted.org/packages/0a/0d/0e3ecece05b7a7e87ab9fb587855548da437a061326fff64a223b6dcb78a/numpy-2.4.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:81f4a14bee47aec54f883e0cad2d73986640c1590eb9bfaaba7ad17394481e6e", size = 16645480, upload-time = "2026-03-29T13:19:03.63Z" }, + { url = "https://files.pythonhosted.org/packages/34/49/f2312c154b82a286758ee2f1743336d50651f8b5195db18cdb63675ff649/numpy-2.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:62d6b0f03b694173f9fcb1fb317f7222fd0b0b103e784c6549f5e53a27718c44", size = 17020036, upload-time = "2026-03-29T13:19:07.428Z" }, + { url = "https://files.pythonhosted.org/packages/7b/e9/736d17bd77f1b0ec4f9901aaec129c00d59f5d84d5e79bba540ef12c2330/numpy-2.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fbc356aae7adf9e6336d336b9c8111d390a05df88f1805573ebb0807bd06fd1d", size = 18368643, upload-time = "2026-03-29T13:19:10.775Z" }, + { url = "https://files.pythonhosted.org/packages/63/f6/d417977c5f519b17c8a5c3bc9e8304b0908b0e21136fe43bf628a1343914/numpy-2.4.4-cp312-cp312-win32.whl", hash = "sha256:0d35aea54ad1d420c812bfa0385c71cd7cc5bcf7c65fed95fc2cd02fe8c79827", size = 5961117, upload-time = "2026-03-29T13:19:13.464Z" }, + { url = "https://files.pythonhosted.org/packages/2d/5b/e1deebf88ff431b01b7406ca3583ab2bbb90972bbe1c568732e49c844f7e/numpy-2.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:b5f0362dc928a6ecd9db58868fca5e48485205e3855957bdedea308f8672ea4a", size = 12320584, upload-time = "2026-03-29T13:19:16.155Z" }, + { url = "https://files.pythonhosted.org/packages/58/89/e4e856ac82a68c3ed64486a544977d0e7bdd18b8da75b78a577ca31c4395/numpy-2.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:846300f379b5b12cc769334464656bc882e0735d27d9726568bc932fdc49d5ec", size = 10221450, upload-time = "2026-03-29T13:19:18.994Z" }, ] [[package]] name = "numpy-typing-compat" -version = "20250818.1.25" +version = "20251206.2.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ff/a7/780dc00f4fed2f2b653f76a196b3a6807c7c667f30ae95a7fd082c1081d8/numpy_typing_compat-20250818.1.25.tar.gz", hash = "sha256:8ff461725af0b436e9b0445d07712f1e6e3a97540a3542810f65f936dcc587a5", size = 5027, upload-time = "2025-08-18T23:46:39.062Z" } +sdist = { url = "https://files.pythonhosted.org/packages/42/5f/29fd5f29b0a5d96e2def96ecba3112fc330ecd16e8c97c2b332563c5e201/numpy_typing_compat-20251206.2.4.tar.gz", hash = "sha256:59882d23aaff054a2536da80564012cdce33487657be4d79c5925bb8705fcabc", size = 5011, upload-time = "2025-12-06T20:02:04.942Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/71/30e8d317b6896acbc347d3089764b6209ba299095550773e14d27dcf035f/numpy_typing_compat-20250818.1.25-py3-none-any.whl", hash = "sha256:4f91427369583074b236c804dd27559134f08ec4243485034c8e7d258cbd9cd3", size = 6355, upload-time = "2025-08-18T23:46:30.927Z" }, + { url = "https://files.pythonhosted.org/packages/63/7c/5c2892e6bc0628a2ccf4e938e1e2db22794657ccb374672d66e20d73839e/numpy_typing_compat-20251206.2.4-py3-none-any.whl", hash = "sha256:a82e723bd20efaa4cf2886709d4264c144f1f2b609bda83d1545113b7e47a5b5", size = 6300, upload-time = "2025-12-06T20:01:57.578Z" }, ] [[package]] @@ -3720,59 +3708,59 @@ wheels = [ [[package]] name = "opentelemetry-api" -version = "1.40.0" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "importlib-metadata" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2c/1d/4049a9e8698361cc1a1aa03a6c59e4fa4c71e0c0f94a30f988a6876a2ae6/opentelemetry_api-1.40.0.tar.gz", hash = "sha256:159be641c0b04d11e9ecd576906462773eb97ae1b657730f0ecf64d32071569f", size = 70851, upload-time = "2026-03-04T14:17:21.555Z" } +sdist = { url = "https://files.pythonhosted.org/packages/47/8e/3778a7e87801d994869a9396b9fc2a289e5f9be91ff54a27d41eace494b0/opentelemetry_api-1.41.0.tar.gz", hash = "sha256:9421d911326ec12dee8bc933f7839090cad7a3f13fcfb0f9e82f8174dc003c09", size = 71416, upload-time = "2026-04-09T14:38:34.544Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676, upload-time = "2026-03-04T14:17:01.24Z" }, + { url = "https://files.pythonhosted.org/packages/58/ee/99ab786653b3bda9c37ade7e24a7b607a1b1f696063172768417539d876d/opentelemetry_api-1.41.0-py3-none-any.whl", hash = "sha256:0e77c806e6a89c9e4f8d372034622f3e1418a11bdbe1c80a50b3d3397ad0fa4f", size = 69007, upload-time = "2026-04-09T14:38:11.833Z" }, ] [[package]] name = "opentelemetry-distro" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-sdk" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f5/00/1f8acc51326956a596fefaf67751380001af36029132a7a07d4debce3c06/opentelemetry_distro-0.61b0.tar.gz", hash = "sha256:975b845f50181ad53753becf4fd4b123b54fa04df5a9d78812264436d6518981", size = 2590, upload-time = "2026-03-04T14:20:12.453Z" } +sdist = { url = "https://files.pythonhosted.org/packages/72/c6/52b0dbcc8fbdecf179047921940516cbb8aaf05f6b737faa526ad76fec51/opentelemetry_distro-0.62b0.tar.gz", hash = "sha256:aa0308fbe50ad8f17d4446982dbf26870e20b8031ba38d8e1224ecf7aedd3184", size = 2611, upload-time = "2026-04-09T14:40:20.404Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/2c/efcc995cd7484e6e55b1d26bd7fa6c55ca96bd415ff94310b52c19f330b0/opentelemetry_distro-0.61b0-py3-none-any.whl", hash = "sha256:f21d1ac0627549795d75e332006dd068877f00e461b1b2e8fe4568d6eb7b9590", size = 3349, upload-time = "2026-03-04T14:18:57.788Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7e/5858bba1c7ed880c7b0fe7d9a1ea40ab8affd18c9ebc1e16c2d69c501da1/opentelemetry_distro-0.62b0-py3-none-any.whl", hash = "sha256:23e9065a35cef12868ad5efb18ce9c88a9103800256b318dec4c9c850c6c78c1", size = 3348, upload-time = "2026-04-09T14:39:17.406Z" }, ] [[package]] name = "opentelemetry-exporter-otlp" -version = "1.40.0" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-exporter-otlp-proto-grpc" }, { name = "opentelemetry-exporter-otlp-proto-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d0/37/b6708e0eff5c5fb9aba2e0ea09f7f3bcbfd12a592d2a780241b5f6014df7/opentelemetry_exporter_otlp-1.40.0.tar.gz", hash = "sha256:7caa0870b95e2fcb59d64e16e2b639ecffb07771b6cd0000b5d12e5e4fef765a", size = 6152, upload-time = "2026-03-04T14:17:23.235Z" } +sdist = { url = "https://files.pythonhosted.org/packages/65/b7/845565a2ab5d22c1486bc7729a06b05cd0964c61539d766e1f107c9eea0c/opentelemetry_exporter_otlp-1.41.0.tar.gz", hash = "sha256:97ff847321f8d4c919032a67d20d3137fb7b34eac0c47f13f71112858927fc5b", size = 6152, upload-time = "2026-04-09T14:38:35.895Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2d/fc/aea77c28d9f3ffef2fdafdc3f4a235aee4091d262ddabd25882f47ce5c5f/opentelemetry_exporter_otlp-1.40.0-py3-none-any.whl", hash = "sha256:48c87e539ec9afb30dc443775a1334cc5487de2f72a770a4c00b1610bf6c697d", size = 7023, upload-time = "2026-03-04T14:17:03.612Z" }, + { url = "https://files.pythonhosted.org/packages/e0/f2/f1076fff152858773f22cda146713f9ae3661795af6bacd411a76f2151ac/opentelemetry_exporter_otlp-1.41.0-py3-none-any.whl", hash = "sha256:443b6a45c990ae4c55e147f97049a86c5f5b704f3d78b48b44a073a886ec4d6e", size = 7022, upload-time = "2026-04-09T14:38:13.934Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.40.0" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-proto" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/51/bc/1559d46557fe6eca0b46c88d4c2676285f1f3be2e8d06bb5d15fbffc814a/opentelemetry_exporter_otlp_proto_common-1.40.0.tar.gz", hash = "sha256:1cbee86a4064790b362a86601ee7934f368b81cd4cc2f2e163902a6e7818a0fa", size = 20416, upload-time = "2026-03-04T14:17:23.801Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/28/e8eca94966fe9a1465f6094dc5ddc5398473682180279c94020bc23b4906/opentelemetry_exporter_otlp_proto_common-1.41.0.tar.gz", hash = "sha256:966bbce537e9edb166154779a7c4f8ab6b8654a03a28024aeaf1a3eacb07d6ee", size = 20411, upload-time = "2026-04-09T14:38:36.572Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/ca/8f122055c97a932311a3f640273f084e738008933503d0c2563cd5d591fc/opentelemetry_exporter_otlp_proto_common-1.40.0-py3-none-any.whl", hash = "sha256:7081ff453835a82417bf38dccf122c827c3cbc94f2079b03bba02a3165f25149", size = 18369, upload-time = "2026-03-04T14:17:04.796Z" }, + { url = "https://files.pythonhosted.org/packages/26/c4/78b9bf2d9c1d5e494f44932988d9d91c51a66b9a7b48adf99b62f7c65318/opentelemetry_exporter_otlp_proto_common-1.41.0-py3-none-any.whl", hash = "sha256:7a99177bf61f85f4f9ed2072f54d676364719c066f6d11f515acc6c745c7acf0", size = 18366, upload-time = "2026-04-09T14:38:15.135Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.40.0" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "googleapis-common-protos" }, @@ -3783,14 +3771,14 @@ dependencies = [ { name = "opentelemetry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8f/7f/b9e60435cfcc7590fa87436edad6822240dddbc184643a2a005301cc31f4/opentelemetry_exporter_otlp_proto_grpc-1.40.0.tar.gz", hash = "sha256:bd4015183e40b635b3dab8da528b27161ba83bf4ef545776b196f0fb4ec47740", size = 25759, upload-time = "2026-03-04T14:17:24.4Z" } +sdist = { url = "https://files.pythonhosted.org/packages/42/46/d75a3f8c91915f2e58f61d0a2e4ada63891e7c7a37a20ff7949ba184a6b2/opentelemetry_exporter_otlp_proto_grpc-1.41.0.tar.gz", hash = "sha256:f704201251c6f65772b11bddea1c948000554459101bdbb0116e0a01b70592f6", size = 25754, upload-time = "2026-04-09T14:38:37.423Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/6f/7ee0980afcbdcd2d40362da16f7f9796bd083bf7f0b8e038abfbc0300f5d/opentelemetry_exporter_otlp_proto_grpc-1.40.0-py3-none-any.whl", hash = "sha256:2aa0ca53483fe0cf6405087a7491472b70335bc5c7944378a0a8e72e86995c52", size = 20304, upload-time = "2026-03-04T14:17:05.942Z" }, + { url = "https://files.pythonhosted.org/packages/81/f6/b09e2e0c9f0b5750cebc6eaf31527b910821453cef40a5a0fe93550422b2/opentelemetry_exporter_otlp_proto_grpc-1.41.0-py3-none-any.whl", hash = "sha256:3a1a86bd24806ccf136ec9737dbfa4c09b069f9130ff66b0acb014f9c5255fd1", size = 20299, upload-time = "2026-04-09T14:38:17.01Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-http" -version = "1.40.0" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "googleapis-common-protos" }, @@ -3801,14 +3789,14 @@ dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2e/fa/73d50e2c15c56be4d000c98e24221d494674b0cc95524e2a8cb3856d95a4/opentelemetry_exporter_otlp_proto_http-1.40.0.tar.gz", hash = "sha256:db48f5e0f33217588bbc00274a31517ba830da576e59503507c839b38fa0869c", size = 17772, upload-time = "2026-03-04T14:17:25.324Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/63/d9f43cd75f3fabb7e01148c89cfa9491fc18f6580a6764c554ff7c953c46/opentelemetry_exporter_otlp_proto_http-1.41.0.tar.gz", hash = "sha256:dcd6e0686f56277db4eecbadd5262124e8f2cc739cadbc3fae3d08a12c976cf5", size = 24139, upload-time = "2026-04-09T14:38:38.128Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/3a/8865d6754e61c9fb170cdd530a124a53769ee5f740236064816eb0ca7301/opentelemetry_exporter_otlp_proto_http-1.40.0-py3-none-any.whl", hash = "sha256:a8d1dab28f504c5d96577d6509f80a8150e44e8f45f82cdbe0e34c99ab040069", size = 19960, upload-time = "2026-03-04T14:17:07.153Z" }, + { url = "https://files.pythonhosted.org/packages/64/b5/a214cd907eedc17699d1c2d602288ae17cb775526df04db3a3b3585329d2/opentelemetry_exporter_otlp_proto_http-1.41.0-py3-none-any.whl", hash = "sha256:a9c4ee69cce9c3f4d7ee736ad1b44e3c9654002c0816900abbafd9f3cf289751", size = 22673, upload-time = "2026-04-09T14:38:18.349Z" }, ] [[package]] name = "opentelemetry-instrumentation" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -3816,14 +3804,14 @@ dependencies = [ { name = "packaging" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/da/37/6bf8e66bfcee5d3c6515b79cb2ee9ad05fe573c20f7ceb288d0e7eeec28c/opentelemetry_instrumentation-0.61b0.tar.gz", hash = "sha256:cb21b48db738c9de196eba6b805b4ff9de3b7f187e4bbf9a466fa170514f1fc7", size = 32606, upload-time = "2026-03-04T14:20:16.825Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/fd/b8e90bb340957f059084376f94cff336b0e871a42feba7d3f7342365e987/opentelemetry_instrumentation-0.62b0.tar.gz", hash = "sha256:aa1b0b9ab2e1722c2a8a5384fb016fc28d30bba51826676c8036074790d2861e", size = 34042, upload-time = "2026-04-09T14:40:22.843Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d8/3e/f6f10f178b6316de67f0dfdbbb699a24fbe8917cf1743c1595fb9dcdd461/opentelemetry_instrumentation-0.61b0-py3-none-any.whl", hash = "sha256:92a93a280e69788e8f88391247cc530fd81f16f2b011979d4d6398f805cfbc63", size = 33448, upload-time = "2026-03-04T14:19:02.447Z" }, + { url = "https://files.pythonhosted.org/packages/00/b6/3356d2e335e3c449c5183e9b023f30f04f1b7073a6583c68745ea2e704b1/opentelemetry_instrumentation-0.62b0-py3-none-any.whl", hash = "sha256:30d4e76486eae64fb095264a70c2c809c4bed17b73373e53091470661f7d477c", size = 34158, upload-time = "2026-04-09T14:39:21.428Z" }, ] [[package]] name = "opentelemetry-instrumentation-asgi" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "asgiref" }, @@ -3832,28 +3820,28 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/00/3e/143cf5c034e58037307e6a24f06e0dd64b2c49ae60a965fc580027581931/opentelemetry_instrumentation_asgi-0.61b0.tar.gz", hash = "sha256:9d08e127244361dc33976d39dd4ca8f128b5aa5a7ae425208400a80a095019b5", size = 26691, upload-time = "2026-03-04T14:20:21.038Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/38/999bf777774878971c2716de4b7a03cd57a7decb4af25090e703b79fa0e5/opentelemetry_instrumentation_asgi-0.62b0.tar.gz", hash = "sha256:93cde8c62e5918a3c1ff9ba020518127300e5e0816b7e8b14baf46a26ba619fc", size = 26779, upload-time = "2026-04-09T14:40:26.566Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/19/78/154470cf9d741a7487fbb5067357b87386475bbb77948a6707cae982e158/opentelemetry_instrumentation_asgi-0.61b0-py3-none-any.whl", hash = "sha256:e4b3ce6b66074e525e717efff20745434e5efd5d9df6557710856fba356da7a4", size = 16980, upload-time = "2026-03-04T14:19:10.894Z" }, + { url = "https://files.pythonhosted.org/packages/25/cf/29df82f5870178143bdb5c9a7be044b9f78c71e1c5dcf995242e86d80158/opentelemetry_instrumentation_asgi-0.62b0-py3-none-any.whl", hash = "sha256:89b62a6f996b260b162f515c25e6d78e39286e4cbe2f935899e51b32f31027e2", size = 17011, upload-time = "2026-04-09T14:39:27.305Z" }, ] [[package]] name = "opentelemetry-instrumentation-celery" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-semantic-conventions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/43/e79108a804d16b1dc8ff28edd0e94ac393cf6359a5adcd7cdd2ec4be85f4/opentelemetry_instrumentation_celery-0.61b0.tar.gz", hash = "sha256:0e352a567dc89ed8bc083fc635035ce3c5b96bbbd92831ffd676e93b87f8e94f", size = 14780, upload-time = "2026-03-04T14:20:27.776Z" } +sdist = { url = "https://files.pythonhosted.org/packages/01/b4/20a3c8c669dc45aa3703c0370041d67e8be613f1829523cdaf634a5f9626/opentelemetry_instrumentation_celery-0.62b0.tar.gz", hash = "sha256:55e8fa48e5b886bcca448fa32e28a6cc2165157745e8328de479a826d3903095", size = 14808, upload-time = "2026-04-09T14:40:31.603Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a2/ed/c05f3c84b455654eb6c047474ffde61ed92efc24030f64213c98bca9d44b/opentelemetry_instrumentation_celery-0.61b0-py3-none-any.whl", hash = "sha256:01235733ff0cdf571cb03b270645abb14b9c8d830313dc5842097ec90146320b", size = 13856, upload-time = "2026-03-04T14:19:20.98Z" }, + { url = "https://files.pythonhosted.org/packages/f6/60/cf951e6bd6ec62ec55bd2384e0ba9841ea38f2d128c773d85dc60da97172/opentelemetry_instrumentation_celery-0.62b0-py3-none-any.whl", hash = "sha256:cadfd3e65287a36099dce5ba7e05d98e4c5f9479a455241e01d140ecc5c10935", size = 13864, upload-time = "2026-04-09T14:39:35.009Z" }, ] [[package]] name = "opentelemetry-instrumentation-fastapi" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -3862,14 +3850,14 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/37/35/aa727bb6e6ef930dcdc96a617b83748fece57b43c47d83ba8d83fbeca657/opentelemetry_instrumentation_fastapi-0.61b0.tar.gz", hash = "sha256:3a24f35b07c557ae1bbc483bf8412221f25d79a405f8b047de8b670722e2fa9f", size = 24800, upload-time = "2026-03-04T14:20:32.759Z" } +sdist = { url = "https://files.pythonhosted.org/packages/37/09/92740c6d114d1bef392557a03ae6de64065c83c1b331dae9b57fe718497c/opentelemetry_instrumentation_fastapi-0.62b0.tar.gz", hash = "sha256:e4748e4e575077e08beaf2c5d2f369da63dd90882d89d73c4192a97356637dec", size = 25056, upload-time = "2026-04-09T14:40:36.438Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/91/05/acfeb2cccd434242a0a7d0ea29afaf077e04b42b35b485d89aee4e0d9340/opentelemetry_instrumentation_fastapi-0.61b0-py3-none-any.whl", hash = "sha256:a1a844d846540d687d377516b2ff698b51d87c781b59f47c214359c4a241047c", size = 13485, upload-time = "2026-03-04T14:19:30.351Z" }, + { url = "https://files.pythonhosted.org/packages/64/bb/186ffe0fde0ad33ceb50e1d3596cc849b732d3b825592a6a507a40c8c49b/opentelemetry_instrumentation_fastapi-0.62b0-py3-none-any.whl", hash = "sha256:06d3272ad15f9daea5a0a27c32831aff376110a4b0394197120256ef6d610e6e", size = 13482, upload-time = "2026-04-09T14:39:43.446Z" }, ] [[package]] name = "opentelemetry-instrumentation-flask" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -3879,14 +3867,14 @@ dependencies = [ { name = "opentelemetry-util-http" }, { name = "packaging" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d9/33/d6852d8f2c3eef86f2f8c858d6f5315983c7063e07e595519e96d4c31c06/opentelemetry_instrumentation_flask-0.61b0.tar.gz", hash = "sha256:e9faf58dfd9860a1868442d180142645abdafc1a652dd73d469a5efd106a7d49", size = 24071, upload-time = "2026-03-04T14:20:33.437Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/86/522294f6a80d59560d8f722da59513d2ed2d53c6178fa109789dacc5dd50/opentelemetry_instrumentation_flask-0.62b0.tar.gz", hash = "sha256:330e903c0e92b06aae32f9eb7b8a923599d7a29440f50841a59dbba34ec6dd9f", size = 24100, upload-time = "2026-04-09T14:40:37.111Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3e/41/619f3530324a58491f2d20f216a10dd7393629b29db4610dda642a27f4ed/opentelemetry_instrumentation_flask-0.61b0-py3-none-any.whl", hash = "sha256:e8ce474d7ce543bfbbb3e93f8a6f8263348af9d7b45502f387420cf3afa71253", size = 15996, upload-time = "2026-03-04T14:19:31.304Z" }, + { url = "https://files.pythonhosted.org/packages/bc/c8/9f3bb38281bcb50c93c3d2358b303645f6917bf972c167484c09f9a97ff1/opentelemetry_instrumentation_flask-0.62b0-py3-none-any.whl", hash = "sha256:8c1f8986ec3887d08899d2eb654625252c929105174911b3b50dcf12b1001807", size = 16006, upload-time = "2026-04-09T14:39:44.401Z" }, ] [[package]] name = "opentelemetry-instrumentation-httpx" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -3895,14 +3883,14 @@ dependencies = [ { name = "opentelemetry-util-http" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cd/2a/e2becd55e33c29d1d9ef76e2579040ed1951cb33bacba259f6aff2fdd2a6/opentelemetry_instrumentation_httpx-0.61b0.tar.gz", hash = "sha256:6569ec097946c5551c2a4252f74c98666addd1bf047c1dde6b4ef426719ff8dd", size = 24104, upload-time = "2026-03-04T14:20:34.752Z" } +sdist = { url = "https://files.pythonhosted.org/packages/77/a7/63e2c6325c8e99cd9b8e0229a8b61c37520ee537214a2c8d514e84486a94/opentelemetry_instrumentation_httpx-0.62b0.tar.gz", hash = "sha256:d865398db3f3c289ba226e355bf4d94460a4301c0c8916e3136caea55ae18000", size = 24182, upload-time = "2026-04-09T14:40:38.719Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/af/88/dde310dce56e2d85cf1a09507f5888544955309edc4b8d22971d6d3d1417/opentelemetry_instrumentation_httpx-0.61b0-py3-none-any.whl", hash = "sha256:dee05c93a6593a5dc3ae5d9d5c01df8b4e2c5d02e49275e5558534ee46343d5e", size = 17198, upload-time = "2026-03-04T14:19:33.585Z" }, + { url = "https://files.pythonhosted.org/packages/c0/5e/7d5fc28487637871b015128cd5dbb3c36f6d343a9098b893bd803d5a9cca/opentelemetry_instrumentation_httpx-0.62b0-py3-none-any.whl", hash = "sha256:c7660b939c12608fec67743126e9b4dc23dceef0ed631c415924966b0d1579e3", size = 17200, upload-time = "2026-04-09T14:39:46.618Z" }, ] [[package]] name = "opentelemetry-instrumentation-redis" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -3910,14 +3898,14 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cf/21/26205f89358a5f2be3ee5512d3d3bce16b622977f64aeaa9d3fa8887dd39/opentelemetry_instrumentation_redis-0.61b0.tar.gz", hash = "sha256:ae0fbb56be9a641e621d55b02a7d62977a2c77c5ee760addd79b9b266e46e523", size = 14781, upload-time = "2026-03-04T14:20:45.694Z" } +sdist = { url = "https://files.pythonhosted.org/packages/55/7d/5acdb4e4e36c522f9393cfa91f7a431ee089663c77855e524bc97f993020/opentelemetry_instrumentation_redis-0.62b0.tar.gz", hash = "sha256:513bc6679ee251436f0aff7be7ddab6186637dde09a795a8dc9659103f103bef", size = 14796, upload-time = "2026-04-09T14:40:48.391Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/e1/8f4c8e4194291dbe828aeabe779050a8497b379ad90040a5a0a7074b1d08/opentelemetry_instrumentation_redis-0.61b0-py3-none-any.whl", hash = "sha256:8d4e850bbb5f8eeafa44c0eac3a007990c7125de187bc9c3659e29ff7e091172", size = 15506, upload-time = "2026-03-04T14:19:48.588Z" }, + { url = "https://files.pythonhosted.org/packages/de/42/a13a7da074c972a51c14277e7f747e90037b9d815515c73b802e95897690/opentelemetry_instrumentation_redis-0.62b0-py3-none-any.whl", hash = "sha256:92ada3d7bdf395785f660549b0e6e8e5bac7cab80e7f1369a7d02228b27684c3", size = 15501, upload-time = "2026-04-09T14:40:00.69Z" }, ] [[package]] name = "opentelemetry-instrumentation-sqlalchemy" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -3926,14 +3914,14 @@ dependencies = [ { name = "packaging" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9e/4f/3a325b180944610697a0a926d49d782b41a86120050d44fefb2715b630ac/opentelemetry_instrumentation_sqlalchemy-0.61b0.tar.gz", hash = "sha256:13a3a159a2043a52f0180b3757fbaa26741b0e08abb50deddce4394c118956e6", size = 15343, upload-time = "2026-03-04T14:20:47.648Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/3d/40adc8c38e5be017ceb230a28ca57ca81981d4dc0c4b902cc930c77fd14f/opentelemetry_instrumentation_sqlalchemy-0.62b0.tar.gz", hash = "sha256:d02f85b83f349e9ef70a34cb3f4c3a3481fa15b11747f09209818663e161cac4", size = 18539, upload-time = "2026-04-09T14:40:50.251Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1f/97/b906a930c6a1a20c53ecc8b58cabc2cdd0ce560a2b5d44259084ffe4333e/opentelemetry_instrumentation_sqlalchemy-0.61b0-py3-none-any.whl", hash = "sha256:f115e0be54116ba4c327b8d7b68db4045ee18d44439d888ab8130a549c50d1c1", size = 14547, upload-time = "2026-03-04T14:19:53.088Z" }, + { url = "https://files.pythonhosted.org/packages/e7/e0/77954ac593f34740dc32e28a15fe7170e90f6ba6398eaaa5c88b34c05ed1/opentelemetry_instrumentation_sqlalchemy-0.62b0-py3-none-any.whl", hash = "sha256:ec576e0660080d9d15ce4fa44d2a07fff8cb4b796a84344cb0f2c9e5d6e26f79", size = 15534, upload-time = "2026-04-09T14:40:03.957Z" }, ] [[package]] name = "opentelemetry-instrumentation-wsgi" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -3941,75 +3929,75 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/89/e5/189f2845362cfe78e356ba127eab21456309def411c6874aa4800c3de816/opentelemetry_instrumentation_wsgi-0.61b0.tar.gz", hash = "sha256:380f2ae61714e5303275a80b2e14c58571573cd1fddf496d8c39fb9551c5e532", size = 19898, upload-time = "2026-03-04T14:20:54.068Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/5c/ed45ff053d76c94c59173f2bcde3d61052adb10214f70f028f760aa56625/opentelemetry_instrumentation_wsgi-0.62b0.tar.gz", hash = "sha256:d179f969ecce0c29a15ffd4d982580dfae57c8ff2fd4d9366e299a6d4815e668", size = 19922, upload-time = "2026-04-09T14:40:56.227Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/75/d6b42ba26f3c921be6d01b16561b7bb863f843bad7ac3a5011f62617bcab/opentelemetry_instrumentation_wsgi-0.61b0-py3-none-any.whl", hash = "sha256:bd33b0824166f24134a3400648805e8d2e6a7951f070241294e8b8866611d7fa", size = 14628, upload-time = "2026-03-04T14:20:03.934Z" }, + { url = "https://files.pythonhosted.org/packages/f6/cb/753dbbe624df88594fa35a3ff26302fea22623385ed64462f6c8ee7c81eb/opentelemetry_instrumentation_wsgi-0.62b0-py3-none-any.whl", hash = "sha256:2714ab5ab2f35e67dc181ffa3a43fa15313c85c09b4d024c36d72cf1efa29c9a", size = 14628, upload-time = "2026-04-09T14:40:13.529Z" }, ] [[package]] name = "opentelemetry-propagator-b3" -version = "1.40.0" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/eb/fe/e0c84af5c654ec42165ba57af83c7f67e4b8af77f836ddc29dee59ff73c6/opentelemetry_propagator_b3-1.40.0.tar.gz", hash = "sha256:59b6925498947c08a1b7e0dd38193ff97e5009bec74ec23824300c2e32f77bcf", size = 9587, upload-time = "2026-03-04T14:17:30.079Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ef/43/cea77e171c014324876104cf2a17c78f5e931408b977b9e64979f950912c/opentelemetry_propagator_b3-1.41.0.tar.gz", hash = "sha256:ef98b715b3a05e8b0b03ebaea1bf295b4ad61a0e306e2d1da81d32af7395e6ad", size = 9588, upload-time = "2026-04-09T14:38:43.328Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/84/8654cc0539b5145046b2e60d058cebad401a600dd0b1240f1711c6788643/opentelemetry_propagator_b3-1.40.0-py3-none-any.whl", hash = "sha256:cb72a1698fd1d1b434f70dc90c1de62da8ade1dd84850d1f040eccf6a420fa7b", size = 8922, upload-time = "2026-03-04T14:17:14.732Z" }, + { url = "https://files.pythonhosted.org/packages/50/c1/11345c06774ec6ed6d89e3994dd1f62ad2ab41dfeb312eacd6b2a2323280/opentelemetry_propagator_b3-1.41.0-py3-none-any.whl", hash = "sha256:0b085c26ba59fcb66771226f967e91886bdeef998b3b5f2e9da6a604918c6f90", size = 8923, upload-time = "2026-04-09T14:38:26.865Z" }, ] [[package]] name = "opentelemetry-proto" -version = "1.40.0" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4c/77/dd38991db037fdfce45849491cb61de5ab000f49824a00230afb112a4392/opentelemetry_proto-1.40.0.tar.gz", hash = "sha256:03f639ca129ba513f5819810f5b1f42bcb371391405d99c168fe6937c62febcd", size = 45667, upload-time = "2026-03-04T14:17:31.194Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/d9/08e3dc6156878713e8c811682bc76151f5fe1a3cb7f3abda3966fd56e71e/opentelemetry_proto-1.41.0.tar.gz", hash = "sha256:95d2e576f9fb1800473a3e4cfcca054295d06bdb869fda4dc9f4f779dc68f7b6", size = 45669, upload-time = "2026-04-09T14:38:45.978Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/b2/189b2577dde745b15625b3214302605b1353436219d42b7912e77fa8dc24/opentelemetry_proto-1.40.0-py3-none-any.whl", hash = "sha256:266c4385d88923a23d63e353e9761af0f47a6ed0d486979777fe4de59dc9b25f", size = 72073, upload-time = "2026-03-04T14:17:16.673Z" }, + { url = "https://files.pythonhosted.org/packages/49/8c/65ef7a9383a363864772022e822b5d5c6988e6f9dabeebb9278f5b86ebc3/opentelemetry_proto-1.41.0-py3-none-any.whl", hash = "sha256:b970ab537309f9eed296be482c3e7cca05d8aca8165346e929f658dbe153b247", size = 72074, upload-time = "2026-04-09T14:38:29.38Z" }, ] [[package]] name = "opentelemetry-sdk" -version = "1.40.0" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-semantic-conventions" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/58/fd/3c3125b20ba18ce2155ba9ea74acb0ae5d25f8cd39cfd37455601b7955cc/opentelemetry_sdk-1.40.0.tar.gz", hash = "sha256:18e9f5ec20d859d268c7cb3c5198c8d105d073714db3de50b593b8c1345a48f2", size = 184252, upload-time = "2026-03-04T14:17:31.87Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/0e/a586df1186f9f56b5a0879d52653effc40357b8e88fc50fe300038c3c08b/opentelemetry_sdk-1.41.0.tar.gz", hash = "sha256:7bddf3961131b318fc2d158947971a8e37e38b1cd23470cfb72b624e7cc108bd", size = 230181, upload-time = "2026-04-09T14:38:47.225Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/c5/6a852903d8bfac758c6dc6e9a68b015d3c33f2f1be5e9591e0f4b69c7e0a/opentelemetry_sdk-1.40.0-py3-none-any.whl", hash = "sha256:787d2154a71f4b3d81f20524a8ce061b7db667d24e46753f32a7bc48f1c1f3f1", size = 141951, upload-time = "2026-03-04T14:17:17.961Z" }, + { url = "https://files.pythonhosted.org/packages/2c/13/a7825118208cb32e6a4edcd0a99f925cbef81e77b3b0aedfd9125583c543/opentelemetry_sdk-1.41.0-py3-none-any.whl", hash = "sha256:a596f5687964a3e0d7f8edfdcf5b79cbca9c93c7025ebf5fb00f398a9443b0bd", size = 180214, upload-time = "2026-04-09T14:38:30.657Z" }, ] [[package]] name = "opentelemetry-semantic-conventions" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6d/c0/4ae7973f3c2cfd2b6e321f1675626f0dab0a97027cc7a297474c9c8f3d04/opentelemetry_semantic_conventions-0.61b0.tar.gz", hash = "sha256:072f65473c5d7c6dc0355b27d6c9d1a679d63b6d4b4b16a9773062cb7e31192a", size = 145755, upload-time = "2026-03-04T14:17:32.664Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/b0/c14f723e86c049b7bf8ff431160d982519b97a7be2857ed2247377397a24/opentelemetry_semantic_conventions-0.62b0.tar.gz", hash = "sha256:cbfb3c8fc259575cf68a6e1b94083cc35adc4a6b06e8cf431efa0d62606c0097", size = 145753, upload-time = "2026-04-09T14:38:48.274Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/37/cc6a55e448deaa9b27377d087da8615a3416d8ad523d5960b78dbeadd02a/opentelemetry_semantic_conventions-0.61b0-py3-none-any.whl", hash = "sha256:fa530a96be229795f8cef353739b618148b0fe2b4b3f005e60e262926c4d38e2", size = 231621, upload-time = "2026-03-04T14:17:19.33Z" }, + { url = "https://files.pythonhosted.org/packages/58/6c/5e86fa1759a525ef91c2d8b79d668574760ff3f900d114297765eb8786cb/opentelemetry_semantic_conventions-0.62b0-py3-none-any.whl", hash = "sha256:0ddac1ce59eaf1a827d9987ab60d9315fb27aea23304144242d1fcad9e16b489", size = 231619, upload-time = "2026-04-09T14:38:32.394Z" }, ] [[package]] name = "opentelemetry-util-http" -version = "0.61b0" +version = "0.62b0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/57/3c/f0196223efc5c4ca19f8fad3d5462b171ac6333013335ce540c01af419e9/opentelemetry_util_http-0.61b0.tar.gz", hash = "sha256:1039cb891334ad2731affdf034d8fb8b48c239af9b6dd295e5fabd07f1c95572", size = 11361, upload-time = "2026-03-04T14:20:57.01Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9b/e7/830f7c57135158eb8a8efd3f94ab191a89e3b8a49bed314a35ee501da3f2/opentelemetry_util_http-0.62b0.tar.gz", hash = "sha256:a62e4b19b8a432c0de657f167dee3455516136bb9c6ed463ca8063019970d835", size = 11393, upload-time = "2026-04-09T14:40:59.442Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0d/e5/c08aaaf2f64288d2b6ef65741d2de5454e64af3e050f34285fb1907492fe/opentelemetry_util_http-0.61b0-py3-none-any.whl", hash = "sha256:8e715e848233e9527ea47e275659ea60a57a75edf5206a3b937e236a6da5fc33", size = 9281, upload-time = "2026-03-04T14:20:08.364Z" }, + { url = "https://files.pythonhosted.org/packages/3d/7f/5c1b7d4385852b9e5eacd4e7f9d8b565d3d351d17463b24916ad098adf1a/opentelemetry_util_http-0.62b0-py3-none-any.whl", hash = "sha256:c20462808d8cc95b69b0dc4a3e02a9d36beb663347e96c931f51ffd78bd318ad", size = 9294, upload-time = "2026-04-09T14:40:19.014Z" }, ] [[package]] name = "opik" -version = "1.10.58" +version = "1.11.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "boto3-stubs", extra = ["bedrock-runtime"] }, @@ -4027,22 +4015,23 @@ dependencies = [ { name = "tenacity" }, { name = "tqdm" }, { name = "uuid6" }, + { name = "watchfiles" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/52/bc/54673138cf374226ab9fcdd5685e92442c0d5a95775ff22b870c767387e6/opik-1.10.58.tar.gz", hash = "sha256:058f8b3e3171a1f5e75f25cf1fea392b8f2e0ddba18765fafd24cd756783002b", size = 833671, upload-time = "2026-04-01T11:43:21.571Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/b9/f6c7e41cb6c02f6e68fde9b6dacf377dcf42079cdbaf891f9fecf4dc958b/opik-1.11.2.tar.gz", hash = "sha256:79e054595b29e1ca8a4fd67d023249f0cf355ea9efbe3e00c28f51628d053d63", size = 871557, upload-time = "2026-04-10T10:48:14.965Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/33/9a/99cf048209f10f8444544202b007d5fbe0a6104465d29038b25932b1c79f/opik-1.10.58-py3-none-any.whl", hash = "sha256:29be9d7f846f3229a027250997195e583da840179ad03f3d28b1d613687963e3", size = 1400658, upload-time = "2026-04-01T11:43:20.096Z" }, + { url = "https://files.pythonhosted.org/packages/99/2d/e5536a2a1b6fdd920d995e09315523be53bde5fe01f104894d9ba7421a8c/opik-1.11.2-py3-none-any.whl", hash = "sha256:1016b6db7563d847e50e463a2ae09e595b6921372dd52edeada660b82036e1b2", size = 1451056, upload-time = "2026-04-10T10:48:12.927Z" }, ] [[package]] name = "optype" -version = "0.14.0" +version = "0.17.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/ca/d3a2abcf12cc8c18ccac1178ef87ab50a235bf386d2401341776fdad18aa/optype-0.14.0.tar.gz", hash = "sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6", size = 100880, upload-time = "2025-10-01T04:49:56.232Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/9f/3b13bab05debf685678b8af004e46b8c67c6f98ffa08eaf5d33bcf162c16/optype-0.17.0.tar.gz", hash = "sha256:31351a1e64d9eba7bf67e14deefb286e85c66458db63c67dd5e26dd72e4664e5", size = 53484, upload-time = "2026-03-08T23:03:12.594Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/84/a6/11b0eb65eeafa87260d36858b69ec4e0072d09e37ea6714280960030bc93/optype-0.14.0-py3-none-any.whl", hash = "sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b", size = 89465, upload-time = "2025-10-01T04:49:54.674Z" }, + { url = "https://files.pythonhosted.org/packages/6b/44/dca78187415947d1bb90b2ee2a58e47d9573528331e8dc6196996b53612a/optype-0.17.0-py3-none-any.whl", hash = "sha256:8c2d88ff13149454bcf6eb47502f80d288bc542e7238fcc412ac4d222c439397", size = 65854, upload-time = "2026-03-08T23:03:11.425Z" }, ] [package.optional-dependencies] @@ -4116,32 +4105,32 @@ wheels = [ [[package]] name = "packaging" -version = "23.2" +version = "26.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fb/2b/9b9c33ffed44ee921d0967086d653047286054117d584f1b1a7c22ceaf7b/packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5", size = 146714, upload-time = "2023-10-01T13:50:05.279Z" } +sdist = { url = "https://files.pythonhosted.org/packages/65/ee/299d360cdc32edc7d2cf530f3accf79c4fca01e96ffc950d8a52213bd8e4/packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4", size = 143416, upload-time = "2026-01-21T20:50:39.064Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/1a/610693ac4ee14fcdf2d9bf3c493370e4f2ef7ae2e19217d7a237ff42367d/packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7", size = 53011, upload-time = "2023-10-01T13:50:03.745Z" }, + { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, ] [[package]] name = "pandas" -version = "3.0.1" +version = "3.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "python-dateutil" }, { name = "tzdata", marker = "sys_platform == 'emscripten' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2e/0c/b28ed414f080ee0ad153f848586d61d1878f91689950f037f976ce15f6c8/pandas-3.0.1.tar.gz", hash = "sha256:4186a699674af418f655dbd420ed87f50d56b4cd6603784279d9eef6627823c8", size = 4641901, upload-time = "2026-02-17T22:20:16.434Z" } +sdist = { url = "https://files.pythonhosted.org/packages/da/99/b342345300f13440fe9fe385c3c481e2d9a595ee3bab4d3219247ac94e9a/pandas-3.0.2.tar.gz", hash = "sha256:f4753e73e34c8d83221ba58f232433fca2748be8b18dbca02d242ed153945043", size = 4645855, upload-time = "2026-03-31T06:48:30.816Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/37/51/b467209c08dae2c624873d7491ea47d2b47336e5403309d433ea79c38571/pandas-3.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:476f84f8c20c9f5bc47252b66b4bb25e1a9fc2fa98cead96744d8116cb85771d", size = 10344357, upload-time = "2026-02-17T22:18:38.262Z" }, - { url = "https://files.pythonhosted.org/packages/7c/f1/e2567ffc8951ab371db2e40b2fe068e36b81d8cf3260f06ae508700e5504/pandas-3.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0ab749dfba921edf641d4036c4c21c0b3ea70fea478165cb98a998fb2a261955", size = 9884543, upload-time = "2026-02-17T22:18:41.476Z" }, - { url = "https://files.pythonhosted.org/packages/d7/39/327802e0b6d693182403c144edacbc27eb82907b57062f23ef5a4c4a5ea7/pandas-3.0.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b8e36891080b87823aff3640c78649b91b8ff6eea3c0d70aeabd72ea43ab069b", size = 10396030, upload-time = "2026-02-17T22:18:43.822Z" }, - { url = "https://files.pythonhosted.org/packages/3d/fe/89d77e424365280b79d99b3e1e7d606f5165af2f2ecfaf0c6d24c799d607/pandas-3.0.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:532527a701281b9dd371e2f582ed9094f4c12dd9ffb82c0c54ee28d8ac9520c4", size = 10876435, upload-time = "2026-02-17T22:18:45.954Z" }, - { url = "https://files.pythonhosted.org/packages/b5/a6/2a75320849dd154a793f69c951db759aedb8d1dd3939eeacda9bdcfa1629/pandas-3.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:356e5c055ed9b0da1580d465657bc7d00635af4fd47f30afb23025352ba764d1", size = 11405133, upload-time = "2026-02-17T22:18:48.533Z" }, - { url = "https://files.pythonhosted.org/packages/58/53/1d68fafb2e02d7881df66aa53be4cd748d25cbe311f3b3c85c93ea5d30ca/pandas-3.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9d810036895f9ad6345b8f2a338dd6998a74e8483847403582cab67745bff821", size = 11932065, upload-time = "2026-02-17T22:18:50.837Z" }, - { url = "https://files.pythonhosted.org/packages/75/08/67cc404b3a966b6df27b38370ddd96b3b023030b572283d035181854aac5/pandas-3.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:536232a5fe26dd989bd633e7a0c450705fdc86a207fec7254a55e9a22950fe43", size = 9741627, upload-time = "2026-02-17T22:18:53.905Z" }, - { url = "https://files.pythonhosted.org/packages/86/4f/caf9952948fb00d23795f09b893d11f1cacb384e666854d87249530f7cbe/pandas-3.0.1-cp312-cp312-win_arm64.whl", hash = "sha256:0f463ebfd8de7f326d38037c7363c6dacb857c5881ab8961fb387804d6daf2f7", size = 9052483, upload-time = "2026-02-17T22:18:57.31Z" }, + { url = "https://files.pythonhosted.org/packages/f3/b0/c20bd4d6d3f736e6bd6b55794e9cd0a617b858eaad27c8f410ea05d953b7/pandas-3.0.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:232a70ebb568c0c4d2db4584f338c1577d81e3af63292208d615907b698a0f18", size = 10347921, upload-time = "2026-03-31T06:46:33.36Z" }, + { url = "https://files.pythonhosted.org/packages/35/d0/4831af68ce30cc2d03c697bea8450e3225a835ef497d0d70f31b8cdde965/pandas-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:970762605cff1ca0d3f71ed4f3a769ea8f85fc8e6348f6e110b8fea7e6eb5a14", size = 9888127, upload-time = "2026-03-31T06:46:36.253Z" }, + { url = "https://files.pythonhosted.org/packages/61/a9/16ea9346e1fc4a96e2896242d9bc674764fb9049b0044c0132502f7a771e/pandas-3.0.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aff4e6f4d722e0652707d7bcb190c445fe58428500c6d16005b02401764b1b3d", size = 10399577, upload-time = "2026-03-31T06:46:39.224Z" }, + { url = "https://files.pythonhosted.org/packages/c4/a8/3a61a721472959ab0ce865ef05d10b0d6bfe27ce8801c99f33d4fa996e65/pandas-3.0.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ef8b27695c3d3dc78403c9a7d5e59a62d5464a7e1123b4e0042763f7104dc74f", size = 10880030, upload-time = "2026-03-31T06:46:42.412Z" }, + { url = "https://files.pythonhosted.org/packages/da/65/7225c0ea4d6ce9cb2160a7fb7f39804871049f016e74782e5dade4d14109/pandas-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f8d68083e49e16b84734eb1a4dcae4259a75c90fb6e2251ab9a00b61120c06ab", size = 11409468, upload-time = "2026-03-31T06:46:45.2Z" }, + { url = "https://files.pythonhosted.org/packages/fa/5b/46e7c76032639f2132359b5cf4c785dd8cf9aea5ea64699eac752f02b9db/pandas-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:32cc41f310ebd4a296d93515fcac312216adfedb1894e879303987b8f1e2b97d", size = 11936381, upload-time = "2026-03-31T06:46:48.293Z" }, + { url = "https://files.pythonhosted.org/packages/7b/8b/721a9cff6fa6a91b162eb51019c6243b82b3226c71bb6c8ef4a9bd65cbc6/pandas-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:a4785e1d6547d8427c5208b748ae2efb64659a21bd82bf440d4262d02bfa02a4", size = 9744993, upload-time = "2026-03-31T06:46:51.488Z" }, + { url = "https://files.pythonhosted.org/packages/d5/18/7f0bd34ae27b28159aa80f2a6799f47fda34f7fb938a76e20c7b7fe3b200/pandas-3.0.2-cp312-cp312-win_arm64.whl", hash = "sha256:08504503f7101300107ecdc8df73658e4347586db5cfdadabc1592e9d7e7a0fd", size = 9056118, upload-time = "2026-03-31T06:46:54.548Z" }, ] [package.optional-dependencies] @@ -4658,11 +4647,11 @@ wheels = [ [[package]] name = "pyjwt" -version = "2.12.0" +version = "2.12.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a8/10/e8192be5f38f3e8e7e046716de4cae33d56fd5ae08927a823bb916be36c1/pyjwt-2.12.0.tar.gz", hash = "sha256:2f62390b667cd8257de560b850bb5a883102a388829274147f1d724453f8fb02", size = 102511, upload-time = "2026-03-12T17:15:30.831Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c2/27/a3b6e5bf6ff856d2509292e95c8f57f0df7017cf5394921fc4e4ef40308a/pyjwt-2.12.1.tar.gz", hash = "sha256:c74a7a2adf861c04d002db713dd85f84beb242228e671280bf709d765b03672b", size = 102564, upload-time = "2026-03-13T19:27:37.25Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/15/70/70f895f404d363d291dcf62c12c85fdd47619ad9674ac0f53364d035925a/pyjwt-2.12.0-py3-none-any.whl", hash = "sha256:9bb459d1bdd0387967d287f5656bf7ec2b9a26645d1961628cda1764e087fd6e", size = 29700, upload-time = "2026-03-12T17:15:29.257Z" }, + { url = "https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl", hash = "sha256:28ca37c070cad8ba8cd9790cd940535d40274d22f80ab87f3ac6a713e6e8454c", size = 29726, upload-time = "2026-03-13T19:27:35.677Z" }, ] [package.optional-dependencies] @@ -4672,7 +4661,7 @@ crypto = [ [[package]] name = "pymilvus" -version = "2.6.11" +version = "2.6.12" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -4684,9 +4673,9 @@ dependencies = [ { name = "requests" }, { name = "setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/e6/0adc3b374f5c5d1eebd4f551b455c6865c449b170b17545001b208e2b153/pymilvus-2.6.11.tar.gz", hash = "sha256:a40c10322cde25184a8c3d84993a14dfb67ad2bdcfc5dff7e68b11a79ff8f6d8", size = 1583634, upload-time = "2026-03-27T06:25:46.023Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/d7/c5d1381248a33975ccc864a0f980f93270ecc35354de8646c8a16443cccb/pymilvus-2.6.12.tar.gz", hash = "sha256:8323e990dc305e607fef525498eb779e42940a69e0691dde009cd02d48845f7a", size = 1584521, upload-time = "2026-04-09T07:49:11.374Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9c/1c/bccb331d71f824738f80f11e9b8b4da47973c903826355526ae4fa2b762f/pymilvus-2.6.11-py3-none-any.whl", hash = "sha256:a11e1718b15045361c71ca671b959900cb7e2faae863c896f6b7e87bf2e4d10a", size = 315252, upload-time = "2026-03-27T06:25:44.215Z" }, + { url = "https://files.pythonhosted.org/packages/ce/5d/44b0fa94c91503381e6f12298277f84f8e7b0bb00715ab89fc273c4d681e/pymilvus-2.6.12-py3-none-any.whl", hash = "sha256:69051b8b62712f157b2b50aeb7bde7fd7cdb5940aac0122094eb3cd58bc20f0d", size = 315183, upload-time = "2026-04-09T07:49:09.013Z" }, ] [[package]] @@ -4763,11 +4752,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.9.2" +version = "6.10.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b8/9f/ca96abf18683ca12602065e4ed2bec9050b672c87d317f1079abc7b6d993/pypdf-6.10.0.tar.gz", hash = "sha256:4c5a48ba258c37024ec2505f7e8fd858525f5502784a2e1c8d415604af29f6ef", size = 5314833, upload-time = "2026-04-10T09:34:57.102Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" }, + { url = "https://files.pythonhosted.org/packages/55/f2/7ebe366f633f30a6ad105f650f44f24f98cb1335c4157d21ae47138b3482/pypdf-6.10.0-py3-none-any.whl", hash = "sha256:90005e959e1596c6e6c84c8b0ad383285b3e17011751cedd17f2ce8fcdfc86de", size = 334459, upload-time = "2026-04-10T09:34:54.966Z" }, ] [[package]] @@ -4842,7 +4831,7 @@ wheels = [ [[package]] name = "pytest" -version = "9.0.2" +version = "9.0.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, @@ -4851,9 +4840,9 @@ dependencies = [ { name = "pluggy" }, { name = "pygments" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, + { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, ] [[package]] @@ -5248,15 +5237,15 @@ wheels = [ [[package]] name = "resend" -version = "2.26.0" +version = "2.27.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/07/ff/6a4e5e758fc2145c6a7d8563934d8ee24bf96a0212d7ec7d1af1f155bb74/resend-2.26.0.tar.gz", hash = "sha256:957a6a59dc597ce27fbd6d5383220dd9cc497fab99d4f3d775c8a42a449a569e", size = 36238, upload-time = "2026-03-20T22:49:09.728Z" } +sdist = { url = "https://files.pythonhosted.org/packages/96/da/3d342cacbde7143e36782243caa3715d9e49cadb43e804419493c784869b/resend-2.27.0.tar.gz", hash = "sha256:abc183da7566c1fdba8221ec5acd9f954c2ff516a0c2615bee2a41bc9db3e277", size = 37177, upload-time = "2026-04-01T21:19:31.823Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/16/c2/f88d3299d97aa1d36a923d0846fe185fcf5355ca898c954b2e5a79f090b5/resend-2.26.0-py2.py3-none-any.whl", hash = "sha256:5e25a804a84a68df504f2ade5369ac37e0139e37788a1f20b66c88696595b4bc", size = 57699, upload-time = "2026-03-20T22:49:08.354Z" }, + { url = "https://files.pythonhosted.org/packages/b4/95/783b09d24c8f40b900a2728b67fd3c1401d4a6afcdf1db1c8475c249559d/resend-2.27.0-py2.py3-none-any.whl", hash = "sha256:5bc8ddebb0418127fc3e47eb29ab72af727861481c4b051b96cb693df8f8dc40", size = 59831, upload-time = "2026-04-01T21:19:30.471Z" }, ] [[package]] @@ -5310,27 +5299,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.15.9" +version = "0.15.10" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e6/97/e9f1ca355108ef7194e38c812ef40ba98c7208f47b13ad78d023caa583da/ruff-0.15.9.tar.gz", hash = "sha256:29cbb1255a9797903f6dde5ba0188c707907ff44a9006eb273b5a17bfa0739a2", size = 4617361, upload-time = "2026-04-02T18:17:20.829Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/d9/aa3f7d59a10ef6b14fe3431706f854dbf03c5976be614a9796d36326810c/ruff-0.15.10.tar.gz", hash = "sha256:d1f86e67ebfdef88e00faefa1552b5e510e1d35f3be7d423dc7e84e63788c94e", size = 4631728, upload-time = "2026-04-09T14:06:09.884Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/1f/9cdfd0ac4b9d1e5a6cf09bedabdf0b56306ab5e333c85c87281273e7b041/ruff-0.15.9-py3-none-linux_armv6l.whl", hash = "sha256:6efbe303983441c51975c243e26dff328aca11f94b70992f35b093c2e71801e1", size = 10511206, upload-time = "2026-04-02T18:16:41.574Z" }, - { url = "https://files.pythonhosted.org/packages/3d/f6/32bfe3e9c136b35f02e489778d94384118bb80fd92c6d92e7ccd97db12ce/ruff-0.15.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:4965bac6ac9ea86772f4e23587746f0b7a395eccabb823eb8bfacc3fa06069f7", size = 10923307, upload-time = "2026-04-02T18:17:08.645Z" }, - { url = "https://files.pythonhosted.org/packages/ca/25/de55f52ab5535d12e7aaba1de37a84be6179fb20bddcbe71ec091b4a3243/ruff-0.15.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf05aad70ca5b5a0a4b0e080df3a6b699803916d88f006efd1f5b46302daab8", size = 10316722, upload-time = "2026-04-02T18:16:44.206Z" }, - { url = "https://files.pythonhosted.org/packages/48/11/690d75f3fd6278fe55fff7c9eb429c92d207e14b25d1cae4064a32677029/ruff-0.15.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9439a342adb8725f32f92732e2bafb6d5246bd7a5021101166b223d312e8fc59", size = 10623674, upload-time = "2026-04-02T18:16:50.951Z" }, - { url = "https://files.pythonhosted.org/packages/bd/ec/176f6987be248fc5404199255522f57af1b4a5a1b57727e942479fec98ad/ruff-0.15.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9c5e6faf9d97c8edc43877c3f406f47446fc48c40e1442d58cfcdaba2acea745", size = 10351516, upload-time = "2026-04-02T18:16:57.206Z" }, - { url = "https://files.pythonhosted.org/packages/b2/fc/51cffbd2b3f240accc380171d51446a32aa2ea43a40d4a45ada67368fbd2/ruff-0.15.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b34a9766aeec27a222373d0b055722900fbc0582b24f39661aa96f3fe6ad901", size = 11150202, upload-time = "2026-04-02T18:17:06.452Z" }, - { url = "https://files.pythonhosted.org/packages/d6/d4/25292a6dfc125f6b6528fe6af31f5e996e19bf73ca8e3ce6eb7fa5b95885/ruff-0.15.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:89dd695bc72ae76ff484ae54b7e8b0f6b50f49046e198355e44ea656e521fef9", size = 11988891, upload-time = "2026-04-02T18:17:18.575Z" }, - { url = "https://files.pythonhosted.org/packages/13/e1/1eebcb885c10e19f969dcb93d8413dfee8172578709d7ee933640f5e7147/ruff-0.15.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ce187224ef1de1bd225bc9a152ac7102a6171107f026e81f317e4257052916d5", size = 11480576, upload-time = "2026-04-02T18:16:52.986Z" }, - { url = "https://files.pythonhosted.org/packages/ff/6b/a1548ac378a78332a4c3dcf4a134c2475a36d2a22ddfa272acd574140b50/ruff-0.15.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b0c7c341f68adb01c488c3b7d4b49aa8ea97409eae6462d860a79cf55f431b6", size = 11254525, upload-time = "2026-04-02T18:17:02.041Z" }, - { url = "https://files.pythonhosted.org/packages/42/aa/4bb3af8e61acd9b1281db2ab77e8b2c3c5e5599bf2a29d4a942f1c62b8d6/ruff-0.15.9-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:55cc15eee27dc0eebdfcb0d185a6153420efbedc15eb1d38fe5e685657b0f840", size = 11204072, upload-time = "2026-04-02T18:17:13.581Z" }, - { url = "https://files.pythonhosted.org/packages/69/48/d550dc2aa6e423ea0bcc1d0ff0699325ffe8a811e2dba156bd80750b86dc/ruff-0.15.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a6537f6eed5cda688c81073d46ffdfb962a5f29ecb6f7e770b2dc920598997ed", size = 10594998, upload-time = "2026-04-02T18:16:46.369Z" }, - { url = "https://files.pythonhosted.org/packages/63/47/321167e17f5344ed5ec6b0aa2cff64efef5f9e985af8f5622cfa6536043f/ruff-0.15.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:6d3fcbca7388b066139c523bda744c822258ebdcfbba7d24410c3f454cc9af71", size = 10359769, upload-time = "2026-04-02T18:17:10.994Z" }, - { url = "https://files.pythonhosted.org/packages/67/5e/074f00b9785d1d2c6f8c22a21e023d0c2c1817838cfca4c8243200a1fa87/ruff-0.15.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:058d8e99e1bfe79d8a0def0b481c56059ee6716214f7e425d8e737e412d69677", size = 10850236, upload-time = "2026-04-02T18:16:48.749Z" }, - { url = "https://files.pythonhosted.org/packages/76/37/804c4135a2a2caf042925d30d5f68181bdbd4461fd0d7739da28305df593/ruff-0.15.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:8e1ddb11dbd61d5983fa2d7d6370ef3eb210951e443cace19594c01c72abab4c", size = 11358343, upload-time = "2026-04-02T18:16:55.068Z" }, - { url = "https://files.pythonhosted.org/packages/88/3d/1364fcde8656962782aa9ea93c92d98682b1ecec2f184e625a965ad3b4a6/ruff-0.15.9-py3-none-win32.whl", hash = "sha256:bde6ff36eaf72b700f32b7196088970bf8fdb2b917b7accd8c371bfc0fd573ec", size = 10583382, upload-time = "2026-04-02T18:17:04.261Z" }, - { url = "https://files.pythonhosted.org/packages/4c/56/5c7084299bd2cacaa07ae63a91c6f4ba66edc08bf28f356b24f6b717c799/ruff-0.15.9-py3-none-win_amd64.whl", hash = "sha256:45a70921b80e1c10cf0b734ef09421f71b5aa11d27404edc89d7e8a69505e43d", size = 11744969, upload-time = "2026-04-02T18:16:59.611Z" }, - { url = "https://files.pythonhosted.org/packages/03/36/76704c4f312257d6dbaae3c959add2a622f63fcca9d864659ce6d8d97d3d/ruff-0.15.9-py3-none-win_arm64.whl", hash = "sha256:0694e601c028fd97dc5c6ee244675bc241aeefced7ef80cd9c6935a871078f53", size = 11005870, upload-time = "2026-04-02T18:17:15.773Z" }, + { url = "https://files.pythonhosted.org/packages/eb/00/a1c2fdc9939b2c03691edbda290afcd297f1f389196172826b03d6b6a595/ruff-0.15.10-py3-none-linux_armv6l.whl", hash = "sha256:0744e31482f8f7d0d10a11fcbf897af272fefdfcb10f5af907b18c2813ff4d5f", size = 10563362, upload-time = "2026-04-09T14:06:21.189Z" }, + { url = "https://files.pythonhosted.org/packages/5c/15/006990029aea0bebe9d33c73c3e28c80c391ebdba408d1b08496f00d422d/ruff-0.15.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b1e7c16ea0ff5a53b7c2df52d947e685973049be1cdfe2b59a9c43601897b22e", size = 10951122, upload-time = "2026-04-09T14:06:02.236Z" }, + { url = "https://files.pythonhosted.org/packages/f2/c0/4ac978fe874d0618c7da647862afe697b281c2806f13ce904ad652fa87e4/ruff-0.15.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:93cc06a19e5155b4441dd72808fdf84290d84ad8a39ca3b0f994363ade4cebb1", size = 10314005, upload-time = "2026-04-09T14:06:00.026Z" }, + { url = "https://files.pythonhosted.org/packages/da/73/c209138a5c98c0d321266372fc4e33ad43d506d7e5dd817dd89b60a8548f/ruff-0.15.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83e1dd04312997c99ea6965df66a14fb4f03ba978564574ffc68b0d61fd3989e", size = 10643450, upload-time = "2026-04-09T14:05:42.137Z" }, + { url = "https://files.pythonhosted.org/packages/ec/76/0deec355d8ec10709653635b1f90856735302cb8e149acfdf6f82a5feb70/ruff-0.15.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8154d43684e4333360fedd11aaa40b1b08a4e37d8ffa9d95fee6fa5b37b6fab1", size = 10379597, upload-time = "2026-04-09T14:05:49.984Z" }, + { url = "https://files.pythonhosted.org/packages/dc/be/86bba8fc8798c081e28a4b3bb6d143ccad3fd5f6f024f02002b8f08a9fa3/ruff-0.15.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ab88715f3a6deb6bde6c227f3a123410bec7b855c3ae331b4c006189e895cef", size = 11146645, upload-time = "2026-04-09T14:06:12.246Z" }, + { url = "https://files.pythonhosted.org/packages/a8/89/140025e65911b281c57be1d385ba1d932c2366ca88ae6663685aed8d4881/ruff-0.15.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a768ff5969b4f44c349d48edf4ab4f91eddb27fd9d77799598e130fb628aa158", size = 12030289, upload-time = "2026-04-09T14:06:04.776Z" }, + { url = "https://files.pythonhosted.org/packages/88/de/ddacca9545a5e01332567db01d44bd8cf725f2db3b3d61a80550b48308ea/ruff-0.15.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ee3ef42dab7078bda5ff6a1bcba8539e9857deb447132ad5566a038674540d0", size = 11496266, upload-time = "2026-04-09T14:05:55.485Z" }, + { url = "https://files.pythonhosted.org/packages/bc/bb/7ddb00a83760ff4a83c4e2fc231fd63937cc7317c10c82f583302e0f6586/ruff-0.15.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51cb8cc943e891ba99989dd92d61e29b1d231e14811db9be6440ecf25d5c1609", size = 11256418, upload-time = "2026-04-09T14:05:57.69Z" }, + { url = "https://files.pythonhosted.org/packages/dc/8d/55de0d35aacf6cd50b6ee91ee0f291672080021896543776f4170fc5c454/ruff-0.15.10-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:e59c9bdc056a320fb9ea1700a8d591718b8faf78af065484e801258d3a76bc3f", size = 11288416, upload-time = "2026-04-09T14:05:44.695Z" }, + { url = "https://files.pythonhosted.org/packages/68/cf/9438b1a27426ec46a80e0a718093c7f958ef72f43eb3111862949ead3cc1/ruff-0.15.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:136c00ca2f47b0018b073f28cb5c1506642a830ea941a60354b0e8bc8076b151", size = 10621053, upload-time = "2026-04-09T14:05:52.782Z" }, + { url = "https://files.pythonhosted.org/packages/4c/50/e29be6e2c135e9cd4cb15fbade49d6a2717e009dff3766dd080fcb82e251/ruff-0.15.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8b80a2f3c9c8a950d6237f2ca12b206bccff626139be9fa005f14feb881a1ae8", size = 10378302, upload-time = "2026-04-09T14:06:14.361Z" }, + { url = "https://files.pythonhosted.org/packages/18/2f/e0b36a6f99c51bb89f3a30239bc7bf97e87a37ae80aa2d6542d6e5150364/ruff-0.15.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:e3e53c588164dc025b671c9df2462429d60357ea91af7e92e9d56c565a9f1b07", size = 10850074, upload-time = "2026-04-09T14:06:16.581Z" }, + { url = "https://files.pythonhosted.org/packages/11/08/874da392558ce087a0f9b709dc6ec0d60cbc694c1c772dab8d5f31efe8cb/ruff-0.15.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b0c52744cf9f143a393e284125d2576140b68264a93c6716464e129a3e9adb48", size = 11358051, upload-time = "2026-04-09T14:06:18.948Z" }, + { url = "https://files.pythonhosted.org/packages/e4/46/602938f030adfa043e67112b73821024dc79f3ab4df5474c25fa4c1d2d14/ruff-0.15.10-py3-none-win32.whl", hash = "sha256:d4272e87e801e9a27a2e8df7b21011c909d9ddd82f4f3281d269b6ba19789ca5", size = 10588964, upload-time = "2026-04-09T14:06:07.14Z" }, + { url = "https://files.pythonhosted.org/packages/25/b6/261225b875d7a13b33a6d02508c39c28450b2041bb01d0f7f1a83d569512/ruff-0.15.10-py3-none-win_amd64.whl", hash = "sha256:28cb32d53203242d403d819fd6983152489b12e4a3ae44993543d6fe62ab42ed", size = 11745044, upload-time = "2026-04-09T14:05:39.473Z" }, + { url = "https://files.pythonhosted.org/packages/58/ed/dea90a65b7d9e69888890fb14c90d7f51bf0c1e82ad800aeb0160e4bacfd/ruff-0.15.10-py3-none-win_arm64.whl", hash = "sha256:601d1610a9e1f1c2165a4f561eeaa2e2ea1e97f3287c5aa258d3dab8b57c6188", size = 11035607, upload-time = "2026-04-09T14:05:47.593Z" }, ] [[package]] @@ -5381,29 +5370,29 @@ wheels = [ [[package]] name = "sendgrid" -version = "6.12.4" +version = "6.12.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ecdsa" }, + { name = "cryptography" }, { name = "python-http-client" }, { name = "werkzeug" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/11/31/62e00433878dccf33edf07f8efa417b9030a2464eb3b04bbd797a11b4447/sendgrid-6.12.4.tar.gz", hash = "sha256:9e88b849daf0fa4bdf256c3b5da9f5a3272402c0c2fd6b1928c9de440db0a03d", size = 50271, upload-time = "2025-06-12T10:29:37.213Z" } +sdist = { url = "https://files.pythonhosted.org/packages/da/fa/f718b2b953f99c1f0085811598ac7e31ccbd4229a81ec2a5290be868187a/sendgrid-6.12.5.tar.gz", hash = "sha256:ea9aae30cd55c332e266bccd11185159482edfc07c149b6cd15cf08869fabdb7", size = 50310, upload-time = "2025-09-19T06:23:09.229Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/9c/45d068fd831a65e6ed1e2ab3233de58784842afdc62fdcdd0a01bbb6b39d/sendgrid-6.12.4-py3-none-any.whl", hash = "sha256:9a211b96241e63bd5b9ed9afcc8608f4bcac426e4a319b3920ab877c8426e92c", size = 102122, upload-time = "2025-06-12T10:29:35.457Z" }, + { url = "https://files.pythonhosted.org/packages/bd/55/b3c3880a77082e8f7374954e0074aafafaa9bc78bdf9c8f5a92c2e7afc6a/sendgrid-6.12.5-py3-none-any.whl", hash = "sha256:96f92cc91634bf552fdb766b904bbb53968018da7ae41fdac4d1090dc0311ca8", size = 102173, upload-time = "2025-09-19T06:23:07.93Z" }, ] [[package]] name = "sentry-sdk" -version = "2.55.0" +version = "2.57.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e9/b8/285293dc60fc198fffc3fcdbc7c6d4e646e0f74e61461c355d40faa64ceb/sentry_sdk-2.55.0.tar.gz", hash = "sha256:3774c4d8820720ca4101548131b9c162f4c9426eb7f4d24aca453012a7470f69", size = 424505, upload-time = "2026-03-17T14:15:51.707Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4f/87/46c0406d8b5ddd026f73adaf5ab75ce144219c41a4830b52df4b9ab55f7f/sentry_sdk-2.57.0.tar.gz", hash = "sha256:4be8d1e71c32fb27f79c577a337ac8912137bba4bcbc64a4ec1da4d6d8dc5199", size = 435288, upload-time = "2026-03-31T09:39:29.264Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9a/66/20465097782d7e1e742d846407ea7262d338c6e876ddddad38ca8907b38f/sentry_sdk-2.55.0-py2.py3-none-any.whl", hash = "sha256:97026981cb15699394474a196b88503a393cbc58d182ece0d3abe12b9bd978d4", size = 449284, upload-time = "2026-03-17T14:15:49.604Z" }, + { url = "https://files.pythonhosted.org/packages/c9/64/982e07b93219cb52e1cca5d272cb579e2f3eb001956c9e7a9a6d106c9473/sentry_sdk-2.57.0-py2.py3-none-any.whl", hash = "sha256:812c8bf5ff3d2f0e89c82f5ce80ab3a6423e102729c4706af7413fd1eb480585", size = 456489, upload-time = "2026-03-31T09:39:27.524Z" }, ] [package.optional-dependencies] @@ -5709,7 +5698,7 @@ wheels = [ [[package]] name = "tablestore" -version = "6.4.3" +version = "6.4.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -5722,9 +5711,9 @@ dependencies = [ { name = "six" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/85/0b/c875c2314d472eed9f9644a94ae0aa7e702a6084779a0136e539d5e7ed32/tablestore-6.4.3.tar.gz", hash = "sha256:4981139e68705052ade6341060a4b6238b1fb9a8c18b43a77383fda14f7554a9", size = 5072450, upload-time = "2026-03-31T04:34:37.832Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2f/bc/84d7592188b060950f4fe5713eb3b03068d42b2e43ad37decdb5242c1879/tablestore-6.4.4.tar.gz", hash = "sha256:0f40834030aff0c67e568b09deaab97144229b569710d66557edf7a06a5dcb19", size = 5076731, upload-time = "2026-04-09T09:40:20.399Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/e0/e11626aea61e1352dafe7707c548d482769afd3ca28f45653d380ba85a5d/tablestore-6.4.3-py3-none-any.whl", hash = "sha256:207b89324cd4157db4559c7619d42b9510a55c0565f00a439389f14426d114c5", size = 5115764, upload-time = "2026-03-31T04:34:35.761Z" }, + { url = "https://files.pythonhosted.org/packages/45/3f/48af1e72e59d60481724b326317bd311615bdedc31f8f81f9508fb84cda6/tablestore-6.4.4-py3-none-any.whl", hash = "sha256:984f086fa7acabaa3558da93205ad6df562b266b85fd249bc5891f2dd1d65814", size = 5118758, upload-time = "2026-04-09T09:40:17.209Z" }, ] [[package]] @@ -5779,7 +5768,7 @@ wheels = [ [[package]] name = "testcontainers" -version = "4.14.1" +version = "4.14.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docker" }, @@ -5788,9 +5777,9 @@ dependencies = [ { name = "urllib3" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8b/02/ef62dec9e4f804189c44df23f0b86897c738d38e9c48282fcd410308632f/testcontainers-4.14.1.tar.gz", hash = "sha256:316f1bb178d829c003acd650233e3ff3c59a833a08d8661c074f58a4fbd42a64", size = 80148, upload-time = "2026-01-31T23:13:46.915Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/ac/a597c3a0e02b26cbed6dd07df68be1e57684766fd1c381dee9b170a99690/testcontainers-4.14.2.tar.gz", hash = "sha256:1340ccf16fe3acd9389a6c9e1d9ab21d9fe99a8afdf8165f89c3e69c1967d239", size = 166841, upload-time = "2026-03-18T05:19:16.696Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/31/5e7b23f9e43ff7fd46d243808d70c5e8daf3bc08ecf5a7fb84d5e38f7603/testcontainers-4.14.1-py3-none-any.whl", hash = "sha256:03dfef4797b31c82e7b762a454b6afec61a2a512ad54af47ab41e4fa5415f891", size = 125640, upload-time = "2026-01-31T23:13:45.464Z" }, + { url = "https://files.pythonhosted.org/packages/13/2d/26b8b30067d94339afee62c3edc9b803a6eb9332f521ba77d8aaab5de873/testcontainers-4.14.2-py3-none-any.whl", hash = "sha256:0d0522c3cd8f8d9627cda41f7a6b51b639fa57bdc492923c045117933c668d68", size = 125712, upload-time = "2026-03-18T05:19:15.29Z" }, ] [[package]] @@ -5964,11 +5953,11 @@ wheels = [ [[package]] name = "types-aiofiles" -version = "25.1.0.20251011" +version = "25.1.0.20260409" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/84/6c/6d23908a8217e36704aa9c79d99a620f2fdd388b66a4b7f72fbc6b6ff6c6/types_aiofiles-25.1.0.20251011.tar.gz", hash = "sha256:1c2b8ab260cb3cd40c15f9d10efdc05a6e1e6b02899304d80dfa0410e028d3ff", size = 14535, upload-time = "2025-10-11T02:44:51.237Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/66/9e62a2692792bc96c0f423f478149f4a7b84720704c546c8960b0a047c89/types_aiofiles-25.1.0.20260409.tar.gz", hash = "sha256:49e67d72bdcf9fe406f5815758a78dc34a1249bb5aa2adba78a80aec0a775435", size = 14812, upload-time = "2026-04-09T04:22:35.308Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/71/0f/76917bab27e270bb6c32addd5968d69e558e5b6f7fb4ac4cbfa282996a96/types_aiofiles-25.1.0.20251011-py3-none-any.whl", hash = "sha256:8ff8de7f9d42739d8f0dadcceeb781ce27cd8d8c4152d4a7c52f6b20edb8149c", size = 14338, upload-time = "2025-10-11T02:44:50.054Z" }, + { url = "https://files.pythonhosted.org/packages/27/d0/28236f869ba4dfb223ecdbc267eb2bdb634b81a561dd992230a4f9ec48fa/types_aiofiles-25.1.0.20260409-py3-none-any.whl", hash = "sha256:923fedb532c772cc0f62e0ce4282725afa82ca5b41cabd9857f06b55e5eee8de", size = 14372, upload-time = "2026-04-09T04:22:34.328Z" }, ] [[package]] @@ -5994,229 +5983,229 @@ wheels = [ [[package]] name = "types-cachetools" -version = "6.2.0.20260317" +version = "6.2.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8b/7f/16a4d8344c28193a5a74358028c2d2f753f0d9658dd98b9e1967c50045a2/types_cachetools-6.2.0.20260317.tar.gz", hash = "sha256:6d91855bcc944665897c125e720aa3c80aace929b77a64e796343701df4f61c6", size = 9812, upload-time = "2026-03-17T04:06:32.007Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/61/475b0e8f4a92e5e33affcc6f4e6344c6dee540824021d22f695ea170da63/types_cachetools-6.2.0.20260408.tar.gz", hash = "sha256:0d8ae2dd5ba0b4cfe6a55c34396dd0415f1be07d0033d84781cdc4ed9c2ebc6b", size = 9854, upload-time = "2026-04-08T04:31:49.665Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/17/9a/b00b23054934c4d569c19f7278c4fb32746cd36a64a175a216d3073a4713/types_cachetools-6.2.0.20260317-py3-none-any.whl", hash = "sha256:92fa9bc50e4629e31fca67ceb3fb1de71791e314fa16c0a0d2728724dc222c8b", size = 9346, upload-time = "2026-03-17T04:06:31.184Z" }, + { url = "https://files.pythonhosted.org/packages/bb/7d/579f50f4f004ee93c7d1baa95339591cac1fe02f4e3fb8fc0f900ee4a80f/types_cachetools-6.2.0.20260408-py3-none-any.whl", hash = "sha256:470e0b274737feae74beed3d764885bf4664002ecc393fba3778846b13ce92cb", size = 9350, upload-time = "2026-04-08T04:31:48.826Z" }, ] [[package]] name = "types-cffi" -version = "2.0.0.20260402" +version = "2.0.0.20260408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cb/85/3896bfcb4e7c32904f762c36ff0afa96d3e39bfce5a95a41635af79c8761/types_cffi-2.0.0.20260402.tar.gz", hash = "sha256:47e1320c009f630c59c55c8e3d2b8c501e280babf52e92f6109cbfb0864ba367", size = 17476, upload-time = "2026-04-02T04:21:09.332Z" } +sdist = { url = "https://files.pythonhosted.org/packages/64/67/eb4ef3408fdc0b4e5af38b30c0e6ad4663b41bdae9fb85a9f09a8db61a99/types_cffi-2.0.0.20260408.tar.gz", hash = "sha256:aa8b9c456ab715c079fc655929811f21f331bfb940f4a821987c581bf4e36230", size = 17541, upload-time = "2026-04-08T04:36:03.918Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/26/aacfef05841e31c65f889ae4225c6bce6b84cd5d3882c42a3661030f29ee/types_cffi-2.0.0.20260402-py3-none-any.whl", hash = "sha256:f647a400fba0a31d603479169d82ee5359db79bd1136e41dc7e6489296e3a2b2", size = 20103, upload-time = "2026-04-02T04:21:08.199Z" }, + { url = "https://files.pythonhosted.org/packages/c3/a3/7fbd93ededcc7c77e9e5948b9794161733ebdbf618a27965b1bea0e728a4/types_cffi-2.0.0.20260408-py3-none-any.whl", hash = "sha256:68bd296742b4ff7c0afe3547f50bd0acc55416ecf322ffefd2b7344ef6388a42", size = 20101, upload-time = "2026-04-08T04:36:02.995Z" }, ] [[package]] name = "types-colorama" -version = "0.4.15.20250801" +version = "0.4.15.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/99/37/af713e7d73ca44738c68814cbacf7a655aa40ddd2e8513d431ba78ace7b3/types_colorama-0.4.15.20250801.tar.gz", hash = "sha256:02565d13d68963d12237d3f330f5ecd622a3179f7b5b14ee7f16146270c357f5", size = 10437, upload-time = "2025-08-01T03:48:22.605Z" } +sdist = { url = "https://files.pythonhosted.org/packages/83/c0/1c02ed9edf3462a392f4ea4bda80fa10c538c63d1d7be255dc7dcb545007/types_colorama-0.4.15.20260408.tar.gz", hash = "sha256:9a816657927489463edec1b7b47933b73fe737d37a3616bf596b7de843441032", size = 10623, upload-time = "2026-04-08T04:28:31.763Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/95/3a/44ccbbfef6235aeea84c74041dc6dfee6c17ff3ddba782a0250e41687ec7/types_colorama-0.4.15.20250801-py3-none-any.whl", hash = "sha256:b6e89bd3b250fdad13a8b6a465c933f4a5afe485ea2e2f104d739be50b13eea9", size = 10743, upload-time = "2025-08-01T03:48:21.774Z" }, + { url = "https://files.pythonhosted.org/packages/b9/65/d03948be8ae9362ad26f36443eab051fe5524295fe008126cd65792f9833/types_colorama-0.4.15.20260408-py3-none-any.whl", hash = "sha256:7327a51c760d94f7df2e8c72c275a4468c03c3abb606d23995cb37e3d24d9132", size = 10763, upload-time = "2026-04-08T04:28:30.688Z" }, ] [[package]] name = "types-defusedxml" -version = "0.7.0.20260402" +version = "0.7.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d3/3c/8e1243dda2fef73be93081d896503352fb92e2351b0b17ac172bbdb70ebf/types_defusedxml-0.7.0.20260402.tar.gz", hash = "sha256:4cc91b225e77c7fcf88b3fb7d821a37fb4e14530727c790b6b8a19f2968d6074", size = 10604, upload-time = "2026-04-02T04:19:00.265Z" } +sdist = { url = "https://files.pythonhosted.org/packages/39/af/d324da5ffbf0af40477533a09ee6c902de335c445a8dcc88c58f62af6e5f/types_defusedxml-0.7.0.20260408.tar.gz", hash = "sha256:f35377d59344f98b57f9bf319cff2107aac35f9e4d42f9ed6cfeeafacffadb00", size = 10638, upload-time = "2026-04-08T04:26:12.239Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/4e/68f85712dfbcc929c54d57e9b0e7503c198fa65896cae2f6337840ab1cc5/types_defusedxml-0.7.0.20260402-py3-none-any.whl", hash = "sha256:200f3cb340c3c576adeb28cf365399e9bb059b34662b86ad4617692284c98bdb", size = 13434, upload-time = "2026-04-02T04:18:59.263Z" }, + { url = "https://files.pythonhosted.org/packages/ed/68/7570cfb818d6a5b3ff964114527e28e360eccf18329b457f057a18596e64/types_defusedxml-0.7.0.20260408-py3-none-any.whl", hash = "sha256:2d68db82412170b91b3e490b7c118a4f4e5a27756a126e2453f629c8d514b106", size = 13435, upload-time = "2026-04-08T04:26:11.347Z" }, ] [[package]] name = "types-deprecated" -version = "1.3.1.20260402" +version = "1.3.1.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e2/ff/7e237c5118c1bd15e5205789901f7e01db232b0c61ca7c7c05de0394f5da/types_deprecated-1.3.1.20260402.tar.gz", hash = "sha256:00828ef7dce735d778583d00611f97da05b86b783ee14b0f22af2f945363cd12", size = 8481, upload-time = "2026-04-02T04:18:28.704Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1a/db/076de3e81b106d3cec17aec9640ab1b2d02f29bad441de280459c161ce65/types_deprecated-1.3.1.20260408.tar.gz", hash = "sha256:62d6a86d0cc754c14bb2de31162d069b1c6a07ce11ee65e5258f8f75308eb3a3", size = 8524, upload-time = "2026-04-08T04:26:39.894Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/3c/59aa775db5f69eba978390c33e1fd617817381cd87424ac1cff4bf2fb6c5/types_deprecated-1.3.1.20260402-py3-none-any.whl", hash = "sha256:ddf1813bd99cd1c00358cb0cb079878fdaa74509e7e482b79627f74f768f31a9", size = 9077, upload-time = "2026-04-02T04:18:27.867Z" }, + { url = "https://files.pythonhosted.org/packages/53/d0/d3258379deb749d949c3c72313981c9d2cceec518b87dcf506f022f5d49f/types_deprecated-1.3.1.20260408-py3-none-any.whl", hash = "sha256:b64e1eab560d4fa9394a27a3099211344b0e0f2f3ac8026d825c86e70d65cdd5", size = 9079, upload-time = "2026-04-08T04:26:38.752Z" }, ] [[package]] name = "types-docutils" -version = "0.22.3.20260322" +version = "0.22.3.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/44/bb/243a87fc1605a4a94c2c343d6dbddbf0d7ef7c0b9550f360b8cda8e82c39/types_docutils-0.22.3.20260322.tar.gz", hash = "sha256:e2450bb997283c3141ec5db3e436b91f0aa26efe35eb9165178ca976ccb4930b", size = 57311, upload-time = "2026-03-22T04:08:44.064Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3c/49/48a386fe15539556de085b87a69568b028cca2fa4b92596a3d4f79ac6784/types_docutils-0.22.3.20260408.tar.gz", hash = "sha256:22d5d45e4e0d65a1bc8280987a73e28669bb1cc9d16b18d0afc91713d1be26da", size = 57383, upload-time = "2026-04-08T04:27:26.924Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/4a/22c090cd4615a16917dff817cbe7c5956da376c961e024c241cd962d2c3d/types_docutils-0.22.3.20260322-py3-none-any.whl", hash = "sha256:681d4510ce9b80a0c6a593f0f9843d81f8caa786db7b39ba04d9fd5480ac4442", size = 91978, upload-time = "2026-03-22T04:08:43.117Z" }, + { url = "https://files.pythonhosted.org/packages/08/47/1667fda6e9fcb044f8fb797f6dc4367b88dc2ab40f1a035e387f5405e870/types_docutils-0.22.3.20260408-py3-none-any.whl", hash = "sha256:2545a86966022cdf1468d430b0007eba0837be77974a7f3fafa1b04a6815d531", size = 91981, upload-time = "2026-04-08T04:27:25.934Z" }, ] [[package]] name = "types-flask-cors" -version = "6.0.0.20260402" +version = "6.0.0.20260408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "flask" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b5/59/84d8ed3801cbf28876067387e1055467e94e3dd404e93e35fe2ec5e46729/types_flask_cors-6.0.0.20260402.tar.gz", hash = "sha256:57350b504328df7ec13a12599e67939189cb644c5d0efec9af80ed03c592052c", size = 10126, upload-time = "2026-04-02T04:20:57.954Z" } +sdist = { url = "https://files.pythonhosted.org/packages/22/68/e58191af5b56e836a4a2e2583ecfad91bde176940edf1bfc8ea706a5f74d/types_flask_cors-6.0.0.20260408.tar.gz", hash = "sha256:8c440c158c335819bb9286870c9770687ae6d943510fdd97e4b573324f8d2178", size = 10223, upload-time = "2026-04-08T04:35:42.608Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/51/71/d86f7644a18a8ccdddf50b9969fc94abbecd0ac52594880dc5667ca53e5e/types_flask_cors-6.0.0.20260402-py3-none-any.whl", hash = "sha256:e018d34946c110f5acfa71cc708ec66b47c4292131647e54889600c20892ca26", size = 9990, upload-time = "2026-04-02T04:20:57.12Z" }, + { url = "https://files.pythonhosted.org/packages/a2/8d/eb905e231aaed6c0853f002446a1fb12d5a32d79b688f2cdd4f8d6e6ce03/types_flask_cors-6.0.0.20260408-py3-none-any.whl", hash = "sha256:ccd8801862b3ebd27754734b84fc3dcfebd0f8056380ae88254c7dd799d64a39", size = 9993, upload-time = "2026-04-08T04:35:41.452Z" }, ] [[package]] name = "types-flask-migrate" -version = "4.1.0.20260402" +version = "4.1.0.20260408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "flask" }, { name = "flask-sqlalchemy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a8/85/291317e13f72d5b2b6c1fe2c59c77a45d07bb225bf5bb2768da6a7b96351/types_flask_migrate-4.1.0.20260402.tar.gz", hash = "sha256:8e0062f063ecbe5c73b53ffc1e86f4d6de5ab970142c7d2dea939c5680ba817a", size = 8717, upload-time = "2026-04-02T04:21:45.77Z" } +sdist = { url = "https://files.pythonhosted.org/packages/46/31/56f5607fca2ad4e41da095b0e22ce9749d29d985df8c0229907bf662413b/types_flask_migrate-4.1.0.20260408.tar.gz", hash = "sha256:65ef927584777eac9a4591eb8320c09eb6eb8862d2ffdd6e23ad485a2869b228", size = 8773, upload-time = "2026-04-08T04:36:22.197Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d4/d9/716b9cb9fca0f87e95f573e21e5ffe83d1cf9919ceb2e1cca8bc71488746/types_flask_migrate-4.1.0.20260402-py3-none-any.whl", hash = "sha256:6989d40d3cfae1c5f70c8f20ba39e714949b633329cc23b2dd00e82fd5b07d1c", size = 8669, upload-time = "2026-04-02T04:21:44.967Z" }, + { url = "https://files.pythonhosted.org/packages/cd/4b/1d1b300251d33f1a97004ef8bba53139116b00872cfb85521d1412259627/types_flask_migrate-4.1.0.20260408-py3-none-any.whl", hash = "sha256:ffdacb78f6697422aa09bdebba34f4133b2443a95a59761c279fc5d368c009d9", size = 8670, upload-time = "2026-04-08T04:36:21.394Z" }, ] [[package]] name = "types-gevent" -version = "25.9.0.20260402" +version = "26.4.0.20260409" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-greenlet" }, { name = "types-psutil" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1c/2f/a2056079f14aeacf538b51b0e6585328c3584fa8e6f4758214c9773ea4b0/types_gevent-25.9.0.20260402.tar.gz", hash = "sha256:24297e6f5733e187a517f08dde6df7b2147e14f7de4d343148f410dffebb5381", size = 38270, upload-time = "2026-04-02T04:22:00.125Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a4/ea/17fa935aa62d45cb9f67947e93c3c0c1ed97a76d579b12e1623cd348b68a/types_gevent-26.4.0.20260409.tar.gz", hash = "sha256:6b029c599fe4ec0efce8cd2bf5e5ae958d9808aa5b2f7bdfcb9b9eb42d91cc6a", size = 38333, upload-time = "2026-04-09T04:22:42.334Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/2f/995920b5cc58bc9041ded8ea2fda32719f6c513bc6e43a0c5234780936db/types_gevent-25.9.0.20260402-py3-none-any.whl", hash = "sha256:178ba12e426c987dd69ef0b8ce9f1095a965103a0d673294831f49f7127bc5ba", size = 55494, upload-time = "2026-04-02T04:21:59.144Z" }, + { url = "https://files.pythonhosted.org/packages/29/17/04671a7e3de8c0fdd4c39dc43830b496ad68998d37cff38e0d9701f77a67/types_gevent-26.4.0.20260409-py3-none-any.whl", hash = "sha256:f5f5eb7365a9b8b738787a2dc93c509ee0ca919c6d4388504f2cd09e476d4066", size = 55491, upload-time = "2026-04-09T04:22:41.226Z" }, ] [[package]] name = "types-greenlet" -version = "3.3.0.20251206" +version = "3.4.0.20260409" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/d3/23f4ab29a5ce239935bb3c157defcf50df8648c16c65965fae03980d67f3/types_greenlet-3.3.0.20251206.tar.gz", hash = "sha256:3e1ab312ab7154c08edc2e8110fbf00d9920323edc1144ad459b7b0052063055", size = 8901, upload-time = "2025-12-06T03:01:38.634Z" } +sdist = { url = "https://files.pythonhosted.org/packages/27/a6/668751bc864efe820e1eb12c2a77f9e62537f433cc002e483ad01badb04b/types_greenlet-3.4.0.20260409.tar.gz", hash = "sha256:81d2cf628934a16856bb9e54136def8de5356e934f0ad5d5474f219a0c5cb205", size = 8976, upload-time = "2026-04-09T04:22:31.693Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/8f/aabde1b6e49b25a6804c12a707829e44ba0f5520563c09271f05d3196142/types_greenlet-3.3.0.20251206-py3-none-any.whl", hash = "sha256:8d11041c0b0db545619e8c8a1266aa4aaa4ebeae8ae6b4b7049917a6045a5590", size = 8809, upload-time = "2025-12-06T03:01:37.651Z" }, + { url = "https://files.pythonhosted.org/packages/4f/3f/c8a4d8782f78fccb4b5fe91c5eae2efce6648072754bc7096b1e3b5407ad/types_greenlet-3.4.0.20260409-py3-none-any.whl", hash = "sha256:cbceadb4594eccd95b57b3f7fa8a9b851488f5e6c05026f4a3db9aac02ec8333", size = 8812, upload-time = "2026-04-09T04:22:30.734Z" }, ] [[package]] name = "types-html5lib" -version = "1.1.11.20260402" +version = "1.1.11.20260408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-webencodings" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/13/95/74eabb3bd0bb2f2b3a8ba56a55e87ee4b76f2b39e2a690eca399deffc837/types_html5lib-1.1.11.20260402.tar.gz", hash = "sha256:a167a30b9619a6eea82ec8b8948044859e033966a4721db34187d647c3a6c1f3", size = 18268, upload-time = "2026-04-02T04:21:56.528Z" } +sdist = { url = "https://files.pythonhosted.org/packages/16/59/914d00107c770e49fa57d4c4572e0371bbce14321385fd2ea3e06691b62d/types_html5lib-1.1.11.20260408.tar.gz", hash = "sha256:8a281aa367bc77dbc758358cd9bef79530f2d154eeed9b33705bb035a0dab9e4", size = 18316, upload-time = "2026-04-08T04:35:49.581Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/79/a9/fac9d4313b1851620610f46d086ba288482c0d5384ebf6feafb5bc4bdd15/types_html5lib-1.1.11.20260402-py3-none-any.whl", hash = "sha256:245d02cf53ef62d7342268c53dbc2af2d200849feec03f77f5909655cb54ab0d", size = 24314, upload-time = "2026-04-02T04:21:55.659Z" }, + { url = "https://files.pythonhosted.org/packages/61/19/12d95e98e42e120522665ec6850b38df8d2c1cca94e21c4d7f8578acb64e/types_html5lib-1.1.11.20260408-py3-none-any.whl", hash = "sha256:d18dc4b90d6d6745585790b920db13ede43e1f8ff6ee1ac0ceb0dec4223a06fa", size = 24313, upload-time = "2026-04-08T04:35:48.679Z" }, ] [[package]] name = "types-jmespath" -version = "1.1.0.20260124" +version = "1.1.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2b/ca/c8d7fc6e450c2f8fc6f510cb194754c43b17f933f2dcabcfc6985cbb97a8/types_jmespath-1.1.0.20260124.tar.gz", hash = "sha256:29d86868e72c0820914577077b27d167dcab08b1fc92157a29d537ff7153fdfe", size = 10709, upload-time = "2026-01-24T03:18:46.557Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/5e/33881ff525fbaa71cb6192d81fd4039607006ff48f85c40ef1e20d72d1d3/types_jmespath-1.1.0.20260408.tar.gz", hash = "sha256:42483cfc3d16bdd88c1150a7419d59ef59b8bdc4db3eec8ebf6971a0dad1a425", size = 10733, upload-time = "2026-04-08T04:29:22.923Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/61/91/915c4a6e6e9bd2bca3ec0c21c1771b175c59e204b85e57f3f572370fe753/types_jmespath-1.1.0.20260124-py3-none-any.whl", hash = "sha256:ec387666d446b15624215aa9cbd2867ffd885b6c74246d357c65e830c7a138b3", size = 11509, upload-time = "2026-01-24T03:18:45.536Z" }, + { url = "https://files.pythonhosted.org/packages/e4/f8/4c34097ce72dc8ea533db26a0162c53837398b26d4a0645ca3c7df74370b/types_jmespath-1.1.0.20260408-py3-none-any.whl", hash = "sha256:58a29fe039e5d3f9d0d42f1b067b9efa7c3e29c7e6df9c6830cbe5fa44ffb943", size = 11512, upload-time = "2026-04-08T04:29:22.133Z" }, ] [[package]] name = "types-markdown" -version = "3.10.2.20260211" +version = "3.10.2.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6d/2e/35b30a09f6ee8a69142408d3ceb248c4454aa638c0a414d8704a3ef79563/types_markdown-3.10.2.20260211.tar.gz", hash = "sha256:66164310f88c11a58c6c706094c6f8c537c418e3525d33b76276a5fbd66b01ce", size = 19768, upload-time = "2026-02-11T04:19:29.497Z" } +sdist = { url = "https://files.pythonhosted.org/packages/dd/0e/a690840934c459aa50e0470e7550d7f151632eafa4a8e3c21d18009ad15c/types_markdown-3.10.2.20260408.tar.gz", hash = "sha256:d5cba15ed65a1420e80e31c17e3d4a2ad7208a3f3a4da97fd2c5f093caf523cd", size = 19784, upload-time = "2026-04-08T04:33:07.644Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/54/c9/659fa2df04b232b0bfcd05d2418e683080e91ec68f636f3c0a5a267350e7/types_markdown-3.10.2.20260211-py3-none-any.whl", hash = "sha256:2d94d08587e3738203b3c4479c449845112b171abe8b5cadc9b0c12fcf3e99da", size = 25854, upload-time = "2026-02-11T04:19:28.647Z" }, + { url = "https://files.pythonhosted.org/packages/75/7e/265a8df257c8dced6ea89295f793a19f0a49ccbfeae1ed562368b2caf7a3/types_markdown-3.10.2.20260408-py3-none-any.whl", hash = "sha256:b0bbe8b7a8174db732067b86e391262898f5f536589ea81efec6d35ceb829331", size = 25857, upload-time = "2026-04-08T04:33:06.769Z" }, ] [[package]] name = "types-oauthlib" -version = "3.3.0.20260324" +version = "3.3.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/91/38/543938f86d81bd6a78b8c355fe81bb8da0a26e4c28addfe3443e38a683d2/types_oauthlib-3.3.0.20260324.tar.gz", hash = "sha256:3c4cc07fa33886f881682237c1e445c5f1778b44efea118f4c1e4ede82cb52f2", size = 26030, upload-time = "2026-03-24T04:06:30.898Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/7e/4cf7b08b4c6b266d9967c02ebdba8c5390029d5750def924b23679a730a0/types_oauthlib-3.3.0.20260408.tar.gz", hash = "sha256:deaeccbc33634f5efa7ef320924bce743495f5a1520073ce4fa0fea441bf063d", size = 26066, upload-time = "2026-04-08T04:28:45.636Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/60/26f0ddade4b2bb17b3d8f3ebaac436e5487caec28831da3d7ea309fe93b9/types_oauthlib-3.3.0.20260324-py3-none-any.whl", hash = "sha256:d24662033b04f4d50a2f1fed04c1b43ff2554aa037c1dafa0424f87100a46ccd", size = 48984, upload-time = "2026-03-24T04:06:29.696Z" }, + { url = "https://files.pythonhosted.org/packages/d7/77/6866665af7b414bbffd37028a92618d3771402ae587e9cd2d70efcb6d8f6/types_oauthlib-3.3.0.20260408-py3-none-any.whl", hash = "sha256:1c305d18a05636fac4953800aa0982e54c258562838dafeaa6a3d05b7f4669fe", size = 48987, upload-time = "2026-04-08T04:28:44.719Z" }, ] [[package]] name = "types-objgraph" -version = "3.6.0.20240907" +version = "3.6.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/22/48/ba0ec63d392904eee34ef1cbde2d8798f79a3663950e42fbbc25fd1bd6f7/types-objgraph-3.6.0.20240907.tar.gz", hash = "sha256:2e3dee675843ae387889731550b0ddfed06e9420946cf78a4bca565b5fc53634", size = 2928, upload-time = "2024-09-07T02:35:21.214Z" } +sdist = { url = "https://files.pythonhosted.org/packages/82/4b/e43381191b1d9d1a0d8b1d7da12ee28ea63f97f38bc6694231dde066b3c8/types_objgraph-3.6.0.20260408.tar.gz", hash = "sha256:9937aae5ad5bb625a2091b33e2f67e979f61e3719078d318b2261c4e7f13ac9a", size = 7606, upload-time = "2026-04-08T04:30:41.072Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/16/c9/6d647a947f3937b19bcc6d52262921ddad60d90060ff66511a4bd7e990c5/types_objgraph-3.6.0.20240907-py3-none-any.whl", hash = "sha256:67207633a9b5789ee1911d740b269c3371081b79c0d8f68b00e7b8539f5c43f5", size = 3314, upload-time = "2024-09-07T02:35:19.865Z" }, + { url = "https://files.pythonhosted.org/packages/38/99/6f618e0931367814b2ab9ad2b946f0f0ca4b8b02405a3552bcb90acc1b5c/types_objgraph-3.6.0.20260408-py3-none-any.whl", hash = "sha256:fd4ae0c6c10e260a3dfdd778f3f79f081c804f9d48b5bb6c2299d8645b287c5f", size = 8059, upload-time = "2026-04-08T04:30:40.301Z" }, ] [[package]] name = "types-olefile" -version = "0.47.0.20240806" +version = "0.47.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/49/18/9d87a1bc394323ce22690308c751680c4301fc3fbe47cd58e16d760b563a/types-olefile-0.47.0.20240806.tar.gz", hash = "sha256:96490f208cbb449a52283855319d73688ba9167ae58858ef8c506bf7ca2c6b67", size = 4369, upload-time = "2024-08-06T02:30:01.966Z" } +sdist = { url = "https://files.pythonhosted.org/packages/84/d0/1c7e058666a4e5364463ad0a7bfd7a0bbc2180df427016cc7ebeeddb0b29/types_olefile-0.47.0.20260408.tar.gz", hash = "sha256:01fbcb6332152e88486634460c8d59a8c75da9a50d85e0ff6f754c02db3fc23e", size = 9037, upload-time = "2026-04-08T04:30:13.973Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/4d/f8acae53dd95353f8a789a06ea27423ae41f2067eb6ce92946fdc6a1f7a7/types_olefile-0.47.0.20240806-py3-none-any.whl", hash = "sha256:c760a3deab7adb87a80d33b0e4edbbfbab865204a18d5121746022d7f8555118", size = 4758, upload-time = "2024-08-06T02:30:01.15Z" }, + { url = "https://files.pythonhosted.org/packages/37/20/b88f8f1336fd3772813c21dd536e50d10c4416294539b8b623e769e9b4a2/types_olefile-0.47.0.20260408-py3-none-any.whl", hash = "sha256:2499f110beb659504173dcd66c069a297b210f715ad45a713e097ecd3a265992", size = 9493, upload-time = "2026-04-08T04:30:13.173Z" }, ] [[package]] name = "types-openpyxl" -version = "3.1.5.20260402" +version = "3.1.5.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6b/8f/d9daf094e0bb468b26e74c1bf9e0170e58c3f16e583d244e9f32078b6bcc/types_openpyxl-3.1.5.20260402.tar.gz", hash = "sha256:855ad28d47c0965048082dfca424d6ebd54d8861d72abcee9106ba5868899e7f", size = 101310, upload-time = "2026-04-02T04:17:37.6Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/c9/24f03f9d9fedd164de699c1418869bef9b819f59f75e7f647f5788c02d98/types_openpyxl-3.1.5.20260408.tar.gz", hash = "sha256:b49274d086fbb6e6bcd2a67d161dd1161d4d380488e4cce546d647d45eadcac2", size = 101361, upload-time = "2026-04-08T04:30:37.809Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/ee/a0b22012076cf23b73fbb82d9c40843cbf6b1d228d7a2dc883da0a905a16/types_openpyxl-3.1.5.20260402-py3-none-any.whl", hash = "sha256:1d149989f0aad4e2074e96b87a045136399e27bc2a33cfefcd0eb4cad8ea5b4c", size = 166046, upload-time = "2026-04-02T04:17:36.162Z" }, + { url = "https://files.pythonhosted.org/packages/5f/e8/64db0c32c6fda4b198aff84329ceeea00a93d16c28f05e9fe0404ddb8b8c/types_openpyxl-3.1.5.20260408-py3-none-any.whl", hash = "sha256:7ab7586796aed017cde50b81bd67ba024120c39b99c102320b57dd91390c317f", size = 166044, upload-time = "2026-04-08T04:30:36.449Z" }, ] [[package]] name = "types-pexpect" -version = "4.9.0.20260127" +version = "4.9.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2e/32/7e03a07e16f79a404d6200ed6bdfcc320d0fb833436a5c6895a1403dedb7/types_pexpect-4.9.0.20260127.tar.gz", hash = "sha256:f8d43efc24251a8e533c71ea9be03d19bb5d08af096d561611697af9720cba7f", size = 13461, upload-time = "2026-01-27T03:28:30.923Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/0f/5e9aa68e4595264e968ffaa3358afb2a8d60093f460aaa7e0398c0d9bfd0/types_pexpect-4.9.0.20260408.tar.gz", hash = "sha256:faedd97fc8086b224bc1966770c486ac6ec96bef07dc47cc2724fe4ae62f8f4a", size = 13471, upload-time = "2026-04-08T04:32:30.624Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/d9/7ac5c9aa5a89a1a64cd835ae348227f4939406d826e461b85b690a8ba1c2/types_pexpect-4.9.0.20260127-py3-none-any.whl", hash = "sha256:69216c0ebf0fe45ad2900823133959b027e9471e24fc3f2e4c7b00605555da5f", size = 17078, upload-time = "2026-01-27T03:28:29.848Z" }, + { url = "https://files.pythonhosted.org/packages/c2/13/96004a3e5dd6c6e7de4edc421d5a2926e062d22be7b006edab747ed42830/types_pexpect-4.9.0.20260408-py3-none-any.whl", hash = "sha256:ba6699609bb6593f00ef7204efc390fd10bc14a8d632f22c8dea13f263b16fcc", size = 17083, upload-time = "2026-04-08T04:32:29.75Z" }, ] [[package]] name = "types-protobuf" -version = "7.34.1.20260403" +version = "7.34.1.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ae/b3/c2e407ea36e0e4355c135127cee1b88a2cc9a2c92eafca50a360ab9f2708/types_protobuf-7.34.1.20260403.tar.gz", hash = "sha256:8d7881867888e667eb9563c08a916fccdc12bdb5f9f34c31d217cce876e36765", size = 68782, upload-time = "2026-04-03T04:18:09.428Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5b/b1/4521e68c2cc17703d80eb42796751345376dd4c706f84007ef5e7c707774/types_protobuf-7.34.1.20260408.tar.gz", hash = "sha256:e2c0a0430e08c75b52671a6f0035abfdcc791aad12af16274282de1b721758ab", size = 68835, upload-time = "2026-04-08T04:26:43.613Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7d/95/24fb0f6fe37b41cf94f9b9912712645e17d8048d4becaf37c1607ddd8e32/types_protobuf-7.34.1.20260403-py3-none-any.whl", hash = "sha256:16d9bbca52ab0f306279958878567df2520f3f5579059419b0ce149a0ad1e332", size = 86011, upload-time = "2026-04-03T04:18:08.245Z" }, + { url = "https://files.pythonhosted.org/packages/ef/b5/0bc9874d89c58fb0ce851e150055ce732d254dbb10b06becbc7635d0d635/types_protobuf-7.34.1.20260408-py3-none-any.whl", hash = "sha256:ebbcd4e27b145aef6a59bc0cb6c013b3528151c1ba5e7f7337aeee355d276a5e", size = 86012, upload-time = "2026-04-08T04:26:42.566Z" }, ] [[package]] name = "types-psutil" -version = "7.2.2.20260402" +version = "7.2.2.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/31/a2/a608db0caf0d71bd231305dc3ab3f5d65624d77761003696a3ca8c6fad40/types_psutil-7.2.2.20260402.tar.gz", hash = "sha256:9f36eebf15ad8487f8004ed67c8e008b84b63ba00cfb709a3f60275058217329", size = 26522, upload-time = "2026-04-02T04:18:47.916Z" } +sdist = { url = "https://files.pythonhosted.org/packages/44/14/279fd5defebbd560ede04aecd38f7651cccee7336f2264d0889d8c9a9d43/types_psutil-7.2.2.20260408.tar.gz", hash = "sha256:e8053450685965b8cd52afb62569073d00ea9967ae78bb45dff5f606847f97f2", size = 26556, upload-time = "2026-04-08T04:27:44.349Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/81/8a/f4b3ca3154e8a77df91eb7a28c208af721d48f8a4aca667f582523a0beff/types_psutil-7.2.2.20260402-py3-none-any.whl", hash = "sha256:653d1fd908e68cc0666754b16a0cee28efbded0c401caa5314d2aeea67f227cd", size = 32860, upload-time = "2026-04-02T04:18:46.671Z" }, + { url = "https://files.pythonhosted.org/packages/af/40/2fd92a4a1ee088c4dbcc44c977908d9869838d9cd2a2fa2e001352f56694/types_psutil-7.2.2.20260408-py3-none-any.whl", hash = "sha256:0c334f6f6bc9e9c24fca5c7d1f0b6971c961a0a2e3956dc5ce704722c01f9762", size = 32861, upload-time = "2026-04-08T04:27:42.929Z" }, ] [[package]] name = "types-psycopg2" -version = "2.9.21.20260223" +version = "2.9.21.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/55/1f/4daff0ce5e8e191844e65aaa793ed1b9cb40027dc2700906ecf2b6bcc0ed/types_psycopg2-2.9.21.20260223.tar.gz", hash = "sha256:78ed70de2e56bc6b5c26c8c1da8e9af54e49fdc3c94d1504609f3519e2b84f02", size = 27090, upload-time = "2026-02-23T04:11:18.177Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/24/d8ae11a0c056535557aaabeb7d7838423abdfdcf1e5f8dfb2c04d316c65d/types_psycopg2-2.9.21.20260408.tar.gz", hash = "sha256:bb65cd12f53b6633077fd782607a33065e1f3bf585219c9f786b61ad2b72211c", size = 27078, upload-time = "2026-04-08T04:26:15.848Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8d/e7/c566df58410bc0728348b514e718f0b38fa0d248b5c10599a11494ba25d2/types_psycopg2-2.9.21.20260223-py3-none-any.whl", hash = "sha256:c6228ade72d813b0624f4c03feeb89471950ac27cd0506b5debed6f053086bc8", size = 24919, upload-time = "2026-02-23T04:11:17.214Z" }, + { url = "https://files.pythonhosted.org/packages/1a/fe/9aab9239640107b6e46afddcee578a916b8b98bfee36e03da5b0d2c95124/types_psycopg2-2.9.21.20260408-py3-none-any.whl", hash = "sha256:49b086bfc9e0ce901c6537403ead1c19c75275571040b037af0248a8e48c322f", size = 24921, upload-time = "2026-04-08T04:26:14.715Z" }, ] [[package]] name = "types-pygments" -version = "2.20.0.20260406" +version = "2.20.0.20260408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-docutils" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/08/bd/d17c28a4c65c556bc4c4bc8f363aa2fbfc91b397e3c0019839d74d9ead31/types_pygments-2.20.0.20260406.tar.gz", hash = "sha256:d3ed7ecd7c34a382459d28ce624b87e1dee03d6844e43aa7590ef4b8c7c9dfce", size = 19486, upload-time = "2026-04-06T04:33:59.632Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/89/4b443128fa540c54a8f7ecdeec225aab4818534167c4a2d133099dc00fa6/types_pygments-2.20.0.20260408.tar.gz", hash = "sha256:e8a56a3ab1aee7f4ed8f1876d2f62c96e0f41ede52405a7d30c888f3989d8f00", size = 21115, upload-time = "2026-04-08T04:34:24.29Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/00/dca7518e6f99ce0f235ec1c6512593ee4bd25109ae1c912bf9ee836a26e1/types_pygments-2.20.0.20260406-py3-none-any.whl", hash = "sha256:6bb0c79874c304977e1c097f7007140e16fe78c443329154db803d7910d945b3", size = 27278, upload-time = "2026-04-06T04:33:58.744Z" }, + { url = "https://files.pythonhosted.org/packages/5e/d8/30924b38eef70caef6b05af5440c84d7673cea2a042e206f404c8100a88d/types_pygments-2.20.0.20260408-py3-none-any.whl", hash = "sha256:6d347d5967b5f0654b659a8b8461a870b207b7e60cd4d646bbc047f6a8db8e1e", size = 29055, upload-time = "2026-04-08T04:34:23.412Z" }, ] [[package]] name = "types-pymysql" -version = "1.1.0.20251220" +version = "1.1.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d3/59/e959dd6d2f8e3b3c3f058d79ac9ece328922a5a8770c707fe9c3a757481c/types_pymysql-1.1.0.20251220.tar.gz", hash = "sha256:ae1c3df32a777489431e2e9963880a0df48f6591e0aa2fd3a6fabd9dee6eca54", size = 22184, upload-time = "2025-12-20T03:07:38.689Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/04/c3570f05ebab083f28698c829dddf754ffefc30aae4e29915610848e44db/types_pymysql-1.1.0.20260408.tar.gz", hash = "sha256:b784dc37908479e3767e2d794ab507b3674adb1c686ca3d13fc9e2960dbcb9ec", size = 22344, upload-time = "2026-04-08T04:27:47.651Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/fa/4f4d3bfca9ef6dd17d69ed18b96564c53b32d3ce774132308d0bee849f10/types_pymysql-1.1.0.20251220-py3-none-any.whl", hash = "sha256:fa1082af7dea6c53b6caa5784241924b1296ea3a8d3bd060417352c5e10c0618", size = 23067, upload-time = "2025-12-20T03:07:37.766Z" }, + { url = "https://files.pythonhosted.org/packages/70/b3/15dee33878709705a4cc83bcc1bb30e00e95bbe038b472cb1207a15b50a1/types_pymysql-1.1.0.20260408-py3-none-any.whl", hash = "sha256:da630647eaaa7a926a3907794f4067f269cd245b2c202c74aa3c6a3bd660a9db", size = 23071, upload-time = "2026-04-08T04:27:46.735Z" }, ] [[package]] @@ -6234,38 +6223,38 @@ wheels = [ [[package]] name = "types-python-dateutil" -version = "2.9.0.20260402" +version = "2.9.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a7/30/c5d9efbff5422b20c9551dc5af237d1ab0c3d33729a9b3239a876ca47dd4/types_python_dateutil-2.9.0.20260402.tar.gz", hash = "sha256:a980142b9966713acb382c467e35c5cc4208a2f91b10b8d785a0ae6765df6c0b", size = 16941, upload-time = "2026-04-02T04:18:35.834Z" } +sdist = { url = "https://files.pythonhosted.org/packages/88/f3/2427775f80cd5e19a0a71ba8e5ab7645a01a852f43a5fd0ffc24f66338e0/types_python_dateutil-2.9.0.20260408.tar.gz", hash = "sha256:8b056ec01568674235f64ecbcef928972a5fac412f5aab09c516dfa2acfbb582", size = 16981, upload-time = "2026-04-08T04:28:10.995Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/d7/fe753bf8329c8c3c1addcba1d2bf716c33898216757abb24f8b80f82d040/types_python_dateutil-2.9.0.20260402-py3-none-any.whl", hash = "sha256:7827e6a9c93587cc18e766944254d1351a2396262e4abe1510cbbd7601c5e01f", size = 18436, upload-time = "2026-04-02T04:18:34.806Z" }, + { url = "https://files.pythonhosted.org/packages/fd/c6/eeba37bfee282a6a97f889faef9352d6172c6a5088eb9a4daf570d9d748d/types_python_dateutil-2.9.0.20260408-py3-none-any.whl", hash = "sha256:473139d514a71c9d1fbd8bb328974bedcb1cc3dba57aad04ffa4157f483c216f", size = 18437, upload-time = "2026-04-08T04:28:10.095Z" }, ] [[package]] name = "types-python-http-client" -version = "3.3.7.20250708" +version = "3.3.7.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/55/a0/0ad93698a3ebc6846ca23aca20ff6f6f8ebe7b4f0c1de7f19e87c03dbe8f/types_python_http_client-3.3.7.20250708.tar.gz", hash = "sha256:5f85b32dc64671a4e5e016142169aa187c5abed0b196680944e4efd3d5ce3322", size = 7707, upload-time = "2025-07-08T03:14:36.197Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/30/f741f5edce6b02a838a30064360f5480510d5f2861561f44c5e33bc1dd96/types_python_http_client-3.3.7.20260408.tar.gz", hash = "sha256:ae84aadeec645ede7e602714090e6c8ebca97dbf28af509ac5eeccfc300174a2", size = 7684, upload-time = "2026-04-08T04:27:09.67Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/85/4f/b88274658cf489e35175be8571c970e9a1219713bafd8fc9e166d7351ecb/types_python_http_client-3.3.7.20250708-py3-none-any.whl", hash = "sha256:e2fc253859decab36713d82fc7f205868c3ddeaee79dbb55956ad9ca77abe12b", size = 8890, upload-time = "2025-07-08T03:14:35.506Z" }, + { url = "https://files.pythonhosted.org/packages/f1/d8/45c6c04924086e8856e7f9a33a38ee713992d9ae9cd6d449de97badcba3c/types_python_http_client-3.3.7.20260408-py3-none-any.whl", hash = "sha256:3f310282e0fe2a18c5291f935538e1f97b9f80d2c5571aad155e66806719017c", size = 8851, upload-time = "2026-04-08T04:27:08.877Z" }, ] [[package]] name = "types-pywin32" -version = "311.0.0.20260402" +version = "311.0.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b3/f0/fc3c923b5d7822f3a93c7b242a69de0e1945e7c153cc5367074621a6509f/types_pywin32-311.0.0.20260402.tar.gz", hash = "sha256:637f041065f02fb49cbaba530ae8cf2e483b5d2c145a9bf97fd084c3e913c7e3", size = 332312, upload-time = "2026-04-02T04:18:52.748Z" } +sdist = { url = "https://files.pythonhosted.org/packages/30/40/0d182fbf578f30f7ff2b07b8fe494cc42178992d0087a739c70990adca8c/types_pywin32-311.0.0.20260408.tar.gz", hash = "sha256:cb86c6beae20195165e770a65c3ee707746dc777ca8e03e4f06a66d4013a4bd0", size = 332341, upload-time = "2026-04-08T04:33:29.824Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/80/0c/a2ee20785df4ebcda6d6ec62d58b7c08a37072f9d00cda4f9548e9c8e5aa/types_pywin32-311.0.0.20260402-py3-none-any.whl", hash = "sha256:4db644fcf40ee85a3ee2551f110d009e427c01569ed4670bb53cfe999df0929f", size = 395413, upload-time = "2026-04-02T04:18:51.529Z" }, + { url = "https://files.pythonhosted.org/packages/0d/b5/3cc67baf622805270d84e2252dfa130daf7ccd49795f80b51350abb91bd9/types_pywin32-311.0.0.20260408-py3-none-any.whl", hash = "sha256:0b691da60aaed0ee7169a69268bad1e2051eb52f4acc10248c103aadcd1f2451", size = 395413, upload-time = "2026-04-08T04:33:28.476Z" }, ] [[package]] name = "types-pyyaml" -version = "6.0.12.20250915" +version = "6.0.12.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7e/69/3c51b36d04da19b92f9e815be12753125bd8bc247ba0470a982e6979e71c/types_pyyaml-6.0.12.20250915.tar.gz", hash = "sha256:0f8b54a528c303f0e6f7165687dd33fafa81c807fcac23f632b63aa624ced1d3", size = 17522, upload-time = "2025-09-15T03:01:00.728Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/73/b759b1e413c31034cc01ecdfb96b38115d0ab4db55a752a3929f0cd449fd/types_pyyaml-6.0.12.20260408.tar.gz", hash = "sha256:92a73f2b8d7f39ef392a38131f76b970f8c66e4c42b3125ae872b7c93b556307", size = 17735, upload-time = "2026-04-08T04:30:50.974Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bd/e0/1eed384f02555dde685fff1a1ac805c1c7dcb6dd019c916fe659b1c1f9ec/types_pyyaml-6.0.12.20250915-py3-none-any.whl", hash = "sha256:e7d4d9e064e89a3b3cae120b4990cd370874d2bf12fa5f46c97018dd5d3c9ab6", size = 20338, upload-time = "2025-09-15T03:00:59.218Z" }, + { url = "https://files.pythonhosted.org/packages/1c/f0/c391068b86abb708882c6d75a08cd7d25b2c7227dab527b3a3685a3c635b/types_pyyaml-6.0.12.20260408-py3-none-any.whl", hash = "sha256:fbc42037d12159d9c801ebfcc79ebd28335a7c13b08a4cfbc6916df78fee9384", size = 20339, upload-time = "2026-04-08T04:30:50.113Z" }, ] [[package]] @@ -6283,11 +6272,11 @@ wheels = [ [[package]] name = "types-regex" -version = "2026.4.4.20260405" +version = "2026.4.4.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/74/9c/dd7b36fe87902a161a69c4a6959e3a6afae09c2c600916beb1aecd300870/types_regex-2026.4.4.20260405.tar.gz", hash = "sha256:993b76a255d9b83fd68eed2fc52b2746be51a93b833796be4fcf9412efa0da51", size = 13143, upload-time = "2026-04-05T04:26:56.614Z" } +sdist = { url = "https://files.pythonhosted.org/packages/92/42/d7c691fc5a8a8ecfba3f23c1c4c087a089af0767610d88c29201193d8f60/types_regex-2026.4.4.20260408.tar.gz", hash = "sha256:86b2975ff11b06e7f538839821510daea2566d9cb18bb8acde47834315409cf9", size = 13182, upload-time = "2026-04-08T04:31:11.887Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/51/83/5dbae203616699890efcdb2a2670d62baf5ed93634f75d793157f1edefb3/types_regex-2026.4.4.20260405-py3-none-any.whl", hash = "sha256:40443cb88c43b9940dd4c904e251be7e65dab3798b2cf6f5ff19501ae99b2ab5", size = 11119, upload-time = "2026-04-05T04:26:55.636Z" }, + { url = "https://files.pythonhosted.org/packages/e1/92/e109654a804d11d9b60d67c7b29d64b2beac6b2e3209ea075e268e5a1021/types_regex-2026.4.4.20260408-py3-none-any.whl", hash = "sha256:d436bcc409abf9b06747b7e038014afc6d40ef7b72329655c353a1955534068f", size = 11116, upload-time = "2026-04-08T04:31:11.01Z" }, ] [[package]] @@ -6313,67 +6302,67 @@ wheels = [ [[package]] name = "types-setuptools" -version = "82.0.0.20260402" +version = "82.0.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e9/f8/74f8a76b4311e70772c0df8f2d432040a3b0facd7bcce6b72b0b26e1746b/types_setuptools-82.0.0.20260402.tar.gz", hash = "sha256:63d2b10ba7958396ad79bbc24d2f6311484e452daad4637ffd40407983a27069", size = 44805, upload-time = "2026-04-02T04:17:49.229Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/12/3464b410c50420dd4674fa5fe9d3880711c1dbe1a06f5fe4960ee9067b9e/types_setuptools-82.0.0.20260408.tar.gz", hash = "sha256:036c68caf7e672a699f5ebbf914708d40644c14e05298bc49f7272be91cf43d3", size = 44861, upload-time = "2026-04-08T04:29:33.292Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/e9/22451997f70ac2c5f18dc5f988750c986011fb049d9021767277119e63fa/types_setuptools-82.0.0.20260402-py3-none-any.whl", hash = "sha256:4b9a9f6c3c4c65107a3956ad6a6acbccec38e398ff6d5f78d5df7f103dadb8d6", size = 68429, upload-time = "2026-04-02T04:17:48.11Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e1/46a4fc3ef03aabf5d18bac9df5cf37c6b02c3bddf3e05c3533f4b4588331/types_setuptools-82.0.0.20260408-py3-none-any.whl", hash = "sha256:ece0a215cdfa6463a65fd6f68bd940f39e455729300ddfe61cab1147ed1d2462", size = 68428, upload-time = "2026-04-08T04:29:32.175Z" }, ] [[package]] name = "types-shapely" -version = "2.1.0.20260402" +version = "2.1.0.20260408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a3/f7/46e95b09434105d7b772d05657495f2900bae8e108fdf4e6d8b5902aa28c/types_shapely-2.1.0.20260402.tar.gz", hash = "sha256:0eb592328170433b4724430a64c309bf07ba69d5d11489d3dba21382d78f5297", size = 26481, upload-time = "2026-04-02T04:20:03.104Z" } +sdist = { url = "https://files.pythonhosted.org/packages/10/8d/bf9e3eb51249601e22d797481999a06fb34998c4db5c76804394f8a3fa28/types_shapely-2.1.0.20260408.tar.gz", hash = "sha256:8552549d9429baa52ec4331e43b5db3b334fc3a7f30da48663010b7454b1451c", size = 26529, upload-time = "2026-04-08T04:34:42.111Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/14/3a/1aa3a62f5b85d4a9e649e7b42842a9e5503fef7eb50c480137a6b94f8bb1/types_shapely-2.1.0.20260402-py3-none-any.whl", hash = "sha256:8d70a16f615a104fd8abdd73e684d4e83b9dedf31d6432ecf86945b5ef0e35de", size = 37817, upload-time = "2026-04-02T04:20:02.17Z" }, + { url = "https://files.pythonhosted.org/packages/8e/3d/cbec691f56e71636192a07bf6809f598bed06d869b03b4e2b1ad2f7df032/types_shapely-2.1.0.20260408-py3-none-any.whl", hash = "sha256:8a31e2b074342a363f0c9d0c7d6e1e6c0dcce302a92ef94d64d0ca2a2b94a1d1", size = 37818, upload-time = "2026-04-08T04:34:41.243Z" }, ] [[package]] name = "types-simplejson" -version = "3.20.0.20260402" +version = "3.20.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/94/93/2ff2f4b8ccd942ee3a4b62c013d2c1779e416d303950060ed8b3f1a4fc11/types_simplejson-3.20.0.20260402.tar.gz", hash = "sha256:ee2bbf65830fe93270a1c0406f3474c952fe1232532c7b6f3eb9500edb308c5a", size = 10650, upload-time = "2026-04-02T04:19:26.266Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9f/36/e319fd0f6d906dbf7c2c03eef17db77ef461197a75b253fccd9c7c695d3e/types_simplejson-3.20.0.20260408.tar.gz", hash = "sha256:0b0e1bf61e70f81dfe6ef4c2b9c02e39403848c0652df334e7a430c3a26c06b3", size = 10693, upload-time = "2026-04-08T04:28:07.8Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/2a/7ba2bede9c2b25fb338d0bda9925a23b73a5ac99fd97304ebe067c090e33/types_simplejson-3.20.0.20260402-py3-none-any.whl", hash = "sha256:b3bdef21bc24fee26b80385ffea5163b6b10381089aa619fe2f8f8d3790e6148", size = 10419, upload-time = "2026-04-02T04:19:25.464Z" }, + { url = "https://files.pythonhosted.org/packages/22/c0/01a5a4c3948c2269cf9d727e5e66a8b404e03beb4f9522680a3f71097011/types_simplejson-3.20.0.20260408-py3-none-any.whl", hash = "sha256:f9e542199cb159ed34ad54b6ceb3dc9af890c256b810ad1bd7c69c61db7d2236", size = 10415, upload-time = "2026-04-08T04:28:06.984Z" }, ] [[package]] name = "types-six" -version = "1.17.0.20251009" +version = "1.17.0.20260408" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ee/f7/448215bc7695cfa0c8a7e0dcfa54fe31b1d52fb87004fed32e659dd85c80/types_six-1.17.0.20251009.tar.gz", hash = "sha256:efe03064ecd0ffb0f7afe133990a2398d8493d8d1c1cc10ff3dfe476d57ba44f", size = 15552, upload-time = "2025-10-09T02:54:26.02Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/95/14bb40b2fa8f19234d60b370bfa1ff64b42509b6d2dee070132949ce4f80/types_six-1.17.0.20260408.tar.gz", hash = "sha256:b28579aedb204d07abac52e49c87e2b4c03cb6171bd764bd9b7775ba58fffaba", size = 15766, upload-time = "2026-04-08T04:26:23.54Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/2f/94baa623421940e3eb5d2fc63570ebb046f2bb4d9573b8787edab3ed2526/types_six-1.17.0.20251009-py3-none-any.whl", hash = "sha256:2494f4c2a58ada0edfe01ea84b58468732e43394c572d9cf5b1dd06d86c487a3", size = 19935, upload-time = "2025-10-09T02:54:25.096Z" }, + { url = "https://files.pythonhosted.org/packages/83/04/3e9c382043579b5170c3bc38d13154d48e8ef2c89c4473748a33e3c9bccd/types_six-1.17.0.20260408-py3-none-any.whl", hash = "sha256:02208fa1099944ed0c8f8de42f065ffd63c55cd7b59f49be49802626b8d58318", size = 19937, upload-time = "2026-04-08T04:26:22.259Z" }, ] [[package]] name = "types-tensorflow" -version = "2.18.0.20260402" +version = "2.18.0.20260408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "types-protobuf" }, { name = "types-requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b9/d9/1ca68336ce7ad8c4a19001fce85f47ffae9d7ac335e5ddd73497b6bfbca4/types_tensorflow-2.18.0.20260402.tar.gz", hash = "sha256:607c4a5895d44c88c7c465410093ee050aa760c3cedab5b9662f475c5e2137d3", size = 259058, upload-time = "2026-04-02T04:22:39.113Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cb/15/d9f1a54e75008fde3dc48f333b4d3c86f0d27b822e3a9c109214f8957ae6/types_tensorflow-2.18.0.20260408.tar.gz", hash = "sha256:68bfbcc76dd9e314eae0a91964edf463c52fc0e3d60189542efbf67006e71015", size = 259103, upload-time = "2026-04-08T04:36:45.263Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/6c/0ad58c7246a5369ceb2ae16c146ac0684a0827f499a8141fc3d13743c38b/types_tensorflow-2.18.0.20260402-py3-none-any.whl", hash = "sha256:0d4a74921c457ade8f46eb09cf728a1732156678e497ce15a88b9c0c16dc2fe5", size = 329776, upload-time = "2026-04-02T04:22:37.903Z" }, + { url = "https://files.pythonhosted.org/packages/11/64/4005df91e916f586d9f80c3f052f2ae41afbcd9c9a54d33005fabeefcaab/types_tensorflow-2.18.0.20260408-py3-none-any.whl", hash = "sha256:01cff182dd6c38c300b27b9d1a26791f04607d914fa9429e5f85766c3bc0d71d", size = 329775, upload-time = "2026-04-08T04:36:43.863Z" }, ] [[package]] name = "types-tqdm" -version = "4.67.3.20260402" +version = "4.67.3.20260408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/54/42/e9e6688891d8db77b5795ec02b329524170892ff81bec63c4c4ca7425b30/types_tqdm-4.67.3.20260402.tar.gz", hash = "sha256:e0739f3bc5d1c801999a202f0537280aa1bc2e669c49f5be91bfb99376690624", size = 18077, upload-time = "2026-04-02T04:22:23.049Z" } +sdist = { url = "https://files.pythonhosted.org/packages/43/42/2e2968e68a694d3dac3a47aa0df06e46be1a6eef498e5bd15f4c54674eb9/types_tqdm-4.67.3.20260408.tar.gz", hash = "sha256:fd849a79891ae7136ed47541aface15c35bd9a13160fa8a93e42e10f60cf4c8d", size = 18119, upload-time = "2026-04-08T04:36:52.488Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4f/73/a6cf75de5be376d7b57ce6c934ae9bc90aa5be6ada4ac50a99ecbdf9763e/types_tqdm-4.67.3.20260402-py3-none-any.whl", hash = "sha256:b5d1a65fe3286e1a855e51ddebf63d3641daf9bad285afd1ec56808eb59df76e", size = 24562, upload-time = "2026-04-02T04:22:22.114Z" }, + { url = "https://files.pythonhosted.org/packages/14/5d/7dedddc32ab7bc2344ece772b5e0f03ec63a1d47ad259696689713c1cf50/types_tqdm-4.67.3.20260408-py3-none-any.whl", hash = "sha256:3b9ed74ebef04df8f53d470ffdc84348e93496d8acafa08bf79fafce0f2f5b5d", size = 24561, upload-time = "2026-04-08T04:36:51.538Z" }, ] [[package]] @@ -6783,20 +6772,19 @@ wheels = [ [[package]] name = "weaviate-client" -version = "4.20.4" +version = "4.20.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "authlib" }, - { name = "deprecation" }, { name = "grpcio" }, { name = "httpx" }, { name = "protobuf" }, { name = "pydantic" }, { name = "validators" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/1c/82b560254f612f95b644849d86e092da6407f17965d61e22b583b30b72cf/weaviate_client-4.20.4.tar.gz", hash = "sha256:08703234b59e4e03739f39e740e9e88cb50cd0aa147d9408b88ea6ce995c37b6", size = 809529, upload-time = "2026-03-10T15:08:13.845Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/c8/aa47cfa0a2b1e260846eaf04ce4cc2ab1bb03f29d793e7b009bc3e3babc7/weaviate_client-4.20.5.tar.gz", hash = "sha256:c07c688f0e6b78723dfecbcfeebf897cefa75f1a89c63ebd84aab88c662e4394", size = 811866, upload-time = "2026-04-09T20:08:45.268Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/d7/9461c3e7d8c44080d2307078e33dc7fefefa3171c8f930f2b83a5cbf67f2/weaviate_client-4.20.4-py3-none-any.whl", hash = "sha256:7af3a213bebcb30dcf456b0db8b6225d8926106b835d7b883276de9dc1c301fe", size = 619517, upload-time = "2026-03-10T15:08:12.047Z" }, + { url = "https://files.pythonhosted.org/packages/58/ba/d55f1a665802f736436d09198afc0d00806a405aadb9977193a2f009cfcb/weaviate_client-4.20.5-py3-none-any.whl", hash = "sha256:3f508e3dc08257f85230f9d2ea0562443ed0715e7e89156f22b7e950d6c08cdb", size = 620766, upload-time = "2026-04-09T20:08:43.215Z" }, ] [[package]] diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index ccd2dd53cc..d462ca6449 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -2,7 +2,7 @@ import type { Area } from 'react-easy-crop' import type { OnImageInput } from '@/app/components/base/app-icon-picker/ImageInput' -import type { AvatarProps } from '@/app/components/base/avatar' +import type { AvatarProps } from '@/app/components/base/ui/avatar' import type { ImageFile } from '@/types/app' import { RiDeleteBin5Line, RiPencilLine } from '@remixicon/react' import * as React from 'react' @@ -10,10 +10,10 @@ import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import ImageInput from '@/app/components/base/app-icon-picker/ImageInput' import getCroppedImg from '@/app/components/base/app-icon-picker/utils' -import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks' +import { Avatar } from '@/app/components/base/ui/avatar' import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' import { toast } from '@/app/components/base/ui/toast' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' diff --git a/web/app/account/(commonLayout)/avatar.tsx b/web/app/account/(commonLayout)/avatar.tsx index 36a510cf63..b81c96df74 100644 --- a/web/app/account/(commonLayout)/avatar.tsx +++ b/web/app/account/(commonLayout)/avatar.tsx @@ -6,9 +6,9 @@ import { import { Fragment } from 'react' import { useTranslation } from 'react-i18next' import { resetUser } from '@/app/components/base/amplitude/utils' -import { Avatar } from '@/app/components/base/avatar' import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' import PremiumBadge from '@/app/components/base/premium-badge' +import { Avatar } from '@/app/components/base/ui/avatar' import { useProviderContext } from '@/context/provider-context' import { useRouter } from '@/next/navigation' import { useLogout, useUserProfile } from '@/service/use-common' diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index 2c849fd542..b0b7f557a4 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -10,9 +10,9 @@ import { import * as React from 'react' import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' +import { Avatar } from '@/app/components/base/ui/avatar' import { toast } from '@/app/components/base/ui/toast' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index e08ece6666..30d8f3e410 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -9,6 +9,7 @@ import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' import { usePathname, useRouter, useSearchParams } from '@/next/navigation' +import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking' import { sendGAEvent } from '@/utils/gtag' import { fetchSetupStatusWithCache } from '@/utils/setup-status' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' @@ -45,6 +46,8 @@ export const AppInitializer = ({ (async () => { const action = searchParams.get('action') + rememberCreateAppExternalAttribution({ searchParams }) + if (oauthNewUser) { let utmInfo = null const utmInfoStr = Cookies.get('utm_info') diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index d7e48f2d1f..b25fb94191 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -5,12 +5,12 @@ import { RiAddCircleFill, RiArrowRightSLine, RiOrganizationChart } from '@remixi import { useDebounce } from 'ahooks' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import { Avatar } from '@/app/components/base/ui/avatar' import { useSelector } from '@/context/app-context' import { SubjectType } from '@/models/access-control' import { useSearchForWhiteListCandidates } from '@/service/access-control' import { cn } from '@/utils/classnames' import useAccessControlStore from '../../../../context/access-control-store' -import { Avatar } from '../../base/avatar' import Button from '../../base/button' import Checkbox from '../../base/checkbox' import Input from '../../base/input' diff --git a/web/app/components/app/app-access-control/specific-groups-or-members.tsx b/web/app/components/app/app-access-control/specific-groups-or-members.tsx index 2c0e4b2694..8f4e71c8d2 100644 --- a/web/app/components/app/app-access-control/specific-groups-or-members.tsx +++ b/web/app/components/app/app-access-control/specific-groups-or-members.tsx @@ -3,10 +3,10 @@ import type { AccessControlAccount, AccessControlGroup } from '@/models/access-c import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from '@remixicon/react' import { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' +import { Avatar } from '@/app/components/base/ui/avatar' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects } from '@/service/access-control' import useAccessControlStore from '../../../../context/access-control-store' -import { Avatar } from '../../base/avatar' import Loading from '../../base/loading' import Tooltip from '../../base/tooltip' import AddMemberOrGroupDialog from './add-member-or-group-pop' diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/chat-item.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/chat-item.spec.tsx index 80bb26a052..61eb8f2ae8 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/chat-item.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/__tests__/chat-item.spec.tsx @@ -90,7 +90,7 @@ vi.mock('@/app/components/base/chat/chat', () => ({ }, })) -vi.mock('@/app/components/base/avatar', () => ({ +vi.mock('@/app/components/base/ui/avatar', () => ({ Avatar: ({ name }: { name: string }) =>
{name}
, })) diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx index e957fc24c4..56345890ff 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx @@ -7,11 +7,11 @@ import { useCallback, useMemo, } from 'react' -import { Avatar } from '@/app/components/base/avatar' import Chat from '@/app/components/base/chat/chat' import { useChat } from '@/app/components/base/chat/chat/hooks' import { getLastAnswer } from '@/app/components/base/chat/utils' import { useFeatures } from '@/app/components/base/features/hooks' +import { Avatar } from '@/app/components/base/ui/avatar' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useAppContext } from '@/context/app-context' import { useDebugConfigurationContext } from '@/context/debug-configuration' diff --git a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx index 84ff8b5ede..a9f9f1116b 100644 --- a/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx +++ b/web/app/components/app/configuration/debug/debug-with-single-model/index.tsx @@ -3,11 +3,11 @@ import type { ChatConfig, ChatItem, OnSend } from '@/app/components/base/chat/ty import type { FileEntity } from '@/app/components/base/file-uploader/types' import { memo, useCallback, useImperativeHandle, useMemo } from 'react' import { useStore as useAppStore } from '@/app/components/app/store' -import { Avatar } from '@/app/components/base/avatar' import Chat from '@/app/components/base/chat/chat' import { useChat } from '@/app/components/base/chat/chat/hooks' import { getLastAnswer, isValidGeneratedAnswer } from '@/app/components/base/chat/utils' import { useFeatures } from '@/app/components/base/features/hooks' +import { Avatar } from '@/app/components/base/ui/avatar' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useAppContext } from '@/context/app-context' import { useDebugConfigurationContext } from '@/context/debug-configuration' diff --git a/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx b/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx index 3ebc5f7157..a319bb58f7 100644 --- a/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx +++ b/web/app/components/app/create-app-dialog/app-list/__tests__/index.spec.tsx @@ -4,7 +4,6 @@ import { AppModeEnum } from '@/types/app' import Apps from '../index' const mockUseExploreAppList = vi.fn() -const mockTrackEvent = vi.fn() const mockImportDSL = vi.fn() const mockFetchAppDetail = vi.fn() const mockHandleCheckPluginDependencies = vi.fn() @@ -12,6 +11,7 @@ const mockGetRedirection = vi.fn() const mockPush = vi.fn() const mockToastSuccess = vi.fn() const mockToastError = vi.fn() +const mockTrackCreateApp = vi.fn() let latestDebounceFn = () => {} vi.mock('ahooks', () => ({ @@ -92,8 +92,8 @@ vi.mock('@/app/components/base/ui/toast', () => ({ error: (...args: unknown[]) => mockToastError(...args), }, })) -vi.mock('@/app/components/base/amplitude', () => ({ - trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), })) vi.mock('@/service/apps', () => ({ importDSL: (...args: unknown[]) => mockImportDSL(...args), @@ -246,10 +246,9 @@ describe('Apps', () => { })) }) - expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_template', expect.objectContaining({ - template_id: 'Alpha', - template_name: 'Alpha', - })) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated') expect(onSuccess).toHaveBeenCalled() expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('created-app-id') diff --git a/web/app/components/app/create-app-dialog/app-list/index.tsx b/web/app/components/app/create-app-dialog/app-list/index.tsx index 1aa40d2014..daf49115c8 100644 --- a/web/app/components/app/create-app-dialog/app-list/index.tsx +++ b/web/app/components/app/create-app-dialog/app-list/index.tsx @@ -8,7 +8,6 @@ import * as React from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import AppTypeSelector from '@/app/components/app/type-selector' -import { trackEvent } from '@/app/components/base/amplitude' import Divider from '@/app/components/base/divider' import Input from '@/app/components/base/input' import Loading from '@/app/components/base/loading' @@ -25,6 +24,7 @@ import { useExploreAppList } from '@/service/use-explore' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import AppCard from '../app-card' import Sidebar, { AppCategories, AppCategoryLabel } from './sidebar' @@ -127,14 +127,7 @@ const Apps = ({ icon_background, description, }) - - // Track app creation from template - trackEvent('create_app_with_template', { - app_mode: mode, - template_id: currApp?.app.id, - template_name: currApp?.app.name, - description, - }) + trackCreateApp({ appMode: mode }) setIsShowCreateModal(false) toast.success(t('newApp.appCreated', { ns: 'app' })) diff --git a/web/app/components/app/create-app-modal/__tests__/index.spec.tsx b/web/app/components/app/create-app-modal/__tests__/index.spec.tsx index ee24ab4006..3e06b89f0e 100644 --- a/web/app/components/app/create-app-modal/__tests__/index.spec.tsx +++ b/web/app/components/app/create-app-modal/__tests__/index.spec.tsx @@ -1,7 +1,6 @@ import type { App } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' -import { trackEvent } from '@/app/components/base/amplitude' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' @@ -10,6 +9,7 @@ import { useRouter } from '@/next/navigation' import { createApp } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' +import { trackCreateApp } from '@/utils/create-app-tracking' import CreateAppModal from '../index' const ahooksMocks = vi.hoisted(() => ({ @@ -31,8 +31,8 @@ vi.mock('ahooks', () => ({ vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(), })) -vi.mock('@/app/components/base/amplitude', () => ({ - trackEvent: vi.fn(), +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: vi.fn(), })) vi.mock('@/service/apps', () => ({ createApp: vi.fn(), @@ -87,7 +87,7 @@ vi.mock('@/hooks/use-theme', () => ({ const mockUseRouter = vi.mocked(useRouter) const mockPush = vi.fn() const mockCreateApp = vi.mocked(createApp) -const mockTrackEvent = vi.mocked(trackEvent) +const mockTrackCreateApp = vi.mocked(trackCreateApp) const mockGetRedirection = vi.mocked(getRedirection) const mockUseProviderContext = vi.mocked(useProviderContext) const mockUseAppContext = vi.mocked(useAppContext) @@ -178,10 +178,7 @@ describe('CreateAppModal', () => { mode: AppModeEnum.ADVANCED_CHAT, })) - expect(mockTrackEvent).toHaveBeenCalledWith('create_app', { - app_mode: AppModeEnum.ADVANCED_CHAT, - description: '', - }) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.ADVANCED_CHAT }) expect(mockToastSuccess).toHaveBeenCalledWith('app.newApp.appCreated') expect(onSuccess).toHaveBeenCalled() expect(onClose).toHaveBeenCalled() diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index f2ced9b6c0..96c3045c59 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -6,7 +6,6 @@ import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon import { useDebounceFn, useKeyPress } from 'ahooks' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { trackEvent } from '@/app/components/base/amplitude' import AppIcon from '@/app/components/base/app-icon' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' @@ -25,6 +24,7 @@ import { createApp } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import { basePath } from '@/utils/var' import AppIconPicker from '../../base/app-icon-picker' import ShortcutsName from '../../workflow/shortcuts-name' @@ -80,11 +80,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }: mode: appMode, }) - // Track app creation success - trackEvent('create_app', { - app_mode: appMode, - description, - }) + trackCreateApp({ appMode: app.mode }) toast.success(t('newApp.appCreated', { ns: 'app' })) onSuccess() diff --git a/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx b/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx index c1ffbc22e8..e106cc7eb3 100644 --- a/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx +++ b/web/app/components/app/create-from-dsl-modal/__tests__/index.spec.tsx @@ -2,12 +2,13 @@ import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { DSLImportMode, DSLImportStatus } from '@/models/app' +import { AppModeEnum } from '@/types/app' import CreateFromDSLModal, { CreateFromDSLModalTab } from '../index' const mockPush = vi.fn() const mockImportDSL = vi.fn() const mockImportDSLConfirm = vi.fn() -const mockTrackEvent = vi.fn() +const mockTrackCreateApp = vi.fn() const mockHandleCheckPluginDependencies = vi.fn() const mockGetRedirection = vi.fn() const toastMocks = vi.hoisted(() => ({ @@ -43,8 +44,8 @@ vi.mock('@/next/navigation', () => ({ }), })) -vi.mock('@/app/components/base/amplitude', () => ({ - trackEvent: (...args: unknown[]) => mockTrackEvent(...args), +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), })) vi.mock('@/service/apps', () => ({ @@ -172,7 +173,7 @@ describe('CreateFromDSLModal', () => { id: 'import-1', status: DSLImportStatus.COMPLETED, app_id: 'app-1', - app_mode: 'chat', + app_mode: AppModeEnum.CHAT, }) render( @@ -196,10 +197,7 @@ describe('CreateFromDSLModal', () => { mode: DSLImportMode.YAML_URL, yaml_url: 'https://example.com/app.yml', }) - expect(mockTrackEvent).toHaveBeenCalledWith('create_app_with_dsl', expect.objectContaining({ - creation_method: 'dsl_url', - has_warnings: false, - })) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.CHAT }) expect(handleSuccess).toHaveBeenCalledTimes(1) expect(handleClose).toHaveBeenCalledTimes(1) expect(localStorage.getItem(NEED_REFRESH_APP_LIST_KEY)).toBe('1') @@ -212,7 +210,7 @@ describe('CreateFromDSLModal', () => { id: 'import-2', status: DSLImportStatus.COMPLETED_WITH_WARNINGS, app_id: 'app-2', - app_mode: 'chat', + app_mode: AppModeEnum.CHAT, }) render( @@ -275,7 +273,7 @@ describe('CreateFromDSLModal', () => { mockImportDSLConfirm.mockResolvedValue({ status: DSLImportStatus.COMPLETED, app_id: 'app-3', - app_mode: 'workflow', + app_mode: AppModeEnum.WORKFLOW, }) render( @@ -305,6 +303,7 @@ describe('CreateFromDSLModal', () => { expect(mockImportDSLConfirm).toHaveBeenCalledWith({ import_id: 'import-3', }) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ appMode: AppModeEnum.WORKFLOW }) }) it('should ignore empty import responses and prevent duplicate submissions while a request is in flight', async () => { @@ -332,7 +331,7 @@ describe('CreateFromDSLModal', () => { id: 'import-in-flight', status: DSLImportStatus.COMPLETED, app_id: 'app-1', - app_mode: 'chat', + app_mode: AppModeEnum.CHAT, }) }) diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index dd17655e3c..77000dbf0a 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -6,7 +6,6 @@ import { useDebounceFn, useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' import { useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { trackEvent } from '@/app/components/base/amplitude' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' @@ -27,6 +26,7 @@ import { } from '@/service/apps' import { getRedirection } from '@/utils/app-redirection' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import ShortcutsName from '../../workflow/shortcuts-name' import Uploader from './uploader' @@ -112,12 +112,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS return const { id, status, app_id, app_mode, imported_dsl_version, current_dsl_version } = response if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { - // Track app creation from DSL import - trackEvent('create_app_with_dsl', { - app_mode, - creation_method: currentTab === CreateFromDSLModalTab.FROM_FILE ? 'dsl_file' : 'dsl_url', - has_warnings: status === DSLImportStatus.COMPLETED_WITH_WARNINGS, - }) + trackCreateApp({ appMode: app_mode }) if (onSuccess) onSuccess() @@ -179,6 +174,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS const { status, app_id, app_mode } = response if (status === DSLImportStatus.COMPLETED) { + trackCreateApp({ appMode: app_mode }) if (onSuccess) onSuccess() if (onClose) @@ -228,7 +224,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS isShow={show} onClose={noop} > -
+
{t('importFromDSL', { ns: 'app' })}
-
+
{ tabs.map(tab => (
-
DSL URL
+
DSL URL
-
{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}
-
+
{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}
+
{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}
{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}

diff --git a/web/app/components/apps/__tests__/index.spec.tsx b/web/app/components/apps/__tests__/index.spec.tsx index da4fbc2d44..aae862c865 100644 --- a/web/app/components/apps/__tests__/index.spec.tsx +++ b/web/app/components/apps/__tests__/index.spec.tsx @@ -1,12 +1,48 @@ import type { ReactNode } from 'react' +import type { App } from '@/models/explore' import { QueryClient, QueryClientProvider } from '@tanstack/react-query' -import { render, screen } from '@testing-library/react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' +import { useContextSelector } from 'use-context-selector' +import AppListContext from '@/context/app-list-context' +import { fetchAppDetail } from '@/service/explore' +import { AppModeEnum } from '@/types/app' import Apps from '../index' let documentTitleCalls: string[] = [] let educationInitCalls: number = 0 +const mockHandleImportDSL = vi.fn() +const mockHandleImportDSLConfirm = vi.fn() +const mockTrackCreateApp = vi.fn() +const mockFetchAppDetail = vi.mocked(fetchAppDetail) + +const mockTemplateApp: App = { + app_id: 'template-1', + category: 'Assistant', + app: { + id: 'template-1', + mode: AppModeEnum.CHAT, + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + icon_url: '', + name: 'Sample App', + description: 'Sample App', + use_icon_as_answer_icon: false, + }, + description: 'Sample App', + can_trial: true, + copyright: '', + privacy_policy: null, + custom_disclaimer: null, + position: 1, + is_listed: true, + install_count: 0, + installed: false, + editable: false, + is_agent: false, +} vi.mock('@/hooks/use-document-title', () => ({ default: (title: string) => { @@ -22,17 +58,80 @@ vi.mock('@/app/education-apply/hooks', () => ({ vi.mock('@/hooks/use-import-dsl', () => ({ useImportDSL: () => ({ - handleImportDSL: vi.fn(), - handleImportDSLConfirm: vi.fn(), + handleImportDSL: mockHandleImportDSL, + handleImportDSLConfirm: mockHandleImportDSLConfirm, versions: [], isFetching: false, }), })) -vi.mock('../list', () => ({ - default: () => { - return React.createElement('div', { 'data-testid': 'apps-list' }, 'Apps List') - }, +vi.mock('../list', () => { + const MockList = () => { + const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel) + return React.createElement( + 'div', + { 'data-testid': 'apps-list' }, + React.createElement('span', null, 'Apps List'), + React.createElement( + 'button', + { + 'data-testid': 'open-preview', + 'onClick': () => setShowTryAppPanel(true, { + appId: mockTemplateApp.app_id, + app: mockTemplateApp, + }), + }, + 'Open Preview', + ), + ) + } + + return { default: MockList } +}) + +vi.mock('../../explore/try-app', () => ({ + default: ({ onCreate, onClose }: { onCreate: () => void, onClose: () => void }) => ( +
+ + +
+ ), +})) + +vi.mock('../../explore/create-app-modal', () => ({ + default: ({ show, onConfirm, onHide }: { show: boolean, onConfirm: (payload: Record) => Promise, onHide: () => void }) => show + ? ( +
+ + +
+ ) + : null, +})) + +vi.mock('../../app/create-from-dsl-modal/dsl-confirm-modal', () => ({ + default: ({ onConfirm }: { onConfirm: () => void }) => ( + + ), +})) + +vi.mock('@/service/explore', () => ({ + fetchAppDetail: vi.fn(), +})) + +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), })) describe('Apps', () => { @@ -59,6 +158,14 @@ describe('Apps', () => { vi.clearAllMocks() documentTitleCalls = [] educationInitCalls = 0 + mockFetchAppDetail.mockResolvedValue({ + id: 'template-1', + name: 'Sample App', + icon: '🤖', + icon_background: '#fff', + mode: AppModeEnum.CHAT, + export_data: 'yaml-content', + }) }) describe('Rendering', () => { @@ -116,6 +223,25 @@ describe('Apps', () => { ) expect(screen.getByTestId('apps-list')).toBeInTheDocument() }) + + it('should track template preview creation after a successful import', async () => { + mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => { + options.onSuccess?.() + }) + + renderWithClient() + + fireEvent.click(screen.getByTestId('open-preview')) + fireEvent.click(await screen.findByTestId('try-app-create')) + fireEvent.click(await screen.findByTestId('confirm-create')) + + await waitFor(() => { + expect(mockFetchAppDetail).toHaveBeenCalledWith('template-1') + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) + }) + }) }) describe('Styling', () => { diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index b6ca60bd7b..9bf07e81e6 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -1,7 +1,7 @@ 'use client' import type { CreateAppModalProps } from '../explore/create-app-modal' import type { TryAppSelection } from '@/types/try-app' -import { useCallback, useState } from 'react' +import { useCallback, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useEducationInit } from '@/app/education-apply/hooks' import AppListContext from '@/context/app-list-context' @@ -10,6 +10,7 @@ import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' import dynamic from '@/next/dynamic' import { fetchAppDetail } from '@/service/explore' +import { trackCreateApp } from '@/utils/create-app-tracking' import List from './list' const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) @@ -23,6 +24,7 @@ const Apps = () => { useEducationInit() const [currentTryAppParams, setCurrentTryAppParams] = useState(undefined) + const currentCreateAppModeRef = useRef(null) const currApp = currentTryAppParams?.app const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false) const hideTryAppPanel = useCallback(() => { @@ -40,6 +42,12 @@ const Apps = () => { const handleShowFromTryApp = useCallback(() => { setIsShowCreateModal(true) }, []) + const trackCurrentCreateApp = useCallback(() => { + if (!currentCreateAppModeRef.current) + return + + trackCreateApp({ appMode: currentCreateAppModeRef.current }) + }, []) const [controlRefreshList, setControlRefreshList] = useState(0) const [controlHideCreateFromTemplatePanel, setControlHideCreateFromTemplatePanel] = useState(0) @@ -59,11 +67,14 @@ const Apps = () => { const onConfirmDSL = useCallback(async () => { await handleImportDSLConfirm({ - onSuccess, + onSuccess: () => { + trackCurrentCreateApp() + onSuccess() + }, }) - }, [handleImportDSLConfirm, onSuccess]) + }, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp]) - const onCreate: CreateAppModalProps['onConfirm'] = async ({ + const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, icon, @@ -72,9 +83,10 @@ const Apps = () => { }) => { hideTryAppPanel() - const { export_data } = await fetchAppDetail( + const { export_data, mode } = await fetchAppDetail( currApp?.app.id as string, ) + currentCreateAppModeRef.current = mode const payload = { mode: DSLImportMode.YAML_CONTENT, yaml_content: export_data, @@ -86,13 +98,14 @@ const Apps = () => { } await handleImportDSL(payload, { onSuccess: () => { + trackCurrentCreateApp() setIsShowCreateModal(false) }, onPending: () => { setShowDSLConfirmModal(true) }, }) - } + }, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp]) return ( = ({ }) => { useEffect(() => { // Only enable in Saas edition with valid API key - if (!isAmplitudeEnabled) - return + // if (!isAmplitudeEnabled) + // return // Initialize Amplitude amplitude.init(AMPLITUDE_API_KEY, { diff --git a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx index 720355d09f..9c3a7cc8f7 100644 --- a/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx +++ b/web/app/components/base/chat/chat-with-history/chat-wrapper.tsx @@ -11,6 +11,7 @@ import AppIcon from '@/app/components/base/app-icon' import InputsForm from '@/app/components/base/chat/chat-with-history/inputs-form' import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested-questions' import { Markdown } from '@/app/components/base/markdown' +import { Avatar } from '@/app/components/base/ui/avatar' import { InputVarType } from '@/app/components/workflow/types' import { AppSourceType, @@ -23,7 +24,6 @@ import { submitHumanInputForm as submitHumanInputFormService } from '@/service/w import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' import { formatBooleanInputs } from '@/utils/model-config' -import { Avatar } from '../../avatar' import Chat from '../chat' import { useChat } from '../chat/hooks' import { getLastAnswer, isValidGeneratedAnswer } from '../utils' diff --git a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx index c518a9d078..451f566505 100644 --- a/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx +++ b/web/app/components/base/chat/embedded-chatbot/chat-wrapper.tsx @@ -12,6 +12,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested import InputsForm from '@/app/components/base/chat/embedded-chatbot/inputs-form' import LogoAvatar from '@/app/components/base/logo/logo-embedded-chat-avatar' import { Markdown } from '@/app/components/base/markdown' +import { Avatar } from '@/app/components/base/ui/avatar' import { InputVarType } from '@/app/components/workflow/types' import { AppSourceType, @@ -23,7 +24,6 @@ import { import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow' import { TransferMethod } from '@/types/app' import { cn } from '@/utils/classnames' -import { Avatar } from '../../avatar' import Chat from '../chat' import { useChat } from '../chat/hooks' import { getLastAnswer, isValidGeneratedAnswer } from '../utils' diff --git a/web/app/components/base/form/components/field/__tests__/number-input.spec.tsx b/web/app/components/base/form/components/field/__tests__/number-input.spec.tsx index eb5b419d78..714c280008 100644 --- a/web/app/components/base/form/components/field/__tests__/number-input.spec.tsx +++ b/web/app/components/base/form/components/field/__tests__/number-input.spec.tsx @@ -27,7 +27,7 @@ describe('NumberInputField', () => { it('should update value when users click increment', () => { render() - fireEvent.click(screen.getByRole('button', { name: 'common.operation.increment' })) + fireEvent.click(screen.getByRole('button', { name: 'Increment value' })) expect(mockField.handleChange).toHaveBeenCalledWith(3) }) diff --git a/web/app/components/base/avatar/__tests__/index.spec.tsx b/web/app/components/base/ui/avatar/__tests__/index.spec.tsx similarity index 99% rename from web/app/components/base/avatar/__tests__/index.spec.tsx rename to web/app/components/base/ui/avatar/__tests__/index.spec.tsx index 69c56ac993..8be3f8bf0f 100644 --- a/web/app/components/base/avatar/__tests__/index.spec.tsx +++ b/web/app/components/base/ui/avatar/__tests__/index.spec.tsx @@ -1,5 +1,5 @@ import { render, screen } from '@testing-library/react' -import { Avatar } from '../index' +import { Avatar } from '..' describe('Avatar', () => { describe('Rendering', () => { diff --git a/web/app/components/base/avatar/index.stories.tsx b/web/app/components/base/ui/avatar/index.stories.tsx similarity index 100% rename from web/app/components/base/avatar/index.stories.tsx rename to web/app/components/base/ui/avatar/index.stories.tsx diff --git a/web/app/components/base/avatar/index.tsx b/web/app/components/base/ui/avatar/index.tsx similarity index 92% rename from web/app/components/base/avatar/index.tsx rename to web/app/components/base/ui/avatar/index.tsx index 885022dded..0842a1734d 100644 --- a/web/app/components/base/avatar/index.tsx +++ b/web/app/components/base/ui/avatar/index.tsx @@ -36,7 +36,7 @@ function AvatarRoot({ return ( ) diff --git a/web/app/components/base/ui/number-field/__tests__/index.spec.tsx b/web/app/components/base/ui/number-field/__tests__/index.spec.tsx index 4cc07bc8eb..f988e2b312 100644 --- a/web/app/components/base/ui/number-field/__tests__/index.spec.tsx +++ b/web/app/components/base/ui/number-field/__tests__/index.spec.tsx @@ -172,13 +172,13 @@ describe('NumberField wrapper', () => { // Increment and decrement buttons should preserve accessible naming, icon fallbacks, and spacing variants. describe('Control buttons', () => { - it('should provide localized aria labels and default icons when labels are not provided', () => { + it('should provide english fallback aria labels and default icons when labels are not provided', () => { renderNumberField({ controlsProps: {}, }) - const increment = screen.getByRole('button', { name: 'common.operation.increment' }) - const decrement = screen.getByRole('button', { name: 'common.operation.decrement' }) + const increment = screen.getByRole('button', { name: 'Increment value' }) + const decrement = screen.getByRole('button', { name: 'Decrement value' }) expect(increment.querySelector('.i-ri-arrow-up-s-line')).toBeInTheDocument() expect(decrement.querySelector('.i-ri-arrow-down-s-line')).toBeInTheDocument() @@ -217,11 +217,11 @@ describe('NumberField wrapper', () => { }, }) - expect(screen.getByRole('button', { name: 'common.operation.increment' })).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.decrement' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Increment value' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Decrement value' })).toBeInTheDocument() }) - it('should rely on aria-labelledby when provided instead of injecting a translated aria-label', () => { + it('should rely on aria-labelledby when provided instead of injecting a fallback aria-label', () => { render( <> Increment from label diff --git a/web/app/components/base/ui/number-field/index.tsx b/web/app/components/base/ui/number-field/index.tsx index 97f1cc7d31..7d4c43b815 100644 --- a/web/app/components/base/ui/number-field/index.tsx +++ b/web/app/components/base/ui/number-field/index.tsx @@ -4,7 +4,6 @@ import type { VariantProps } from 'class-variance-authority' import { NumberField as BaseNumberField } from '@base-ui/react/number-field' import { cva } from 'class-variance-authority' import * as React from 'react' -import { useTranslation } from 'react-i18next' import { cn } from '@/utils/classnames' export const NumberField = BaseNumberField.Root @@ -188,18 +187,19 @@ type NumberFieldButtonVariantProps = Omit< export type NumberFieldButtonProps = React.ComponentPropsWithoutRef & NumberFieldButtonVariantProps +const incrementAriaLabel = 'Increment value' +const decrementAriaLabel = 'Decrement value' + export function NumberFieldIncrement({ className, children, size = 'regular', ...props }: NumberFieldButtonProps) { - const { t } = useTranslation() - return ( {children ??
() return ( ({ useExploreAppList: () => ({ @@ -45,6 +46,9 @@ vi.mock('@/hooks/use-import-dsl', () => ({ isFetching: false, }), })) +vi.mock('@/utils/create-app-tracking', () => ({ + trackCreateApp: (...args: unknown[]) => mockTrackCreateApp(...args), +})) vi.mock('@/app/components/explore/create-app-modal', () => ({ default: (props: CreateAppModalProps) => { @@ -214,7 +218,7 @@ describe('AppList', () => { categories: ['Writing'], allList: [createApp()], }; - (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content' }) + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml-content', mode: AppModeEnum.CHAT }) mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void, onPending?: () => void }) => { options.onPending?.() }) @@ -235,6 +239,9 @@ describe('AppList', () => { fireEvent.click(screen.getByTestId('dsl-confirm')) await waitFor(() => { expect(mockHandleImportDSLConfirm).toHaveBeenCalledTimes(1) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) expect(onSuccess).toHaveBeenCalledTimes(1) }) }) @@ -307,7 +314,7 @@ describe('AppList', () => { categories: ['Writing'], allList: [createApp()], }; - (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' }) + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT }) renderAppList(true) fireEvent.click(screen.getByText('explore.appCard.addToWorkspace')) @@ -325,7 +332,7 @@ describe('AppList', () => { categories: ['Writing'], allList: [createApp()], }; - (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' }) + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT }) mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => { options.onSuccess?.() }) @@ -337,6 +344,9 @@ describe('AppList', () => { await waitFor(() => { expect(screen.queryByTestId('create-app-modal')).not.toBeInTheDocument() }) + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) }) it('should cancel DSL confirm modal', async () => { @@ -345,7 +355,7 @@ describe('AppList', () => { categories: ['Writing'], allList: [createApp()], }; - (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml' }) + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT }) mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onPending?: () => void }) => { options.onPending?.() }) @@ -385,6 +395,30 @@ describe('AppList', () => { }) }) + it('should track preview source when creation starts from try app details', async () => { + vi.useRealTimers() + mockExploreData = { + categories: ['Writing'], + allList: [createApp()], + }; + (fetchAppDetail as unknown as Mock).mockResolvedValue({ export_data: 'yaml', mode: AppModeEnum.CHAT }) + mockHandleImportDSL.mockImplementation(async (_payload: unknown, options: { onSuccess?: () => void }) => { + options.onSuccess?.() + }) + + renderAppList(true) + + fireEvent.click(screen.getByText('explore.appCard.try')) + fireEvent.click(screen.getByTestId('try-app-create')) + fireEvent.click(await screen.findByTestId('confirm-create')) + + await waitFor(() => { + expect(mockTrackCreateApp).toHaveBeenCalledWith({ + appMode: AppModeEnum.CHAT, + }) + }) + }) + it('should close try app panel when close is clicked', () => { mockExploreData = { categories: ['Writing'], diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx index d508f141b4..684ab9e267 100644 --- a/web/app/components/explore/app-list/index.tsx +++ b/web/app/components/explore/app-list/index.tsx @@ -6,7 +6,7 @@ import type { TryAppSelection } from '@/types/try-app' import { useDebounceFn } from 'ahooks' import { useQueryState } from 'nuqs' import * as React from 'react' -import { useCallback, useMemo, useState } from 'react' +import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import DSLConfirmModal from '@/app/components/app/create-from-dsl-modal/dsl-confirm-modal' import Button from '@/app/components/base/button' @@ -26,6 +26,7 @@ import { fetchAppDetail } from '@/service/explore' import { useMembers } from '@/service/use-common' import { useExploreAppList } from '@/service/use-explore' import { cn } from '@/utils/classnames' +import { trackCreateApp } from '@/utils/create-app-tracking' import TryApp from '../try-app' import s from './style.module.css' @@ -101,6 +102,7 @@ const Apps = ({ const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false) const [currentTryApp, setCurrentTryApp] = useState(undefined) + const currentCreateAppModeRef = useRef(null) const isShowTryAppPanel = !!currentTryApp const hideTryAppPanel = useCallback(() => { setCurrentTryApp(undefined) @@ -112,8 +114,14 @@ const Apps = ({ setCurrApp(currentTryApp?.app || null) setIsShowCreateModal(true) }, [currentTryApp?.app]) + const trackCurrentCreateApp = useCallback(() => { + if (!currentCreateAppModeRef.current) + return - const onCreate: CreateAppModalProps['onConfirm'] = async ({ + trackCreateApp({ appMode: currentCreateAppModeRef.current }) + }, []) + + const onCreate: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, icon, @@ -122,9 +130,10 @@ const Apps = ({ }) => { hideTryAppPanel() - const { export_data } = await fetchAppDetail( + const { export_data, mode } = await fetchAppDetail( currApp?.app.id as string, ) + currentCreateAppModeRef.current = mode const payload = { mode: DSLImportMode.YAML_CONTENT, yaml_content: export_data, @@ -136,19 +145,23 @@ const Apps = ({ } await handleImportDSL(payload, { onSuccess: () => { + trackCurrentCreateApp() setIsShowCreateModal(false) }, onPending: () => { setShowDSLConfirmModal(true) }, }) - } + }, [currApp?.app.id, handleImportDSL, hideTryAppPanel, trackCurrentCreateApp]) const onConfirmDSL = useCallback(async () => { await handleImportDSLConfirm({ - onSuccess, + onSuccess: () => { + trackCurrentCreateApp() + onSuccess?.() + }, }) - }, [handleImportDSLConfirm, onSuccess]) + }, [handleImportDSLConfirm, onSuccess, trackCurrentCreateApp]) if (isLoading) { return ( diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index f5b0352a40..442554615b 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -4,9 +4,9 @@ import type { MouseEventHandler, ReactNode } from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { resetUser } from '@/app/components/base/amplitude/utils' -import { Avatar } from '@/app/components/base/avatar' import PremiumBadge from '@/app/components/base/premium-badge' import ThemeSwitcher from '@/app/components/base/theme-switcher' +import { Avatar } from '@/app/components/base/ui/avatar' import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { IS_CLOUD_EDITION } from '@/config' diff --git a/web/app/components/header/account-setting/members-page/index.tsx b/web/app/components/header/account-setting/members-page/index.tsx index 875ffba3e0..6ac9ee5d2d 100644 --- a/web/app/components/header/account-setting/members-page/index.tsx +++ b/web/app/components/header/account-setting/members-page/index.tsx @@ -2,7 +2,7 @@ import type { InvitationResult } from '@/models/common' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { Avatar } from '@/app/components/base/avatar' +import { Avatar } from '@/app/components/base/ui/avatar' import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' import { NUM_INFINITE } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' diff --git a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/member-selector.tsx b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/member-selector.tsx index a6617ac4d2..59e69b92e2 100644 --- a/web/app/components/header/account-setting/members-page/transfer-ownership-modal/member-selector.tsx +++ b/web/app/components/header/account-setting/members-page/transfer-ownership-modal/member-selector.tsx @@ -3,9 +3,9 @@ import type { FC } from 'react' import * as React from 'react' import { useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { Avatar } from '@/app/components/base/avatar' import Input from '@/app/components/base/input' import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +import { Avatar } from '@/app/components/base/ui/avatar' import { useMembers } from '@/service/use-common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/workflow/__tests__/block-icon.spec.tsx b/web/app/components/workflow/__tests__/block-icon.spec.tsx new file mode 100644 index 0000000000..c3b30a67b6 --- /dev/null +++ b/web/app/components/workflow/__tests__/block-icon.spec.tsx @@ -0,0 +1,46 @@ +import { render } from '@testing-library/react' +import { API_PREFIX } from '@/config' +import BlockIcon, { VarBlockIcon } from '../block-icon' +import { BlockEnum } from '../types' + +describe('BlockIcon', () => { + it('renders the default workflow icon container for regular nodes', () => { + const { container } = render() + + const iconContainer = container.firstElementChild + expect(iconContainer).toHaveClass('w-4', 'h-4', 'bg-util-colors-blue-brand-blue-brand-500', 'extra-class') + expect(iconContainer?.querySelector('svg')).toBeInTheDocument() + }) + + it('normalizes protected plugin icon urls for tool-like nodes', () => { + const { container } = render( + , + ) + + const iconContainer = container.firstElementChild as HTMLElement + const backgroundIcon = iconContainer.querySelector('div') as HTMLElement + + expect(iconContainer).not.toHaveClass('bg-util-colors-blue-blue-500') + expect(backgroundIcon.style.backgroundImage).toContain( + `${API_PREFIX}/workspaces/current/plugin/icon/plugin-tool.png`, + ) + }) +}) + +describe('VarBlockIcon', () => { + it('renders the compact icon variant without the default container wrapper', () => { + const { container } = render( + , + ) + + expect(container.querySelector('.custom-var-icon')).toBeInTheDocument() + expect(container.querySelector('svg')).toBeInTheDocument() + expect(container.querySelector('.bg-util-colors-warning-warning-500')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/__tests__/context.spec.tsx b/web/app/components/workflow/__tests__/context.spec.tsx new file mode 100644 index 0000000000..ccf1eaa9b1 --- /dev/null +++ b/web/app/components/workflow/__tests__/context.spec.tsx @@ -0,0 +1,39 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { WorkflowContextProvider } from '../context' +import { useStore, useWorkflowStore } from '../store' + +const StoreConsumer = () => { + const showSingleRunPanel = useStore(s => s.showSingleRunPanel) + const store = useWorkflowStore() + + return ( + + ) +} + +describe('WorkflowContextProvider', () => { + it('provides the workflow store to descendants and keeps the same store across rerenders', async () => { + const user = userEvent.setup() + const { rerender } = render( + + + , + ) + + expect(screen.getByRole('button', { name: 'closed' })).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'closed' })) + expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument() + + rerender( + + + , + ) + + expect(screen.getByRole('button', { name: 'open' })).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/__tests__/index.spec.tsx b/web/app/components/workflow/__tests__/index.spec.tsx new file mode 100644 index 0000000000..77b61e54e7 --- /dev/null +++ b/web/app/components/workflow/__tests__/index.spec.tsx @@ -0,0 +1,67 @@ +import type { Edge, Node } from '../types' +import { render, screen } from '@testing-library/react' +import { useStoreApi } from 'reactflow' +import { useDatasetsDetailStore } from '../datasets-detail-store/store' +import WorkflowWithDefaultContext from '../index' +import { BlockEnum } from '../types' +import { useWorkflowHistoryStore } from '../workflow-history-store' + +const nodes: Node[] = [ + { + id: 'node-start', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + title: 'Start', + desc: '', + type: BlockEnum.Start, + }, + }, +] + +const edges: Edge[] = [ + { + id: 'edge-1', + source: 'node-start', + target: 'node-end', + sourceHandle: null, + targetHandle: null, + type: 'custom', + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.End, + }, + }, +] + +const ContextConsumer = () => { + const { store, shortcutsEnabled } = useWorkflowHistoryStore() + const datasetCount = useDatasetsDetailStore(state => Object.keys(state.datasetsDetail).length) + const reactFlowStore = useStoreApi() + + return ( +
+ {`history:${store.getState().nodes.length}`} + {` shortcuts:${String(shortcutsEnabled)}`} + {` datasets:${datasetCount}`} + {` reactflow:${String(!!reactFlowStore)}`} +
+ ) +} + +describe('WorkflowWithDefaultContext', () => { + it('wires the ReactFlow, workflow history, and datasets detail providers around its children', () => { + render( + + + , + ) + + expect( + screen.getByText('history:1 shortcuts:true datasets:0 reactflow:true'), + ).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/__tests__/shortcuts-name.spec.tsx b/web/app/components/workflow/__tests__/shortcuts-name.spec.tsx new file mode 100644 index 0000000000..87efddb005 --- /dev/null +++ b/web/app/components/workflow/__tests__/shortcuts-name.spec.tsx @@ -0,0 +1,51 @@ +import { render, screen } from '@testing-library/react' +import ShortcutsName from '../shortcuts-name' + +describe('ShortcutsName', () => { + const originalNavigator = globalThis.navigator + + afterEach(() => { + Object.defineProperty(globalThis, 'navigator', { + value: originalNavigator, + writable: true, + configurable: true, + }) + }) + + it('renders mac-friendly key labels and style variants', () => { + Object.defineProperty(globalThis, 'navigator', { + value: { userAgent: 'Macintosh' }, + writable: true, + configurable: true, + }) + + const { container } = render( + , + ) + + expect(screen.getByText('⌘')).toBeInTheDocument() + expect(screen.getByText('⇧')).toBeInTheDocument() + expect(screen.getByText('s')).toBeInTheDocument() + expect(container.querySelector('.system-kbd')).toHaveClass( + 'bg-components-kbd-bg-white', + 'text-text-tertiary', + ) + }) + + it('keeps raw key names on non-mac systems', () => { + Object.defineProperty(globalThis, 'navigator', { + value: { userAgent: 'Windows NT' }, + writable: true, + configurable: true, + }) + + render() + + expect(screen.getByText('ctrl')).toBeInTheDocument() + expect(screen.getByText('alt')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/__tests__/workflow-history-store.spec.tsx b/web/app/components/workflow/__tests__/workflow-history-store.spec.tsx new file mode 100644 index 0000000000..931cd97c02 --- /dev/null +++ b/web/app/components/workflow/__tests__/workflow-history-store.spec.tsx @@ -0,0 +1,97 @@ +import type { Edge, Node } from '../types' +import type { WorkflowHistoryState } from '../workflow-history-store' +import { render, renderHook, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum } from '../types' +import { useWorkflowHistoryStore, WorkflowHistoryProvider } from '../workflow-history-store' + +const nodes: Node[] = [ + { + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + title: 'Start', + desc: '', + type: BlockEnum.Start, + selected: true, + }, + selected: true, + }, +] + +const edges: Edge[] = [ + { + id: 'edge-1', + source: 'node-1', + target: 'node-2', + sourceHandle: null, + targetHandle: null, + type: 'custom', + selected: true, + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.End, + }, + }, +] + +const HistoryConsumer = () => { + const { store, shortcutsEnabled, setShortcutsEnabled } = useWorkflowHistoryStore() + + return ( + + ) +} + +describe('WorkflowHistoryProvider', () => { + it('provides workflow history state and shortcut toggles', async () => { + const user = userEvent.setup() + + render( + + + , + ) + + expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'nodes:1 shortcuts:true' })) + expect(screen.getByRole('button', { name: 'nodes:1 shortcuts:false' })).toBeInTheDocument() + }) + + it('sanitizes selected flags when history state is replaced through the exposed store api', () => { + const wrapper = ({ children }: { children: React.ReactNode }) => ( + + {children} + + ) + + const { result } = renderHook(() => useWorkflowHistoryStore(), { wrapper }) + const nextState: WorkflowHistoryState = { + workflowHistoryEvent: undefined, + workflowHistoryEventMeta: undefined, + nodes, + edges, + } + + result.current.store.setState(nextState) + + expect(result.current.store.getState().nodes[0].data.selected).toBe(false) + expect(result.current.store.getState().edges[0].selected).toBe(false) + }) + + it('throws when consumed outside the provider', () => { + expect(() => renderHook(() => useWorkflowHistoryStore())).toThrow( + 'useWorkflowHistoryStoreApi must be used within a WorkflowHistoryProvider', + ) + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/all-tools.spec.tsx b/web/app/components/workflow/block-selector/__tests__/all-tools.spec.tsx new file mode 100644 index 0000000000..64f012fae3 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/all-tools.spec.tsx @@ -0,0 +1,140 @@ +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { useMarketplacePlugins } from '@/app/components/plugins/marketplace/hooks' +import { useGlobalPublicStore } from '@/context/global-public-context' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import AllTools from '../all-tools' +import { createGlobalPublicStoreState, createToolProvider } from './factories' + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: vi.fn(), +})) + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/plugins/marketplace/hooks', () => ({ + useMarketplacePlugins: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +vi.mock('@/utils/var', async importOriginal => ({ + ...(await importOriginal()), + getMarketplaceUrl: () => 'https://marketplace.test/tools', +})) + +const mockUseMarketplacePlugins = vi.mocked(useMarketplacePlugins) +const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore) +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +const createMarketplacePluginsMock = () => ({ + plugins: [], + total: 0, + resetPlugins: vi.fn(), + queryPlugins: vi.fn(), + queryPluginsWithDebounced: vi.fn(), + cancelQueryPluginsWithDebounced: vi.fn(), + isLoading: false, + isFetchingNextPage: false, + hasNextPage: false, + fetchNextPage: vi.fn(), + page: 0, +}) + +describe('AllTools', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGlobalPublicStore.mockImplementation(selector => selector(createGlobalPublicStoreState(false))) + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + mockUseMarketplacePlugins.mockReturnValue(createMarketplacePluginsMock()) + }) + + it('filters tools by the active tab', async () => { + const user = userEvent.setup() + + render( + , + ) + + expect(screen.getByText('Built In Provider')).toBeInTheDocument() + expect(screen.getByText('Custom Provider')).toBeInTheDocument() + + await user.click(screen.getByText('workflow.tabs.customTool')) + + expect(screen.getByText('Custom Provider')).toBeInTheDocument() + expect(screen.queryByText('Built In Provider')).not.toBeInTheDocument() + }) + + it('filters the rendered tools by the search text', () => { + render( + , + ) + + expect(screen.getByText('Report Toolkit')).toBeInTheDocument() + expect(screen.queryByText('Other Toolkit')).not.toBeInTheDocument() + }) + + it('shows the empty state when no tool matches the current filter', async () => { + render( + , + ) + + await waitFor(() => { + expect(screen.getByText('workflow.tabs.noPluginsFound')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/blocks.spec.tsx b/web/app/components/workflow/block-selector/__tests__/blocks.spec.tsx new file mode 100644 index 0000000000..00972f808c --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/blocks.spec.tsx @@ -0,0 +1,79 @@ +import type { NodeDefault } from '../../types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum } from '../../types' +import Blocks from '../blocks' +import { BlockClassificationEnum } from '../types' + +const runtimeState = vi.hoisted(() => ({ + nodes: [] as Array<{ data: { type?: BlockEnum } }>, +})) + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: () => runtimeState.nodes, + }), + }), +})) + +const createBlock = (type: BlockEnum, title: string, classification = BlockClassificationEnum.Default): NodeDefault => ({ + metaData: { + classification, + sort: 0, + type, + title, + author: 'Dify', + description: `${title} description`, + }, + defaultValue: {}, + checkValid: () => ({ isValid: true }), +}) + +describe('Blocks', () => { + beforeEach(() => { + runtimeState.nodes = [] + }) + + it('renders grouped blocks, filters duplicate knowledge-base nodes, and selects a block', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + runtimeState.nodes = [{ data: { type: BlockEnum.KnowledgeBase } }] + + render( + , + ) + + expect(screen.getByText('LLM')).toBeInTheDocument() + expect(screen.getByText('Exit Loop')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.loop.loopNode')).toBeInTheDocument() + expect(screen.queryByText('Knowledge Retrieval')).not.toBeInTheDocument() + + await user.click(screen.getByText('LLM')) + + expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM) + }) + + it('shows the empty state when no block matches the search', () => { + render( + , + ) + + expect(screen.getByText('workflow.tabs.noResult')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/factories.ts b/web/app/components/workflow/block-selector/__tests__/factories.ts new file mode 100644 index 0000000000..b7d82f7cb3 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/factories.ts @@ -0,0 +1,101 @@ +import type { ToolWithProvider } from '../../types' +import type { Plugin } from '@/app/components/plugins/types' +import type { Tool } from '@/app/components/tools/types' +import { PluginCategoryEnum } from '@/app/components/plugins/types' +import { CollectionType } from '@/app/components/tools/types' +import { defaultSystemFeatures } from '@/types/feature' + +export const createTool = ( + name: string, + label: string, + description = `${label} description`, +): Tool => ({ + name, + author: 'author', + label: { + en_US: label, + zh_Hans: label, + }, + description: { + en_US: description, + zh_Hans: description, + }, + parameters: [], + labels: [], + output_schema: {}, +}) + +export const createToolProvider = ( + overrides: Partial = {}, +): ToolWithProvider => ({ + id: 'provider-1', + name: 'provider-one', + author: 'Provider Author', + description: { + en_US: 'Provider description', + zh_Hans: 'Provider description', + }, + icon: 'icon', + icon_dark: 'icon-dark', + label: { + en_US: 'Provider One', + zh_Hans: 'Provider One', + }, + type: CollectionType.builtIn, + team_credentials: {}, + is_team_authorization: false, + allow_delete: false, + labels: [], + plugin_id: 'plugin-1', + tools: [createTool('tool-a', 'Tool A')], + meta: { version: '1.0.0' } as ToolWithProvider['meta'], + plugin_unique_identifier: 'plugin-1@1.0.0', + ...overrides, +}) + +export const createPlugin = (overrides: Partial = {}): Plugin => ({ + type: 'plugin', + org: 'org', + author: 'author', + name: 'Plugin One', + plugin_id: 'plugin-1', + version: '1.0.0', + latest_version: '1.0.0', + latest_package_identifier: 'plugin-1@1.0.0', + icon: 'icon', + verified: true, + label: { + en_US: 'Plugin One', + zh_Hans: 'Plugin One', + }, + brief: { + en_US: 'Plugin description', + zh_Hans: 'Plugin description', + }, + description: { + en_US: 'Plugin description', + zh_Hans: 'Plugin description', + }, + introduction: 'Plugin introduction', + repository: 'https://example.com/plugin', + category: PluginCategoryEnum.tool, + tags: [], + badges: [], + install_count: 0, + endpoint: { + settings: [], + }, + verification: { + authorized_category: 'community', + }, + from: 'github', + ...overrides, +}) + +export const createGlobalPublicStoreState = (enableMarketplace: boolean) => ({ + systemFeatures: { + ...defaultSystemFeatures, + enable_marketplace: enableMarketplace, + }, + setSystemFeatures: vi.fn(), +}) diff --git a/web/app/components/workflow/block-selector/__tests__/featured-tools.spec.tsx b/web/app/components/workflow/block-selector/__tests__/featured-tools.spec.tsx new file mode 100644 index 0000000000..1720a2d897 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/featured-tools.spec.tsx @@ -0,0 +1,101 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import FeaturedTools from '../featured-tools' +import { createPlugin, createToolProvider } from './factories' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +vi.mock('@/utils/var', async importOriginal => ({ + ...(await importOriginal()), + getMarketplaceUrl: () => 'https://marketplace.test/tools', +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('FeaturedTools', () => { + beforeEach(() => { + vi.clearAllMocks() + localStorage.clear() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('shows more featured tools when the list exceeds the initial quota', async () => { + const user = userEvent.setup() + const plugins = Array.from({ length: 6 }, (_, index) => + createPlugin({ + plugin_id: `plugin-${index + 1}`, + latest_package_identifier: `plugin-${index + 1}@1.0.0`, + label: { en_US: `Plugin ${index + 1}`, zh_Hans: `Plugin ${index + 1}` }, + })) + const providers = plugins.map((plugin, index) => + createToolProvider({ + id: `provider-${index + 1}`, + plugin_id: plugin.plugin_id, + label: { en_US: `Provider ${index + 1}`, zh_Hans: `Provider ${index + 1}` }, + }), + ) + const providerMap = new Map(providers.map(provider => [provider.plugin_id!, provider])) + + render( + , + ) + + expect(screen.getByText('Provider 1')).toBeInTheDocument() + expect(screen.queryByText('Provider 6')).not.toBeInTheDocument() + + await user.click(screen.getByText('workflow.tabs.showMoreFeatured')) + + expect(screen.getByText('Provider 6')).toBeInTheDocument() + }) + + it('honors the persisted collapsed state', () => { + localStorage.setItem('workflow_tools_featured_collapsed', 'true') + + render( + , + ) + + expect(screen.getByText('workflow.tabs.featuredTools')).toBeInTheDocument() + expect(screen.queryByText('Provider One')).not.toBeInTheDocument() + }) + + it('shows the marketplace empty state when no featured tools are available', () => { + render( + , + ) + + expect(screen.getByText('workflow.tabs.noFeaturedPlugins')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/hooks.spec.tsx b/web/app/components/workflow/block-selector/__tests__/hooks.spec.tsx new file mode 100644 index 0000000000..6d27560802 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/hooks.spec.tsx @@ -0,0 +1,52 @@ +import { act, renderHook } from '@testing-library/react' +import { useTabs, useToolTabs } from '../hooks' +import { TabsEnum, ToolTypeEnum } from '../types' + +describe('block-selector hooks', () => { + it('falls back to the first valid tab when the preferred start tab is disabled', () => { + const { result } = renderHook(() => useTabs({ + noStart: false, + hasUserInputNode: true, + defaultActiveTab: TabsEnum.Start, + })) + + expect(result.current.tabs.find(tab => tab.key === TabsEnum.Start)?.disabled).toBe(true) + expect(result.current.activeTab).toBe(TabsEnum.Blocks) + }) + + it('keeps the start tab enabled when forcing it on and resets to a valid tab after disabling blocks', () => { + const props: Parameters[0] = { + noBlocks: false, + noStart: false, + hasUserInputNode: true, + forceEnableStartTab: true, + } + + const { result, rerender } = renderHook(nextProps => useTabs(nextProps), { + initialProps: props, + }) + + expect(result.current.tabs.find(tab => tab.key === TabsEnum.Start)?.disabled).toBeFalsy() + + act(() => { + result.current.setActiveTab(TabsEnum.Blocks) + }) + + rerender({ + ...props, + noBlocks: true, + noSources: true, + noTools: true, + }) + + expect(result.current.activeTab).toBe(TabsEnum.Start) + }) + + it('returns the MCP tab only when it is not hidden', () => { + const { result: visible } = renderHook(() => useToolTabs()) + const { result: hidden } = renderHook(() => useToolTabs(true)) + + expect(visible.current.some(tab => tab.key === ToolTypeEnum.MCP)).toBe(true) + expect(hidden.current.some(tab => tab.key === ToolTypeEnum.MCP)).toBe(false) + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/index.spec.tsx b/web/app/components/workflow/block-selector/__tests__/index.spec.tsx new file mode 100644 index 0000000000..735a831c10 --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/index.spec.tsx @@ -0,0 +1,90 @@ +import type { NodeDefault, ToolWithProvider } from '../../types' +import { screen } from '@testing-library/react' +import { renderWorkflowComponent } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import NodeSelectorWrapper from '../index' +import { BlockClassificationEnum } from '../types' + +vi.mock('reactflow', async () => + (await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock()) + +vi.mock('@/service/use-plugins', () => ({ + useFeaturedToolsRecommendations: () => ({ + plugins: [], + isLoading: false, + }), +})) + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ data: [] }), + useAllCustomTools: () => ({ data: [] }), + useAllWorkflowTools: () => ({ data: [] }), + useAllMCPTools: () => ({ data: [] }), + useInvalidateAllBuiltInTools: () => vi.fn(), +})) + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector({ + systemFeatures: { enable_marketplace: false }, + }), +})) + +const createBlock = (type: BlockEnum, title: string): NodeDefault => ({ + metaData: { + type, + title, + sort: 0, + classification: BlockClassificationEnum.Default, + author: 'Dify', + description: `${title} description`, + }, + defaultValue: {}, + checkValid: () => ({ isValid: true }), +}) + +const dataSource: ToolWithProvider = { + id: 'datasource-1', + name: 'datasource', + author: 'Dify', + description: { en_US: 'Data source', zh_Hans: '数据源' }, + icon: 'icon', + label: { en_US: 'Data Source', zh_Hans: 'Data Source' }, + type: 'datasource' as ToolWithProvider['type'], + team_credentials: {}, + is_team_authorization: false, + allow_delete: false, + labels: [], + tools: [], + meta: { version: '1.0.0' } as ToolWithProvider['meta'], +} + +describe('NodeSelectorWrapper', () => { + it('filters hidden block types from hooks store and forwards data sources', async () => { + renderWorkflowComponent( + , + { + hooksStoreProps: { + availableNodesMetaData: { + nodes: [ + createBlock(BlockEnum.Start, 'Start'), + createBlock(BlockEnum.Tool, 'Tool'), + createBlock(BlockEnum.Code, 'Code'), + createBlock(BlockEnum.DataSource, 'Data Source'), + ], + }, + }, + initialStoreState: { + dataSourceList: [dataSource], + }, + }, + ) + + expect(await screen.findByText('Code')).toBeInTheDocument() + expect(screen.queryByText('Start')).not.toBeInTheDocument() + expect(screen.queryByText('Tool')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/main.spec.tsx b/web/app/components/workflow/block-selector/__tests__/main.spec.tsx new file mode 100644 index 0000000000..1deb6ce84c --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/main.spec.tsx @@ -0,0 +1,95 @@ +import type { NodeDefault } from '../../types' +import { screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { renderWorkflowComponent } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import NodeSelector from '../main' +import { BlockClassificationEnum } from '../types' + +vi.mock('reactflow', () => ({ + useStoreApi: () => ({ + getState: () => ({ + getNodes: () => [], + }), + }), +})) + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector({ + systemFeatures: { enable_marketplace: false }, + }), +})) + +vi.mock('@/service/use-plugins', () => ({ + useFeaturedToolsRecommendations: () => ({ + plugins: [], + isLoading: false, + }), +})) + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ data: [] }), + useAllCustomTools: () => ({ data: [] }), + useAllWorkflowTools: () => ({ data: [] }), + useAllMCPTools: () => ({ data: [] }), + useInvalidateAllBuiltInTools: () => vi.fn(), +})) + +const createBlock = (type: BlockEnum, title: string): NodeDefault => ({ + metaData: { + classification: BlockClassificationEnum.Default, + sort: 0, + type, + title, + author: 'Dify', + description: `${title} description`, + }, + defaultValue: {}, + checkValid: () => ({ isValid: true }), +}) + +describe('NodeSelector', () => { + it('opens with the real blocks tab, filters by search, selects a block, and clears search after close', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + renderWorkflowComponent( + ( + + )} + />, + ) + + await user.click(screen.getByRole('button', { name: 'selector-closed' })) + + const searchInput = screen.getByPlaceholderText('workflow.tabs.searchBlock') + expect(screen.getByText('LLM')).toBeInTheDocument() + expect(screen.getByText('End')).toBeInTheDocument() + + await user.type(searchInput, 'LLM') + expect(screen.getByText('LLM')).toBeInTheDocument() + expect(screen.queryByText('End')).not.toBeInTheDocument() + + await user.click(screen.getByText('LLM')) + + expect(onSelect).toHaveBeenCalledWith(BlockEnum.LLM, undefined) + await waitFor(() => { + expect(screen.queryByPlaceholderText('workflow.tabs.searchBlock')).not.toBeInTheDocument() + }) + + await user.click(screen.getByRole('button', { name: 'selector-closed' })) + + const reopenedInput = screen.getByPlaceholderText('workflow.tabs.searchBlock') as HTMLInputElement + expect(reopenedInput.value).toBe('') + expect(screen.getByText('End')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/__tests__/tools.spec.tsx b/web/app/components/workflow/block-selector/__tests__/tools.spec.tsx new file mode 100644 index 0000000000..a800342e6e --- /dev/null +++ b/web/app/components/workflow/block-selector/__tests__/tools.spec.tsx @@ -0,0 +1,95 @@ +import { render, screen } from '@testing-library/react' +import { CollectionType } from '@/app/components/tools/types' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import Tools from '../tools' +import { ViewType } from '../view-type-select' +import { createToolProvider } from './factories' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('Tools', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('shows the empty state when there are no tools and no search text', () => { + render( + , + ) + + expect(screen.getByText('No tools available')).toBeInTheDocument() + }) + + it('renders tree groups for built-in and custom providers', () => { + render( + , + ) + + expect(screen.getByText('Built In')).toBeInTheDocument() + expect(screen.getByText('workflow.tabs.customTool')).toBeInTheDocument() + expect(screen.getByText('Built In Provider')).toBeInTheDocument() + expect(screen.getByText('Custom Provider')).toBeInTheDocument() + }) + + it('shows the alphabetical index in flat view when enough tools are present', () => { + const { container } = render( + + createToolProvider({ + id: `provider-${index}`, + label: { + en_US: `${String.fromCharCode(65 + index)} Provider`, + zh_Hans: `${String.fromCharCode(65 + index)} Provider`, + }, + }))} + onSelect={vi.fn()} + viewType={ViewType.flat} + hasSearchText={false} + />, + ) + + expect(container.querySelector('.index-bar')).toBeInTheDocument() + expect(screen.getByText('A Provider')).toBeInTheDocument() + expect(screen.getByText('K Provider')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/tool/__tests__/tool.spec.tsx b/web/app/components/workflow/block-selector/tool/__tests__/tool.spec.tsx new file mode 100644 index 0000000000..d9fad38854 --- /dev/null +++ b/web/app/components/workflow/block-selector/tool/__tests__/tool.spec.tsx @@ -0,0 +1,99 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { trackEvent } from '@/app/components/base/amplitude' +import { CollectionType } from '@/app/components/tools/types' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { BlockEnum } from '../../../types' +import { createTool, createToolProvider } from '../../__tests__/factories' +import { ViewType } from '../../view-type-select' +import Tool from '../tool' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/base/amplitude', () => ({ + trackEvent: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) +const mockTrackEvent = vi.mocked(trackEvent) + +describe('Tool', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('expands a provider and selects an action item', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + render( + , + ) + + await user.click(screen.getByText('Provider One')) + await user.click(screen.getByText('Tool B')) + + expect(onSelect).toHaveBeenCalledWith(BlockEnum.Tool, expect.objectContaining({ + provider_id: 'provider-1', + provider_name: 'provider-one', + tool_name: 'tool-b', + title: 'Tool B', + })) + expect(mockTrackEvent).toHaveBeenCalledWith('tool_selected', { + tool_name: 'tool-b', + plugin_id: 'plugin-1', + }) + }) + + it('selects workflow tools directly without expanding the provider', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + render( + , + ) + + await user.click(screen.getByText('Workflow Tool')) + + expect(onSelect).toHaveBeenCalledWith(BlockEnum.Tool, expect.objectContaining({ + provider_type: CollectionType.workflow, + tool_name: 'workflow-tool', + tool_label: 'Workflow Tool', + })) + }) +}) diff --git a/web/app/components/workflow/block-selector/tool/tool-list-flat-view/__tests__/list.spec.tsx b/web/app/components/workflow/block-selector/tool/tool-list-flat-view/__tests__/list.spec.tsx new file mode 100644 index 0000000000..ecb5dfe0a6 --- /dev/null +++ b/web/app/components/workflow/block-selector/tool/tool-list-flat-view/__tests__/list.spec.tsx @@ -0,0 +1,66 @@ +import { render, screen } from '@testing-library/react' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { createToolProvider } from '../../../__tests__/factories' +import List from '../list' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('ToolListFlatView', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('assigns the first tool of each letter to the shared refs and renders the index bar', () => { + const toolRefs = { + current: {} as Record, + } + + render( + ), + createToolProvider({ + id: 'provider-b', + label: { en_US: 'B Provider', zh_Hans: 'B Provider' }, + letter: 'B', + } as ReturnType), + ]} + isShowLetterIndex + indexBar={
} + hasSearchText={false} + onSelect={vi.fn()} + toolRefs={toolRefs} + />, + ) + + expect(screen.getByText('A Provider')).toBeInTheDocument() + expect(screen.getByText('B Provider')).toBeInTheDocument() + expect(screen.getByTestId('index-bar')).toBeInTheDocument() + expect(toolRefs.current.A).toBeTruthy() + expect(toolRefs.current.B).toBeTruthy() + }) +}) diff --git a/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/item.spec.tsx b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/item.spec.tsx new file mode 100644 index 0000000000..027ad7c11c --- /dev/null +++ b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/item.spec.tsx @@ -0,0 +1,47 @@ +import { render, screen } from '@testing-library/react' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { createToolProvider } from '../../../__tests__/factories' +import Item from '../item' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('ToolListTreeView Item', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('renders the group heading and its provider list', () => { + render( + , + ) + + expect(screen.getByText('My Group')).toBeInTheDocument() + expect(screen.getByText('Provider Alpha')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/list.spec.tsx b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/list.spec.tsx new file mode 100644 index 0000000000..7b3c083e85 --- /dev/null +++ b/web/app/components/workflow/block-selector/tool/tool-list-tree-view/__tests__/list.spec.tsx @@ -0,0 +1,56 @@ +import { render, screen } from '@testing-library/react' +import { useGetLanguage } from '@/context/i18n' +import useTheme from '@/hooks/use-theme' +import { Theme } from '@/types/app' +import { createToolProvider } from '../../../__tests__/factories' +import { CUSTOM_GROUP_NAME } from '../../../index-bar' +import List from '../list' + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: vi.fn(), +})) + +vi.mock('@/hooks/use-theme', () => ({ + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/mcp-tool-availability', () => ({ + useMCPToolAvailability: () => ({ + allowed: true, + }), +})) + +const mockUseGetLanguage = vi.mocked(useGetLanguage) +const mockUseTheme = vi.mocked(useTheme) + +describe('ToolListTreeView', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseGetLanguage.mockReturnValue('en_US') + mockUseTheme.mockReturnValue({ theme: Theme.light } as ReturnType) + }) + + it('translates built-in special group names and renders the nested providers', () => { + render( + , + ) + + expect(screen.getByText('BuiltIn')).toBeInTheDocument() + expect(screen.getByText('workflow.tabs.customTool')).toBeInTheDocument() + expect(screen.getByText('Built In Provider')).toBeInTheDocument() + expect(screen.getByText('Custom Provider')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/datasets-detail-store/__tests__/store.spec.tsx b/web/app/components/workflow/datasets-detail-store/__tests__/store.spec.tsx new file mode 100644 index 0000000000..a031c6370e --- /dev/null +++ b/web/app/components/workflow/datasets-detail-store/__tests__/store.spec.tsx @@ -0,0 +1,91 @@ +import type { DataSet } from '@/models/datasets' +import { renderHook } from '@testing-library/react' +import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' +import { DatasetsDetailContext } from '../provider' +import { createDatasetsDetailStore, useDatasetsDetailStore } from '../store' + +const createDataset = (id: string, name = `dataset-${id}`): DataSet => ({ + id, + name, + indexing_status: 'completed', + icon_info: { + icon: 'book', + icon_type: 'emoji' as DataSet['icon_info']['icon_type'], + }, + description: `${name} description`, + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + created_by: 'user-1', + updated_by: 'user-1', + updated_at: 1, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 0, + total_document_count: 0, + word_count: 0, + provider: 'provider', + embedding_model: 'model', + embedding_model_provider: 'provider', + embedding_available: true, + retrieval_model_dict: {} as DataSet['retrieval_model_dict'], + retrieval_model: {} as DataSet['retrieval_model'], + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 1, + score_threshold: 0, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + runtime_mode: 'general', + enable_api: false, + is_multimodal: false, +}) + +describe('datasets-detail-store store', () => { + it('merges dataset details by id', () => { + const store = createDatasetsDetailStore() + + store.getState().updateDatasetsDetail([ + createDataset('dataset-1', 'Dataset One'), + createDataset('dataset-2', 'Dataset Two'), + ]) + store.getState().updateDatasetsDetail([ + createDataset('dataset-2', 'Dataset Two Updated'), + ]) + + expect(store.getState().datasetsDetail).toMatchObject({ + 'dataset-1': { name: 'Dataset One' }, + 'dataset-2': { name: 'Dataset Two Updated' }, + }) + }) + + it('reads state from the datasets detail context', () => { + const store = createDatasetsDetailStore() + store.getState().updateDatasetsDetail([createDataset('dataset-3')]) + const wrapper = ({ children }: { children: React.ReactNode }) => ( + + {children} + + ) + + const { result } = renderHook( + () => useDatasetsDetailStore(state => state.datasetsDetail['dataset-3']?.name), + { wrapper }, + ) + + expect(result.current).toBe('dataset-dataset-3') + }) + + it('throws when the datasets detail provider is missing', () => { + expect(() => renderHook(() => useDatasetsDetailStore(state => state.datasetsDetail))).toThrow( + 'Missing DatasetsDetailContext.Provider in the tree', + ) + }) +}) diff --git a/web/app/components/workflow/hooks-store/__tests__/store.spec.tsx b/web/app/components/workflow/hooks-store/__tests__/store.spec.tsx new file mode 100644 index 0000000000..131290b834 --- /dev/null +++ b/web/app/components/workflow/hooks-store/__tests__/store.spec.tsx @@ -0,0 +1,41 @@ +import { renderHook } from '@testing-library/react' +import { HooksStoreContext } from '../provider' +import { createHooksStore, useHooksStore } from '../store' + +describe('hooks-store store', () => { + it('creates default callbacks and refreshes selected handlers', () => { + const store = createHooksStore({}) + const handleBackupDraft = vi.fn() + + expect(store.getState().availableNodesMetaData).toEqual({ nodes: [] }) + expect(store.getState().hasNodeInspectVars('node-1')).toBe(false) + expect(store.getState().getWorkflowRunAndTraceUrl('run-1')).toEqual({ + runUrl: '', + traceUrl: '', + }) + + store.getState().refreshAll({ handleBackupDraft }) + + expect(store.getState().handleBackupDraft).toBe(handleBackupDraft) + }) + + it('reads state from the hooks store context', () => { + const handleRun = vi.fn() + const store = createHooksStore({ handleRun }) + const wrapper = ({ children }: { children: React.ReactNode }) => ( + + {children} + + ) + + const { result } = renderHook(() => useHooksStore(state => state.handleRun), { wrapper }) + + expect(result.current).toBe(handleRun) + }) + + it('throws when the hooks store provider is missing', () => { + expect(() => renderHook(() => useHooksStore(state => state.handleRun))).toThrow( + 'Missing HooksStoreContext.Provider in the tree', + ) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-DSL.spec.ts b/web/app/components/workflow/hooks/__tests__/use-DSL.spec.ts new file mode 100644 index 0000000000..f10777ae69 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-DSL.spec.ts @@ -0,0 +1,19 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useDSL } from '../use-DSL' + +describe('useDSL', () => { + it('returns the DSL handlers from hooks store', () => { + const exportCheck = vi.fn() + const handleExportDSL = vi.fn() + + const { result } = renderWorkflowHook(() => useDSL(), { + hooksStoreProps: { + exportCheck, + handleExportDSL, + }, + }) + + expect(result.current.exportCheck).toBe(exportCheck) + expect(result.current.handleExportDSL).toBe(handleExportDSL) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-edges-interactions-without-sync.spec.ts b/web/app/components/workflow/hooks/__tests__/use-edges-interactions-without-sync.spec.ts new file mode 100644 index 0000000000..b38aca6398 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-edges-interactions-without-sync.spec.ts @@ -0,0 +1,90 @@ +import { act, waitFor } from '@testing-library/react' +import { useEdges } from 'reactflow' +import { createEdge, createNode } from '../../__tests__/fixtures' +import { renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../types' +import { useEdgesInteractionsWithoutSync } from '../use-edges-interactions-without-sync' + +type EdgeRuntimeState = { + _sourceRunningStatus?: NodeRunningStatus + _targetRunningStatus?: NodeRunningStatus + _waitingRun?: boolean +} + +const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => + (edge?.data ?? {}) as EdgeRuntimeState + +const createFlowNodes = () => [ + createNode({ id: 'a' }), + createNode({ id: 'b' }), + createNode({ id: 'c' }), +] + +const createFlowEdges = () => [ + createEdge({ + id: 'e1', + source: 'a', + target: 'b', + data: { + _sourceRunningStatus: NodeRunningStatus.Running, + _targetRunningStatus: NodeRunningStatus.Running, + _waitingRun: true, + }, + }), + createEdge({ + id: 'e2', + source: 'b', + target: 'c', + data: { + _sourceRunningStatus: NodeRunningStatus.Succeeded, + _targetRunningStatus: undefined, + _waitingRun: false, + }, + }), +] + +const renderEdgesInteractionsHook = () => + renderWorkflowFlowHook(() => ({ + ...useEdgesInteractionsWithoutSync(), + edges: useEdges(), + }), { + nodes: createFlowNodes(), + edges: createFlowEdges(), + }) + +describe('useEdgesInteractionsWithoutSync', () => { + it('clears running status and waitingRun on all edges', () => { + const { result } = renderEdgesInteractionsHook() + + act(() => { + result.current.handleEdgeCancelRunningStatus() + }) + + return waitFor(() => { + result.current.edges.forEach((edge) => { + const edgeState = getEdgeRuntimeState(edge) + expect(edgeState._sourceRunningStatus).toBeUndefined() + expect(edgeState._targetRunningStatus).toBeUndefined() + expect(edgeState._waitingRun).toBe(false) + }) + }) + }) + + it('does not mutate the original edges array', () => { + const edges = createFlowEdges() + const originalData = { ...getEdgeRuntimeState(edges[0]) } + const { result } = renderWorkflowFlowHook(() => ({ + ...useEdgesInteractionsWithoutSync(), + edges: useEdges(), + }), { + nodes: createFlowNodes(), + edges, + }) + + act(() => { + result.current.handleEdgeCancelRunningStatus() + }) + + expect(getEdgeRuntimeState(edges[0])._sourceRunningStatus).toBe(originalData._sourceRunningStatus) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-edges-interactions.helpers.spec.ts b/web/app/components/workflow/hooks/__tests__/use-edges-interactions.helpers.spec.ts new file mode 100644 index 0000000000..3741bcc653 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-edges-interactions.helpers.spec.ts @@ -0,0 +1,114 @@ +import { createEdge, createNode } from '../../__tests__/fixtures' +import { getNodesConnectedSourceOrTargetHandleIdsMap } from '../../utils' +import { + applyConnectedHandleNodeData, + buildContextMenuEdges, + clearEdgeMenuIfNeeded, + clearNodeSelectionState, + updateEdgeHoverState, + updateEdgeSelectionState, +} from '../use-edges-interactions.helpers' + +vi.mock('../../utils', () => ({ + getNodesConnectedSourceOrTargetHandleIdsMap: vi.fn(), +})) + +const mockGetNodesConnectedSourceOrTargetHandleIdsMap = vi.mocked(getNodesConnectedSourceOrTargetHandleIdsMap) + +describe('use-edges-interactions.helpers', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('applyConnectedHandleNodeData should merge connected handle metadata into matching nodes', () => { + mockGetNodesConnectedSourceOrTargetHandleIdsMap.mockReturnValue({ + 'node-1': { + _connectedSourceHandleIds: ['branch-a'], + }, + }) + + const nodes = [ + createNode({ id: 'node-1', data: { title: 'Source' } }), + createNode({ id: 'node-2', data: { title: 'Target' } }), + ] + const edgeChanges = [{ + type: 'add', + edge: createEdge({ id: 'edge-1', source: 'node-1', target: 'node-2' }), + }] + + const result = applyConnectedHandleNodeData(nodes, edgeChanges) + + expect(result[0].data._connectedSourceHandleIds).toEqual(['branch-a']) + expect(result[1].data._connectedSourceHandleIds).toEqual([]) + expect(mockGetNodesConnectedSourceOrTargetHandleIdsMap).toHaveBeenCalledWith(edgeChanges, nodes) + }) + + it('clearEdgeMenuIfNeeded should return true only when the open menu belongs to a removed edge', () => { + expect(clearEdgeMenuIfNeeded({ + edgeMenu: { edgeId: 'edge-1' }, + edgeIds: ['edge-1', 'edge-2'], + })).toBe(true) + + expect(clearEdgeMenuIfNeeded({ + edgeMenu: { edgeId: 'edge-3' }, + edgeIds: ['edge-1', 'edge-2'], + })).toBe(false) + + expect(clearEdgeMenuIfNeeded({ + edgeIds: ['edge-1'], + })).toBe(false) + }) + + it('updateEdgeHoverState should toggle only the hovered edge flag', () => { + const edges = [ + createEdge({ id: 'edge-1', data: { _hovering: false } }), + createEdge({ id: 'edge-2', data: { _hovering: false } }), + ] + + const result = updateEdgeHoverState(edges, 'edge-2', true) + + expect(result.find(edge => edge.id === 'edge-1')?.data._hovering).toBe(false) + expect(result.find(edge => edge.id === 'edge-2')?.data._hovering).toBe(true) + }) + + it('updateEdgeSelectionState should update selected flags for select changes only', () => { + const edges = [ + createEdge({ id: 'edge-1', selected: false }), + createEdge({ id: 'edge-2', selected: true }), + ] + + const result = updateEdgeSelectionState(edges, [ + { type: 'select', id: 'edge-1', selected: true }, + { type: 'remove', id: 'edge-2' }, + ]) + + expect(result.find(edge => edge.id === 'edge-1')?.selected).toBe(true) + expect(result.find(edge => edge.id === 'edge-2')?.selected).toBe(true) + }) + + it('buildContextMenuEdges should select the target edge and clear bundled markers', () => { + const edges = [ + createEdge({ id: 'edge-1', selected: true, data: { _isBundled: true } }), + createEdge({ id: 'edge-2', selected: false, data: { _isBundled: true } }), + ] + + const result = buildContextMenuEdges(edges, 'edge-2') + + expect(result.find(edge => edge.id === 'edge-1')?.selected).toBe(false) + expect(result.find(edge => edge.id === 'edge-2')?.selected).toBe(true) + expect(result.every(edge => edge.data._isBundled === false)).toBe(true) + }) + + it('clearNodeSelectionState should clear selected state and bundled markers on every node', () => { + const nodes = [ + createNode({ id: 'node-1', selected: true, data: { selected: true, _isBundled: true } }), + createNode({ id: 'node-2', selected: false, data: { selected: true, _isBundled: true } }), + ] + + const result = clearNodeSelectionState(nodes) + + expect(result.every(node => node.selected === false)).toBe(true) + expect(result.every(node => node.data.selected === false)).toBe(true) + expect(result.every(node => node.data._isBundled === false)).toBe(true) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-fetch-workflow-inspect-vars.spec.ts b/web/app/components/workflow/hooks/__tests__/use-fetch-workflow-inspect-vars.spec.ts new file mode 100644 index 0000000000..e1e26732ae --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-fetch-workflow-inspect-vars.spec.ts @@ -0,0 +1,187 @@ +import type { SchemaTypeDefinition } from '@/service/use-common' +import type { VarInInspect } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { FlowType } from '@/types/common' +import { createNode } from '../../__tests__/fixtures' +import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum, VarType } from '../../types' +import { useSetWorkflowVarsWithValue } from '../use-fetch-workflow-inspect-vars' + +const mockFetchAllInspectVars = vi.hoisted(() => vi.fn()) +const mockInvalidateConversationVarValues = vi.hoisted(() => vi.fn()) +const mockInvalidateSysVarValues = vi.hoisted(() => vi.fn()) +const mockHandleCancelAllNodeSuccessStatus = vi.hoisted(() => vi.fn()) +const mockToNodeOutputVars = vi.hoisted(() => vi.fn()) + +const schemaTypeDefinitions: SchemaTypeDefinition[] = [{ + name: 'simple', + schema: { + properties: {}, + }, +}] + +vi.mock('reactflow', async () => + (await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock()) + +vi.mock('@/service/use-tools', async () => + (await import('../../__tests__/service-mock-factory')).createToolServiceMock()) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidateConversationVarValues: () => mockInvalidateConversationVarValues, + useInvalidateSysVarValues: () => mockInvalidateSysVarValues, +})) + +vi.mock('@/service/workflow', () => ({ + fetchAllInspectVars: (...args: unknown[]) => mockFetchAllInspectVars(...args), +})) + +vi.mock('../use-nodes-interactions-without-sync', () => ({ + useNodesInteractionsWithoutSync: () => ({ + handleCancelAllNodeSuccessStatus: mockHandleCancelAllNodeSuccessStatus, + }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/use-match-schema-type', () => ({ + default: () => ({ + schemaTypeDefinitions, + }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/utils', () => ({ + toNodeOutputVars: (...args: unknown[]) => mockToNodeOutputVars(...args), +})) + +const createInspectVar = (overrides: Partial = {}): VarInInspect => ({ + id: 'var-1', + type: 'node', + name: 'answer', + description: 'Answer', + selector: ['node-1', 'answer'], + value_type: VarType.string, + value: 'hello', + edited: false, + visible: true, + is_truncated: false, + full_content: { + size_bytes: 5, + download_url: 'https://example.com/answer.txt', + }, + ...overrides, +}) + +describe('use-fetch-workflow-inspect-vars', () => { + beforeEach(() => { + vi.clearAllMocks() + resetReactFlowMockState() + rfState.nodes = [ + createNode({ + id: 'node-1', + data: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + }, + }), + ] + mockToNodeOutputVars.mockReturnValue([{ + nodeId: 'node-1', + vars: [{ + variable: 'answer', + schemaType: 'simple', + }], + }]) + }) + + it('fetches inspect vars, invalidates cached values, and stores schema-enriched node vars', async () => { + mockFetchAllInspectVars.mockResolvedValue([ + createInspectVar(), + createInspectVar({ + id: 'missing-node-var', + selector: ['missing-node', 'answer'], + }), + ]) + + const { result, store } = renderWorkflowHook( + () => useSetWorkflowVarsWithValue({ + flowType: FlowType.appFlow, + flowId: 'flow-1', + }), + { + initialStoreState: { + dataSourceList: [], + }, + }, + ) + + await act(async () => { + await result.current.fetchInspectVars({}) + }) + + expect(mockInvalidateConversationVarValues).toHaveBeenCalledTimes(1) + expect(mockInvalidateSysVarValues).toHaveBeenCalledTimes(1) + expect(mockFetchAllInspectVars).toHaveBeenCalledWith(FlowType.appFlow, 'flow-1') + expect(mockHandleCancelAllNodeSuccessStatus).toHaveBeenCalledTimes(1) + expect(store.getState().nodesWithInspectVars).toEqual([ + expect.objectContaining({ + nodeId: 'node-1', + nodeType: BlockEnum.Code, + title: 'Code', + vars: [ + expect.objectContaining({ + id: 'var-1', + selector: ['node-1', 'answer'], + schemaType: 'simple', + value: 'hello', + }), + ], + }), + ]) + }) + + it('accepts passed-in vars and plugin metadata without refetching from the API', async () => { + const passedInVars = [ + createInspectVar({ + id: 'var-2', + value: 'passed-in', + }), + ] + const passedInPluginInfo = { + buildInTools: [], + customTools: [], + workflowTools: [], + mcpTools: [], + dataSourceList: [], + } + + const { result, store } = renderWorkflowHook( + () => useSetWorkflowVarsWithValue({ + flowType: FlowType.appFlow, + flowId: 'flow-2', + }), + { + initialStoreState: { + dataSourceList: [], + }, + }, + ) + + await act(async () => { + await result.current.fetchInspectVars({ + passInVars: true, + vars: passedInVars, + passedInAllPluginInfoList: passedInPluginInfo, + passedInSchemaTypeDefinitions: schemaTypeDefinitions, + }) + }) + + await waitFor(() => { + expect(mockFetchAllInspectVars).not.toHaveBeenCalled() + expect(store.getState().nodesWithInspectVars[0]?.vars[0]).toMatchObject({ + id: 'var-2', + value: 'passed-in', + schemaType: 'simple', + }) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud-common.spec.ts b/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud-common.spec.ts new file mode 100644 index 0000000000..7b2006aa77 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud-common.spec.ts @@ -0,0 +1,210 @@ +import type { SchemaTypeDefinition } from '@/service/use-common' +import type { VarInInspect } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { FlowType } from '@/types/common' +import { createNode } from '../../__tests__/fixtures' +import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum, VarType } from '../../types' +import { useInspectVarsCrudCommon } from '../use-inspect-vars-crud-common' + +const mockFetchNodeInspectVars = vi.hoisted(() => vi.fn()) +const mockDoDeleteAllInspectorVars = vi.hoisted(() => vi.fn()) +const mockInvalidateConversationVarValues = vi.hoisted(() => vi.fn()) +const mockInvalidateSysVarValues = vi.hoisted(() => vi.fn()) +const mockHandleCancelNodeSuccessStatus = vi.hoisted(() => vi.fn()) +const mockHandleEdgeCancelRunningStatus = vi.hoisted(() => vi.fn()) +const mockToNodeOutputVars = vi.hoisted(() => vi.fn()) + +const schemaTypeDefinitions: SchemaTypeDefinition[] = [{ + name: 'simple', + schema: { + properties: {}, + }, +}] + +vi.mock('reactflow', async () => + (await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock()) + +vi.mock('@/service/use-flow', () => ({ + default: () => ({ + useInvalidateConversationVarValues: () => mockInvalidateConversationVarValues, + useInvalidateSysVarValues: () => mockInvalidateSysVarValues, + useResetConversationVar: () => ({ mutateAsync: vi.fn() }), + useResetToLastRunValue: () => ({ mutateAsync: vi.fn() }), + useDeleteAllInspectorVars: () => ({ mutateAsync: mockDoDeleteAllInspectorVars }), + useDeleteNodeInspectorVars: () => ({ mutate: vi.fn() }), + useDeleteInspectVar: () => ({ mutate: vi.fn() }), + useEditInspectorVar: () => ({ mutateAsync: vi.fn() }), + }), +})) + +vi.mock('@/service/use-tools', async () => + (await import('../../__tests__/service-mock-factory')).createToolServiceMock()) + +vi.mock('@/service/workflow', () => ({ + fetchNodeInspectVars: (...args: unknown[]) => mockFetchNodeInspectVars(...args), +})) + +vi.mock('../use-nodes-interactions-without-sync', () => ({ + useNodesInteractionsWithoutSync: () => ({ + handleCancelNodeSuccessStatus: mockHandleCancelNodeSuccessStatus, + }), +})) + +vi.mock('../use-edges-interactions-without-sync', () => ({ + useEdgesInteractionsWithoutSync: () => ({ + handleEdgeCancelRunningStatus: mockHandleEdgeCancelRunningStatus, + }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/utils', async importOriginal => ({ + ...(await importOriginal()), + toNodeOutputVars: (...args: unknown[]) => mockToNodeOutputVars(...args), +})) + +const createInspectVar = (overrides: Partial = {}): VarInInspect => ({ + id: 'var-1', + type: 'node', + name: 'answer', + description: 'Answer', + selector: ['node-1', 'answer'], + value_type: VarType.string, + value: 'hello', + edited: false, + visible: true, + is_truncated: false, + full_content: { + size_bytes: 5, + download_url: 'https://example.com/answer.txt', + }, + ...overrides, +}) + +describe('useInspectVarsCrudCommon', () => { + beforeEach(() => { + vi.clearAllMocks() + resetReactFlowMockState() + rfState.nodes = [ + createNode({ + id: 'node-1', + data: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + }, + }), + ] + mockToNodeOutputVars.mockReturnValue([{ + nodeId: 'node-1', + vars: [{ + variable: 'answer', + schemaType: 'simple', + }], + }]) + }) + + it('invalidates cached system vars without refetching node values for system selectors', async () => { + const { result } = renderWorkflowHook( + () => useInspectVarsCrudCommon({ + flowId: 'flow-1', + flowType: FlowType.appFlow, + }), + { + initialStoreState: { + dataSourceList: [], + }, + }, + ) + + await act(async () => { + await result.current.fetchInspectVarValue(['sys', 'query'], schemaTypeDefinitions) + }) + + expect(mockInvalidateSysVarValues).toHaveBeenCalledTimes(1) + expect(mockFetchNodeInspectVars).not.toHaveBeenCalled() + }) + + it('fetches node inspect vars, adds schema types, and marks the node as fetched', async () => { + mockFetchNodeInspectVars.mockResolvedValue([ + createInspectVar(), + ]) + + const { result, store } = renderWorkflowHook( + () => useInspectVarsCrudCommon({ + flowId: 'flow-1', + flowType: FlowType.appFlow, + }), + { + initialStoreState: { + dataSourceList: [], + nodesWithInspectVars: [{ + nodeId: 'node-1', + nodePayload: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + } as never, + nodeType: BlockEnum.Code, + title: 'Code', + vars: [], + }], + }, + }, + ) + + await act(async () => { + await result.current.fetchInspectVarValue(['node-1', 'answer'], schemaTypeDefinitions) + }) + + await waitFor(() => { + expect(mockFetchNodeInspectVars).toHaveBeenCalledWith(FlowType.appFlow, 'flow-1', 'node-1') + expect(store.getState().nodesWithInspectVars[0]).toMatchObject({ + nodeId: 'node-1', + isValueFetched: true, + vars: [ + expect.objectContaining({ + id: 'var-1', + schemaType: 'simple', + }), + ], + }) + }) + }) + + it('deletes all inspect vars, invalidates cached values, and clears edge running state', async () => { + mockDoDeleteAllInspectorVars.mockResolvedValue(undefined) + + const { result, store } = renderWorkflowHook( + () => useInspectVarsCrudCommon({ + flowId: 'flow-1', + flowType: FlowType.appFlow, + }), + { + initialStoreState: { + nodesWithInspectVars: [{ + nodeId: 'node-1', + nodePayload: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + } as never, + nodeType: BlockEnum.Code, + title: 'Code', + vars: [createInspectVar()], + }], + }, + }, + ) + + await act(async () => { + await result.current.deleteAllInspectorVars() + }) + + expect(mockDoDeleteAllInspectorVars).toHaveBeenCalledTimes(1) + expect(mockInvalidateConversationVarValues).toHaveBeenCalledTimes(1) + expect(mockInvalidateSysVarValues).toHaveBeenCalledTimes(1) + expect(mockHandleEdgeCancelRunningStatus).toHaveBeenCalledTimes(1) + expect(store.getState().nodesWithInspectVars).toEqual([]) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud.spec.ts b/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud.spec.ts new file mode 100644 index 0000000000..193e4307de --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-inspect-vars-crud.spec.ts @@ -0,0 +1,135 @@ +import type { VarInInspect } from '@/types/workflow' +import { FlowType } from '@/types/common' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum, VarType } from '../../types' +import useInspectVarsCrud from '../use-inspect-vars-crud' + +const mockUseConversationVarValues = vi.hoisted(() => vi.fn()) +const mockUseSysVarValues = vi.hoisted(() => vi.fn()) + +vi.mock('@/service/use-workflow', () => ({ + useConversationVarValues: (...args: unknown[]) => mockUseConversationVarValues(...args), + useSysVarValues: (...args: unknown[]) => mockUseSysVarValues(...args), +})) + +const createInspectVar = (overrides: Partial = {}): VarInInspect => ({ + id: 'var-1', + type: 'node', + name: 'answer', + description: 'Answer', + selector: ['node-1', 'answer'], + value_type: VarType.string, + value: 'hello', + edited: false, + visible: true, + is_truncated: false, + full_content: { + size_bytes: 5, + download_url: 'https://example.com/answer.txt', + }, + ...overrides, +}) + +describe('useInspectVarsCrud', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseConversationVarValues.mockReturnValue({ + data: [createInspectVar({ + id: 'conversation-var', + name: 'history', + selector: ['conversation', 'history'], + })], + }) + mockUseSysVarValues.mockReturnValue({ + data: [ + createInspectVar({ + id: 'query-var', + name: 'query', + selector: ['sys', 'query'], + }), + createInspectVar({ + id: 'files-var', + name: 'files', + selector: ['sys', 'files'], + }), + createInspectVar({ + id: 'time-var', + name: 'time', + selector: ['sys', 'time'], + }), + ], + }) + }) + + it('appends query/files system vars to start-node inspect vars and filters them from the system list', () => { + const hasNodeInspectVars = vi.fn(() => true) + const deleteAllInspectorVars = vi.fn() + const fetchInspectVarValue = vi.fn() + + const { result } = renderWorkflowHook(() => useInspectVarsCrud(), { + initialStoreState: { + nodesWithInspectVars: [{ + nodeId: 'start-node', + nodePayload: { + type: BlockEnum.Start, + title: 'Start', + desc: '', + } as never, + nodeType: BlockEnum.Start, + title: 'Start', + vars: [createInspectVar({ + id: 'start-answer', + selector: ['start-node', 'answer'], + })], + }], + }, + hooksStoreProps: { + configsMap: { + flowId: 'flow-1', + flowType: FlowType.appFlow, + fileSettings: {} as never, + }, + hasNodeInspectVars, + fetchInspectVarValue, + editInspectVarValue: vi.fn(), + renameInspectVarName: vi.fn(), + appendNodeInspectVars: vi.fn(), + deleteInspectVar: vi.fn(), + deleteNodeInspectorVars: vi.fn(), + deleteAllInspectorVars, + isInspectVarEdited: vi.fn(() => false), + resetToLastRunVar: vi.fn(), + invalidateSysVarValues: vi.fn(), + resetConversationVar: vi.fn(), + invalidateConversationVarValues: vi.fn(), + hasSetInspectVar: vi.fn(() => false), + }, + }) + + expect(result.current.conversationVars).toHaveLength(1) + expect(result.current.systemVars.map(item => item.name)).toEqual(['time']) + expect(result.current.nodesWithInspectVars[0]?.vars.map(item => item.name)).toEqual([ + 'answer', + 'query', + 'files', + ]) + expect(result.current.hasNodeInspectVars).toBe(hasNodeInspectVars) + expect(result.current.fetchInspectVarValue).toBe(fetchInspectVarValue) + expect(result.current.deleteAllInspectorVars).toBe(deleteAllInspectorVars) + }) + + it('uses an empty flow id for rag pipeline conversation and system value queries', () => { + renderWorkflowHook(() => useInspectVarsCrud(), { + hooksStoreProps: { + configsMap: { + flowId: 'rag-flow', + flowType: FlowType.ragPipeline, + fileSettings: {} as never, + }, + }, + }) + + expect(mockUseConversationVarValues).toHaveBeenCalledWith(FlowType.ragPipeline, '') + expect(mockUseSysVarValues).toHaveBeenCalledWith(FlowType.ragPipeline, '') + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-nodes-available-var-list.spec.ts b/web/app/components/workflow/hooks/__tests__/use-nodes-available-var-list.spec.ts new file mode 100644 index 0000000000..55db395f2e --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-nodes-available-var-list.spec.ts @@ -0,0 +1,110 @@ +import type { Node, NodeOutPutVar, Var } from '../../types' +import { renderHook } from '@testing-library/react' +import { BlockEnum, VarType } from '../../types' +import useNodesAvailableVarList, { useGetNodesAvailableVarList } from '../use-nodes-available-var-list' + +const mockGetTreeLeafNodes = vi.hoisted(() => vi.fn()) +const mockGetBeforeNodesInSameBranchIncludeParent = vi.hoisted(() => vi.fn()) +const mockGetNodeAvailableVars = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useIsChatMode: () => true, + useWorkflow: () => ({ + getTreeLeafNodes: mockGetTreeLeafNodes, + getBeforeNodesInSameBranchIncludeParent: mockGetBeforeNodesInSameBranchIncludeParent, + }), + useWorkflowVariables: () => ({ + getNodeAvailableVars: mockGetNodeAvailableVars, + }), +})) + +const createNode = (overrides: Partial = {}): Node => ({ + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.LLM, + title: 'Node', + desc: '', + }, + ...overrides, +} as Node) + +const outputVars: NodeOutPutVar[] = [{ + nodeId: 'vars-node', + title: 'Vars', + vars: [{ + variable: 'name', + type: VarType.string, + }] satisfies Var[], +}] + +describe('useNodesAvailableVarList', () => { + beforeEach(() => { + vi.clearAllMocks() + mockGetBeforeNodesInSameBranchIncludeParent.mockImplementation((nodeId: string) => [createNode({ id: `before-${nodeId}` })]) + mockGetTreeLeafNodes.mockImplementation((nodeId: string) => [createNode({ id: `leaf-${nodeId}` })]) + mockGetNodeAvailableVars.mockReturnValue(outputVars) + }) + + it('builds availability per node, carrying loop nodes and parent iteration context', () => { + const loopNode = createNode({ + id: 'loop-1', + data: { + type: BlockEnum.Loop, + title: 'Loop', + desc: '', + }, + }) + const childNode = createNode({ + id: 'child-1', + parentId: 'loop-1', + data: { + type: BlockEnum.LLM, + title: 'Writer', + desc: '', + }, + }) + const filterVar = vi.fn(() => true) + + const { result } = renderHook(() => useNodesAvailableVarList([loopNode, childNode], { + filterVar, + hideEnv: true, + hideChatVar: true, + })) + + expect(mockGetBeforeNodesInSameBranchIncludeParent).toHaveBeenCalledWith('loop-1') + expect(mockGetBeforeNodesInSameBranchIncludeParent).toHaveBeenCalledWith('child-1') + expect(result.current['loop-1']?.availableNodes.map(node => node.id)).toEqual(['before-loop-1', 'loop-1']) + expect(result.current['child-1']?.availableVars).toBe(outputVars) + expect(mockGetNodeAvailableVars).toHaveBeenNthCalledWith(2, expect.objectContaining({ + parentNode: loopNode, + isChatMode: true, + filterVar, + hideEnv: true, + hideChatVar: true, + })) + }) + + it('returns a callback version that can use leaf nodes or caller-provided nodes', () => { + const firstNode = createNode({ id: 'node-a' }) + const secondNode = createNode({ id: 'node-b' }) + const filterVar = vi.fn(() => true) + const passedInAvailableNodes = [createNode({ id: 'manual-node' })] + + const { result } = renderHook(() => useGetNodesAvailableVarList()) + + const leafMap = result.current.getNodesAvailableVarList([firstNode], { + onlyLeafNodeVar: true, + filterVar, + }) + const manualMap = result.current.getNodesAvailableVarList([secondNode], { + filterVar, + passedInAvailableNodes, + }) + + expect(mockGetTreeLeafNodes).toHaveBeenCalledWith('node-a') + expect(leafMap['node-a']?.availableNodes.map(node => node.id)).toEqual(['leaf-node-a']) + expect(manualMap['node-b']?.availableNodes).toBe(passedInAvailableNodes) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-nodes-interactions-without-sync.spec.ts b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions-without-sync.spec.ts new file mode 100644 index 0000000000..1a2ebe9385 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions-without-sync.spec.ts @@ -0,0 +1,119 @@ +import { act, waitFor } from '@testing-library/react' +import { useNodes } from 'reactflow' +import { createNode } from '../../__tests__/fixtures' +import { renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../types' +import { useNodesInteractionsWithoutSync } from '../use-nodes-interactions-without-sync' + +type NodeRuntimeState = { + _runningStatus?: NodeRunningStatus + _waitingRun?: boolean +} + +const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => + (node?.data ?? {}) as NodeRuntimeState + +const createFlowNodes = () => [ + createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running, _waitingRun: true } }), + createNode({ id: 'n2', position: { x: 100, y: 0 }, data: { _runningStatus: NodeRunningStatus.Succeeded, _waitingRun: false } }), + createNode({ id: 'n3', position: { x: 200, y: 0 }, data: { _runningStatus: NodeRunningStatus.Failed, _waitingRun: true } }), +] + +const renderNodesInteractionsHook = () => + renderWorkflowFlowHook(() => ({ + ...useNodesInteractionsWithoutSync(), + nodes: useNodes(), + }), { + nodes: createFlowNodes(), + edges: [], + }) + +describe('useNodesInteractionsWithoutSync', () => { + it('clears _runningStatus and _waitingRun on all nodes', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleNodeCancelRunningStatus() + }) + + await waitFor(() => { + result.current.nodes.forEach((node) => { + const nodeState = getNodeRuntimeState(node) + expect(nodeState._runningStatus).toBeUndefined() + expect(nodeState._waitingRun).toBe(false) + }) + }) + }) + + it('clears _runningStatus only for Succeeded nodes', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelAllNodeSuccessStatus() + }) + + await waitFor(() => { + const n1 = result.current.nodes.find(node => node.id === 'n1') + const n2 = result.current.nodes.find(node => node.id === 'n2') + const n3 = result.current.nodes.find(node => node.id === 'n3') + + expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(n2)._runningStatus).toBeUndefined() + expect(getNodeRuntimeState(n3)._runningStatus).toBe(NodeRunningStatus.Failed) + }) + }) + + it('does not modify _waitingRun when clearing all success status', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelAllNodeSuccessStatus() + }) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n1'))._waitingRun).toBe(true) + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n3'))._waitingRun).toBe(true) + }) + }) + + it('clears _runningStatus and _waitingRun for the specified succeeded node', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelNodeSuccessStatus('n2') + }) + + await waitFor(() => { + const n2 = result.current.nodes.find(node => node.id === 'n2') + expect(getNodeRuntimeState(n2)._runningStatus).toBeUndefined() + expect(getNodeRuntimeState(n2)._waitingRun).toBe(false) + }) + }) + + it('does not modify nodes that are not succeeded', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelNodeSuccessStatus('n1') + }) + + await waitFor(() => { + const n1 = result.current.nodes.find(node => node.id === 'n1') + expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(n1)._waitingRun).toBe(true) + }) + }) + + it('does not modify other nodes', async () => { + const { result } = renderNodesInteractionsHook() + + act(() => { + result.current.handleCancelNodeSuccessStatus('n2') + }) + + await waitFor(() => { + const n1 = result.current.nodes.find(node => node.id === 'n1') + expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts new file mode 100644 index 0000000000..35a309902e --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-nodes-interactions.spec.ts @@ -0,0 +1,205 @@ +import type { Edge, Node } from '../../types' +import { act } from '@testing-library/react' +import { createEdge, createNode } from '../../__tests__/fixtures' +import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import { useNodesInteractions } from '../use-nodes-interactions' + +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockSaveStateToHistory = vi.hoisted(() => vi.fn()) +const mockUndo = vi.hoisted(() => vi.fn()) +const mockRedo = vi.hoisted(() => vi.fn()) + +const runtimeState = vi.hoisted(() => ({ + nodesReadOnly: false, + workflowReadOnly: false, +})) + +let currentNodes: Node[] = [] +let currentEdges: Edge[] = [] + +vi.mock('reactflow', async () => + (await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock()) + +vi.mock('../use-workflow', () => ({ + useWorkflow: () => ({ + getAfterNodesInSameBranch: () => [], + }), + useNodesReadOnly: () => ({ + getNodesReadOnly: () => runtimeState.nodesReadOnly, + }), + useWorkflowReadOnly: () => ({ + getWorkflowReadOnly: () => runtimeState.workflowReadOnly, + }), +})) + +vi.mock('../use-helpline', () => ({ + useHelpline: () => ({ + handleSetHelpline: () => ({ + showHorizontalHelpLineNodes: [], + showVerticalHelpLineNodes: [], + }), + }), +})) + +vi.mock('../use-nodes-meta-data', () => ({ + useNodesMetaData: () => ({ + nodesMap: {}, + }), +})) + +vi.mock('../use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft, + }), +})) + +vi.mock('../use-auto-generate-webhook-url', () => ({ + useAutoGenerateWebhookUrl: () => vi.fn(), +})) + +vi.mock('../use-inspect-vars-crud', () => ({ + default: () => ({ + deleteNodeInspectorVars: vi.fn(), + }), +})) + +vi.mock('../../nodes/iteration/use-interactions', () => ({ + useNodeIterationInteractions: () => ({ + handleNodeIterationChildDrag: () => ({ restrictPosition: {} }), + handleNodeIterationChildrenCopy: vi.fn(), + }), +})) + +vi.mock('../../nodes/loop/use-interactions', () => ({ + useNodeLoopInteractions: () => ({ + handleNodeLoopChildDrag: () => ({ restrictPosition: {} }), + handleNodeLoopChildrenCopy: vi.fn(), + }), +})) + +vi.mock('../use-workflow-history', async importOriginal => ({ + ...(await importOriginal()), + useWorkflowHistory: () => ({ + saveStateToHistory: mockSaveStateToHistory, + undo: mockUndo, + redo: mockRedo, + }), +})) + +describe('useNodesInteractions', () => { + beforeEach(() => { + vi.clearAllMocks() + resetReactFlowMockState() + runtimeState.nodesReadOnly = false + runtimeState.workflowReadOnly = false + currentNodes = [ + createNode({ + id: 'node-1', + position: { x: 10, y: 20 }, + data: { + type: BlockEnum.Code, + title: 'Code', + desc: '', + }, + }), + ] + currentEdges = [ + createEdge({ + id: 'edge-1', + source: 'node-1', + target: 'node-2', + }), + ] + rfState.nodes = currentNodes as unknown as typeof rfState.nodes + rfState.edges = currentEdges as unknown as typeof rfState.edges + }) + + it('persists node drags only when the node position actually changes', () => { + const node = currentNodes[0] + const movedNode = { + ...node, + position: { x: 120, y: 80 }, + } + + const { result, store } = renderWorkflowHook(() => useNodesInteractions(), { + historyStore: { + nodes: currentNodes, + edges: currentEdges, + }, + }) + + act(() => { + result.current.handleNodeDragStart({} as never, node, currentNodes) + result.current.handleNodeDragStop({} as never, movedNode, currentNodes) + }) + + expect(store.getState().nodeAnimation).toBe(false) + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(1) + expect(mockSaveStateToHistory).toHaveBeenCalledWith('NodeDragStop', { + nodeId: 'node-1', + }) + }) + + it('restores history snapshots on undo and clears the edge menu', () => { + const historyNodes = [ + createNode({ + id: 'history-node', + data: { + type: BlockEnum.End, + title: 'End', + desc: '', + }, + }), + ] + const historyEdges = [ + createEdge({ + id: 'history-edge', + source: 'history-node', + target: 'node-1', + }), + ] + + const { result, store } = renderWorkflowHook(() => useNodesInteractions(), { + initialStoreState: { + edgeMenu: { + id: 'edge-1', + } as never, + }, + historyStore: { + nodes: historyNodes, + edges: historyEdges, + }, + }) + + act(() => { + result.current.handleHistoryBack() + }) + + expect(mockUndo).toHaveBeenCalledTimes(1) + expect(rfState.setNodes).toHaveBeenCalledWith(historyNodes) + expect(rfState.setEdges).toHaveBeenCalledWith(historyEdges) + expect(store.getState().edgeMenu).toBeUndefined() + }) + + it('skips undo and redo when the workflow is read-only', () => { + runtimeState.workflowReadOnly = true + const { result } = renderWorkflowHook(() => useNodesInteractions(), { + historyStore: { + nodes: currentNodes, + edges: currentEdges, + }, + }) + + act(() => { + result.current.handleHistoryBack() + result.current.handleHistoryForward() + }) + + expect(mockUndo).not.toHaveBeenCalled() + expect(mockRedo).not.toHaveBeenCalled() + expect(rfState.setNodes).not.toHaveBeenCalled() + expect(rfState.setEdges).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-nodes-meta-data.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-nodes-meta-data.spec.tsx new file mode 100644 index 0000000000..9dffa46cb2 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-nodes-meta-data.spec.tsx @@ -0,0 +1,153 @@ +import type { Node } from '../../types' +import { CollectionType } from '@/app/components/tools/types' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import { useNodeMetaData, useNodesMetaData } from '../use-nodes-meta-data' + +const buildInToolsState = vi.hoisted(() => [] as Array<{ id: string, author: string, description: Record }>) +const customToolsState = vi.hoisted(() => [] as Array<{ id: string, author: string, description: Record }>) +const workflowToolsState = vi.hoisted(() => [] as Array<{ id: string, author: string, description: Record }>) + +vi.mock('@/context/i18n', () => ({ + useGetLanguage: () => 'en-US', +})) + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ data: buildInToolsState }), + useAllCustomTools: () => ({ data: customToolsState }), + useAllWorkflowTools: () => ({ data: workflowToolsState }), +})) + +const createNode = (overrides: Partial = {}): Node => ({ + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.LLM, + title: 'Node', + desc: '', + }, + ...overrides, +} as Node) + +describe('useNodesMetaData', () => { + beforeEach(() => { + vi.clearAllMocks() + buildInToolsState.length = 0 + customToolsState.length = 0 + workflowToolsState.length = 0 + }) + + it('returns empty metadata collections when the hooks store has no node map', () => { + const { result } = renderWorkflowHook(() => useNodesMetaData(), { + hooksStoreProps: {}, + }) + + expect(result.current).toEqual({ + nodes: [], + nodesMap: {}, + }) + }) + + it('resolves built-in tool metadata from tool providers', () => { + buildInToolsState.push({ + id: 'provider-1', + author: 'Provider Author', + description: { + 'en-US': 'Built-in provider description', + }, + }) + + const toolNode = createNode({ + data: { + type: BlockEnum.Tool, + title: 'Tool Node', + desc: '', + provider_type: CollectionType.builtIn, + provider_id: 'provider-1', + }, + }) + + const { result } = renderWorkflowHook(() => useNodeMetaData(toolNode), { + hooksStoreProps: { + availableNodesMetaData: { + nodes: [], + }, + }, + }) + + expect(result.current).toEqual(expect.objectContaining({ + author: 'Provider Author', + description: 'Built-in provider description', + })) + }) + + it('prefers workflow store data for datasource nodes and keeps generic metadata for normal blocks', () => { + const datasourceNode = createNode({ + data: { + type: BlockEnum.DataSource, + title: 'Dataset', + desc: '', + plugin_id: 'datasource-1', + }, + }) + + const normalNode = createNode({ + data: { + type: BlockEnum.LLM, + title: 'Writer', + desc: '', + }, + }) + + const datasource = { + plugin_id: 'datasource-1', + author: 'Datasource Author', + description: { + 'en-US': 'Datasource description', + }, + } + + const metadataMap = { + [BlockEnum.LLM]: { + metaData: { + type: BlockEnum.LLM, + title: 'LLM', + author: 'Dify', + description: 'Node description', + }, + }, + } + + const datasourceResult = renderWorkflowHook(() => useNodeMetaData(datasourceNode), { + initialStoreState: { + dataSourceList: [datasource as never], + }, + hooksStoreProps: { + availableNodesMetaData: { + nodes: [], + nodesMap: metadataMap as never, + }, + }, + }) + + const normalResult = renderWorkflowHook(() => useNodeMetaData(normalNode), { + hooksStoreProps: { + availableNodesMetaData: { + nodes: [], + nodesMap: metadataMap as never, + }, + }, + }) + + expect(datasourceResult.result.current).toEqual(expect.objectContaining({ + author: 'Datasource Author', + description: 'Datasource description', + })) + expect(normalResult.result.current).toEqual(expect.objectContaining({ + author: 'Dify', + description: 'Node description', + title: 'LLM', + })) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-set-workflow-vars-with-value.spec.ts b/web/app/components/workflow/hooks/__tests__/use-set-workflow-vars-with-value.spec.ts new file mode 100644 index 0000000000..c0d693cf24 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-set-workflow-vars-with-value.spec.ts @@ -0,0 +1,14 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useSetWorkflowVarsWithValue } from '../use-set-workflow-vars-with-value' + +describe('useSetWorkflowVarsWithValue', () => { + it('returns fetchInspectVars from hooks store', () => { + const fetchInspectVars = vi.fn() + + const { result } = renderWorkflowHook(() => useSetWorkflowVarsWithValue(), { + hooksStoreProps: { fetchInspectVars }, + }) + + expect(result.current.fetchInspectVars).toBe(fetchInspectVars) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-shortcuts.spec.ts b/web/app/components/workflow/hooks/__tests__/use-shortcuts.spec.ts new file mode 100644 index 0000000000..b3c63ff519 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-shortcuts.spec.ts @@ -0,0 +1,168 @@ +import { act } from '@testing-library/react' +import { ZEN_TOGGLE_EVENT } from '@/app/components/goto-anything/actions/commands/zen' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useShortcuts } from '../use-shortcuts' + +type KeyPressRegistration = { + keyFilter: unknown + handler: (event: KeyboardEvent) => void + options?: { + events?: string[] + } +} + +const keyPressRegistrations = vi.hoisted(() => []) +const mockZoomTo = vi.hoisted(() => vi.fn()) +const mockGetZoom = vi.hoisted(() => vi.fn(() => 1)) +const mockFitView = vi.hoisted(() => vi.fn()) +const mockHandleNodesDelete = vi.hoisted(() => vi.fn()) +const mockHandleEdgeDelete = vi.hoisted(() => vi.fn()) +const mockHandleNodesCopy = vi.hoisted(() => vi.fn()) +const mockHandleNodesPaste = vi.hoisted(() => vi.fn()) +const mockHandleNodesDuplicate = vi.hoisted(() => vi.fn()) +const mockHandleHistoryBack = vi.hoisted(() => vi.fn()) +const mockHandleHistoryForward = vi.hoisted(() => vi.fn()) +const mockDimOtherNodes = vi.hoisted(() => vi.fn()) +const mockUndimAllNodes = vi.hoisted(() => vi.fn()) +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockHandleModeHand = vi.hoisted(() => vi.fn()) +const mockHandleModePointer = vi.hoisted(() => vi.fn()) +const mockHandleLayout = vi.hoisted(() => vi.fn()) +const mockHandleToggleMaximizeCanvas = vi.hoisted(() => vi.fn()) + +vi.mock('ahooks', () => ({ + useKeyPress: (keyFilter: unknown, handler: (event: KeyboardEvent) => void, options?: { events?: string[] }) => { + keyPressRegistrations.push({ keyFilter, handler, options }) + }, +})) + +vi.mock('reactflow', () => ({ + useReactFlow: () => ({ + zoomTo: mockZoomTo, + getZoom: mockGetZoom, + fitView: mockFitView, + }), +})) + +vi.mock('..', () => ({ + useNodesInteractions: () => ({ + handleNodesCopy: mockHandleNodesCopy, + handleNodesPaste: mockHandleNodesPaste, + handleNodesDuplicate: mockHandleNodesDuplicate, + handleNodesDelete: mockHandleNodesDelete, + handleHistoryBack: mockHandleHistoryBack, + handleHistoryForward: mockHandleHistoryForward, + dimOtherNodes: mockDimOtherNodes, + undimAllNodes: mockUndimAllNodes, + }), + useEdgesInteractions: () => ({ + handleEdgeDelete: mockHandleEdgeDelete, + }), + useNodesSyncDraft: () => ({ + handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft, + }), + useWorkflowCanvasMaximize: () => ({ + handleToggleMaximizeCanvas: mockHandleToggleMaximizeCanvas, + }), + useWorkflowMoveMode: () => ({ + handleModeHand: mockHandleModeHand, + handleModePointer: mockHandleModePointer, + }), + useWorkflowOrganize: () => ({ + handleLayout: mockHandleLayout, + }), +})) + +vi.mock('../../workflow-history-store', () => ({ + useWorkflowHistoryStore: () => ({ + shortcutsEnabled: true, + }), +})) + +const createKeyboardEvent = (target: HTMLElement = document.body) => ({ + preventDefault: vi.fn(), + target, +}) as unknown as KeyboardEvent + +const findRegistration = (matcher: (registration: KeyPressRegistration) => boolean) => { + const registration = keyPressRegistrations.find(matcher) + expect(registration).toBeDefined() + return registration as KeyPressRegistration +} + +describe('useShortcuts', () => { + beforeEach(() => { + keyPressRegistrations.length = 0 + vi.clearAllMocks() + }) + + it('deletes selected nodes and edges only outside editable inputs', () => { + renderWorkflowHook(() => useShortcuts()) + + const deleteShortcut = findRegistration(registration => + Array.isArray(registration.keyFilter) + && registration.keyFilter.includes('delete'), + ) + + const bodyEvent = createKeyboardEvent() + deleteShortcut.handler(bodyEvent) + + expect(bodyEvent.preventDefault).toHaveBeenCalled() + expect(mockHandleNodesDelete).toHaveBeenCalledTimes(1) + expect(mockHandleEdgeDelete).toHaveBeenCalledTimes(1) + + const inputEvent = createKeyboardEvent(document.createElement('input')) + deleteShortcut.handler(inputEvent) + + expect(mockHandleNodesDelete).toHaveBeenCalledTimes(1) + expect(mockHandleEdgeDelete).toHaveBeenCalledTimes(1) + }) + + it('runs layout and zoom shortcuts through the workflow actions', () => { + renderWorkflowHook(() => useShortcuts()) + + const layoutShortcut = findRegistration(registration => registration.keyFilter === 'ctrl.o' || registration.keyFilter === 'meta.o') + const fitViewShortcut = findRegistration(registration => registration.keyFilter === 'ctrl.1' || registration.keyFilter === 'meta.1') + const halfZoomShortcut = findRegistration(registration => registration.keyFilter === 'shift.5') + const zoomOutShortcut = findRegistration(registration => registration.keyFilter === 'ctrl.dash' || registration.keyFilter === 'meta.dash') + const zoomInShortcut = findRegistration(registration => registration.keyFilter === 'ctrl.equalsign' || registration.keyFilter === 'meta.equalsign') + + layoutShortcut.handler(createKeyboardEvent()) + fitViewShortcut.handler(createKeyboardEvent()) + halfZoomShortcut.handler(createKeyboardEvent()) + zoomOutShortcut.handler(createKeyboardEvent()) + zoomInShortcut.handler(createKeyboardEvent()) + + expect(mockHandleLayout).toHaveBeenCalledTimes(1) + expect(mockFitView).toHaveBeenCalledTimes(1) + expect(mockZoomTo).toHaveBeenNthCalledWith(1, 0.5) + expect(mockZoomTo).toHaveBeenNthCalledWith(2, 0.9) + expect(mockZoomTo).toHaveBeenNthCalledWith(3, 1.1) + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(4) + }) + + it('dims on shift down, undims on shift up, and responds to zen toggle events', () => { + const { unmount } = renderWorkflowHook(() => useShortcuts()) + + const shiftDownShortcut = findRegistration(registration => registration.keyFilter === 'shift' && registration.options?.events?.[0] === 'keydown') + const shiftUpShortcut = findRegistration(registration => typeof registration.keyFilter === 'function' && registration.options?.events?.[0] === 'keyup') + + shiftDownShortcut.handler(createKeyboardEvent()) + shiftUpShortcut.handler({ ...createKeyboardEvent(), key: 'Shift' } as KeyboardEvent) + + expect(mockDimOtherNodes).toHaveBeenCalledTimes(1) + expect(mockUndimAllNodes).toHaveBeenCalledTimes(1) + + act(() => { + window.dispatchEvent(new Event(ZEN_TOGGLE_EVENT)) + }) + expect(mockHandleToggleMaximizeCanvas).toHaveBeenCalledTimes(1) + + unmount() + + act(() => { + window.dispatchEvent(new Event(ZEN_TOGGLE_EVENT)) + }) + expect(mockHandleToggleMaximizeCanvas).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-without-sync-hooks.spec.ts b/web/app/components/workflow/hooks/__tests__/use-without-sync-hooks.spec.ts deleted file mode 100644 index 2d40028226..0000000000 --- a/web/app/components/workflow/hooks/__tests__/use-without-sync-hooks.spec.ts +++ /dev/null @@ -1,209 +0,0 @@ -import { act, waitFor } from '@testing-library/react' -import { useEdges, useNodes } from 'reactflow' -import { createEdge, createNode } from '../../__tests__/fixtures' -import { renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' -import { NodeRunningStatus } from '../../types' -import { useEdgesInteractionsWithoutSync } from '../use-edges-interactions-without-sync' -import { useNodesInteractionsWithoutSync } from '../use-nodes-interactions-without-sync' - -type EdgeRuntimeState = { - _sourceRunningStatus?: NodeRunningStatus - _targetRunningStatus?: NodeRunningStatus - _waitingRun?: boolean -} - -type NodeRuntimeState = { - _runningStatus?: NodeRunningStatus - _waitingRun?: boolean -} - -const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => - (edge?.data ?? {}) as EdgeRuntimeState - -const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => - (node?.data ?? {}) as NodeRuntimeState - -describe('useEdgesInteractionsWithoutSync', () => { - const createFlowNodes = () => [ - createNode({ id: 'a' }), - createNode({ id: 'b' }), - createNode({ id: 'c' }), - ] - const createFlowEdges = () => [ - createEdge({ - id: 'e1', - source: 'a', - target: 'b', - data: { - _sourceRunningStatus: NodeRunningStatus.Running, - _targetRunningStatus: NodeRunningStatus.Running, - _waitingRun: true, - }, - }), - createEdge({ - id: 'e2', - source: 'b', - target: 'c', - data: { - _sourceRunningStatus: NodeRunningStatus.Succeeded, - _targetRunningStatus: undefined, - _waitingRun: false, - }, - }), - ] - - const renderEdgesInteractionsHook = () => - renderWorkflowFlowHook(() => ({ - ...useEdgesInteractionsWithoutSync(), - edges: useEdges(), - }), { - nodes: createFlowNodes(), - edges: createFlowEdges(), - }) - - it('should clear running status and waitingRun on all edges', () => { - const { result } = renderEdgesInteractionsHook() - - act(() => { - result.current.handleEdgeCancelRunningStatus() - }) - - return waitFor(() => { - result.current.edges.forEach((edge) => { - const edgeState = getEdgeRuntimeState(edge) - expect(edgeState._sourceRunningStatus).toBeUndefined() - expect(edgeState._targetRunningStatus).toBeUndefined() - expect(edgeState._waitingRun).toBe(false) - }) - }) - }) - - it('should not mutate original edges', () => { - const edges = createFlowEdges() - const originalData = { ...getEdgeRuntimeState(edges[0]) } - const { result } = renderWorkflowFlowHook(() => ({ - ...useEdgesInteractionsWithoutSync(), - edges: useEdges(), - }), { - nodes: createFlowNodes(), - edges, - }) - - act(() => { - result.current.handleEdgeCancelRunningStatus() - }) - - expect(getEdgeRuntimeState(edges[0])._sourceRunningStatus).toBe(originalData._sourceRunningStatus) - }) -}) - -describe('useNodesInteractionsWithoutSync', () => { - const createFlowNodes = () => [ - createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running, _waitingRun: true } }), - createNode({ id: 'n2', position: { x: 100, y: 0 }, data: { _runningStatus: NodeRunningStatus.Succeeded, _waitingRun: false } }), - createNode({ id: 'n3', position: { x: 200, y: 0 }, data: { _runningStatus: NodeRunningStatus.Failed, _waitingRun: true } }), - ] - - const renderNodesInteractionsHook = () => - renderWorkflowFlowHook(() => ({ - ...useNodesInteractionsWithoutSync(), - nodes: useNodes(), - }), { - nodes: createFlowNodes(), - edges: [], - }) - - describe('handleNodeCancelRunningStatus', () => { - it('should clear _runningStatus and _waitingRun on all nodes', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleNodeCancelRunningStatus() - }) - - await waitFor(() => { - result.current.nodes.forEach((node) => { - const nodeState = getNodeRuntimeState(node) - expect(nodeState._runningStatus).toBeUndefined() - expect(nodeState._waitingRun).toBe(false) - }) - }) - }) - }) - - describe('handleCancelAllNodeSuccessStatus', () => { - it('should clear _runningStatus only for Succeeded nodes', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelAllNodeSuccessStatus() - }) - - await waitFor(() => { - const n1 = result.current.nodes.find(node => node.id === 'n1') - const n2 = result.current.nodes.find(node => node.id === 'n2') - const n3 = result.current.nodes.find(node => node.id === 'n3') - - expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(n2)._runningStatus).toBeUndefined() - expect(getNodeRuntimeState(n3)._runningStatus).toBe(NodeRunningStatus.Failed) - }) - }) - - it('should not modify _waitingRun', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelAllNodeSuccessStatus() - }) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n1'))._waitingRun).toBe(true) - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n3'))._waitingRun).toBe(true) - }) - }) - }) - - describe('handleCancelNodeSuccessStatus', () => { - it('should clear _runningStatus and _waitingRun for the specified Succeeded node', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelNodeSuccessStatus('n2') - }) - - await waitFor(() => { - const n2 = result.current.nodes.find(node => node.id === 'n2') - expect(getNodeRuntimeState(n2)._runningStatus).toBeUndefined() - expect(getNodeRuntimeState(n2)._waitingRun).toBe(false) - }) - }) - - it('should not modify nodes that are not Succeeded', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelNodeSuccessStatus('n1') - }) - - await waitFor(() => { - const n1 = result.current.nodes.find(node => node.id === 'n1') - expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(n1)._waitingRun).toBe(true) - }) - }) - - it('should not modify other nodes', async () => { - const { result } = renderNodesInteractionsHook() - - act(() => { - result.current.handleCancelNodeSuccessStatus('n2') - }) - - await waitFor(() => { - const n1 = result.current.nodes.find(node => node.id === 'n1') - expect(getNodeRuntimeState(n1)._runningStatus).toBe(NodeRunningStatus.Running) - }) - }) - }) -}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-canvas-maximize.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-canvas-maximize.spec.ts new file mode 100644 index 0000000000..f4cde1e72a --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-canvas-maximize.spec.ts @@ -0,0 +1,59 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { WorkflowRunningStatus } from '../../types' +import { useWorkflowCanvasMaximize } from '../use-workflow-canvas-maximize' + +const mockEmit = vi.hoisted(() => vi.fn()) + +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + emit: mockEmit, + }, + }), +})) + +describe('useWorkflowCanvasMaximize', () => { + beforeEach(() => { + vi.clearAllMocks() + localStorage.clear() + }) + + it('toggles maximize state, persists it, and emits the canvas event', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowCanvasMaximize(), { + initialStoreState: { + maximizeCanvas: false, + }, + }) + + result.current.handleToggleMaximizeCanvas() + + expect(store.getState().maximizeCanvas).toBe(true) + expect(localStorage.getItem('workflow-canvas-maximize')).toBe('true') + expect(mockEmit).toHaveBeenCalledWith({ + type: 'workflow-canvas-maximize', + payload: true, + }) + }) + + it('does nothing while workflow nodes are read-only', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowCanvasMaximize(), { + initialStoreState: { + maximizeCanvas: false, + workflowRunningData: { + result: { + status: WorkflowRunningStatus.Running, + inputs_truncated: false, + process_data_truncated: false, + outputs_truncated: false, + }, + }, + }, + }) + + result.current.handleToggleMaximizeCanvas() + + expect(store.getState().maximizeCanvas).toBe(false) + expect(localStorage.getItem('workflow-canvas-maximize')).toBeNull() + expect(mockEmit).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-history.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-history.spec.tsx new file mode 100644 index 0000000000..54917d009c --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-history.spec.tsx @@ -0,0 +1,141 @@ +import type { Edge, Node } from '../../types' +import { act } from '@testing-library/react' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { BlockEnum } from '../../types' +import { useWorkflowHistory, WorkflowHistoryEvent } from '../use-workflow-history' + +const reactFlowState = vi.hoisted(() => ({ + edges: [] as Edge[], + nodes: [] as Node[], +})) + +vi.mock('es-toolkit/compat', () => ({ + debounce: unknown>(fn: T) => fn, +})) + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useStoreApi: () => ({ + getState: () => ({ + getNodes: () => reactFlowState.nodes, + edges: reactFlowState.edges, + }), + }), + } +}) + +vi.mock('react-i18next', async () => { + const actual = await vi.importActual('react-i18next') + return { + ...actual, + useTranslation: () => ({ + t: (key: string) => key, + }), + } +}) + +const nodes: Node[] = [{ + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title: 'Start', + desc: '', + }, +}] + +const edges: Edge[] = [{ + id: 'edge-1', + source: 'node-1', + target: 'node-2', + type: 'custom', + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.End, + }, +}] + +describe('useWorkflowHistory', () => { + beforeEach(() => { + reactFlowState.nodes = nodes + reactFlowState.edges = edges + }) + + it('stores the latest workflow graph snapshot for supported events', () => { + const { result } = renderWorkflowHook(() => useWorkflowHistory(), { + historyStore: { + nodes, + edges, + }, + }) + + act(() => { + result.current.saveStateToHistory(WorkflowHistoryEvent.NodeAdd, { nodeId: 'node-1' }) + }) + + expect(result.current.store.getState().workflowHistoryEvent).toBe(WorkflowHistoryEvent.NodeAdd) + expect(result.current.store.getState().workflowHistoryEventMeta).toEqual({ nodeId: 'node-1' }) + expect(result.current.store.getState().nodes).toEqual([ + expect.objectContaining({ + id: 'node-1', + data: expect.objectContaining({ + selected: false, + title: 'Start', + }), + }), + ]) + expect(result.current.store.getState().edges).toEqual([ + expect.objectContaining({ + id: 'edge-1', + selected: false, + source: 'node-1', + target: 'node-2', + }), + ]) + }) + + it('returns translated labels and falls back for unsupported events', () => { + const { result } = renderWorkflowHook(() => useWorkflowHistory(), { + historyStore: { + nodes, + edges, + }, + }) + + expect(result.current.getHistoryLabel(WorkflowHistoryEvent.NodeDelete)).toBe('changeHistory.nodeDelete') + expect(result.current.getHistoryLabel('Unknown' as keyof typeof WorkflowHistoryEvent)).toBe('Unknown Event') + }) + + it('runs registered undo and redo callbacks', () => { + const onUndo = vi.fn() + const onRedo = vi.fn() + + const { result } = renderWorkflowHook(() => useWorkflowHistory(), { + historyStore: { + nodes, + edges, + }, + }) + + act(() => { + result.current.onUndo(onUndo) + result.current.onRedo(onRedo) + }) + + const undoSpy = vi.spyOn(result.current.store.temporal.getState(), 'undo') + const redoSpy = vi.spyOn(result.current.store.temporal.getState(), 'redo') + + act(() => { + result.current.undo() + result.current.redo() + }) + + expect(undoSpy).toHaveBeenCalled() + expect(redoSpy).toHaveBeenCalled() + expect(onUndo).toHaveBeenCalled() + expect(onRedo).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-organize.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-organize.spec.tsx new file mode 100644 index 0000000000..424ad96630 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-organize.spec.tsx @@ -0,0 +1,152 @@ +import { act } from '@testing-library/react' +import { createLoopNode, createNode } from '../../__tests__/fixtures' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowOrganize } from '../use-workflow-organize' + +const mockSetViewport = vi.hoisted(() => vi.fn()) +const mockSetNodes = vi.hoisted(() => vi.fn()) +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockSaveStateToHistory = vi.hoisted(() => vi.fn()) +const mockGetLayoutForChildNodes = vi.hoisted(() => vi.fn()) +const mockGetLayoutByELK = vi.hoisted(() => vi.fn()) + +const runtimeState = vi.hoisted(() => ({ + nodes: [] as ReturnType[], + edges: [] as { id: string, source: string, target: string }[], + nodesReadOnly: false, +})) + +vi.mock('reactflow', () => ({ + Position: { + Left: 'left', + Right: 'right', + Top: 'top', + Bottom: 'bottom', + }, + useStoreApi: () => ({ + getState: () => ({ + getNodes: () => runtimeState.nodes, + edges: runtimeState.edges, + setNodes: mockSetNodes, + }), + setState: vi.fn(), + }), + useReactFlow: () => ({ + setViewport: mockSetViewport, + }), +})) + +vi.mock('../use-workflow', () => ({ + useNodesReadOnly: () => ({ + getNodesReadOnly: () => runtimeState.nodesReadOnly, + nodesReadOnly: runtimeState.nodesReadOnly, + }), +})) + +vi.mock('../use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + handleSyncWorkflowDraft: (...args: unknown[]) => mockHandleSyncWorkflowDraft(...args), + }), +})) + +vi.mock('../use-workflow-history', () => ({ + useWorkflowHistory: () => ({ + saveStateToHistory: (...args: unknown[]) => mockSaveStateToHistory(...args), + }), + WorkflowHistoryEvent: { + LayoutOrganize: 'LayoutOrganize', + }, +})) + +vi.mock('../../utils/elk-layout', async importOriginal => ({ + ...(await importOriginal()), + getLayoutForChildNodes: (...args: unknown[]) => mockGetLayoutForChildNodes(...args), + getLayoutByELK: (...args: unknown[]) => mockGetLayoutByELK(...args), +})) + +describe('useWorkflowOrganize', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.useFakeTimers() + runtimeState.nodesReadOnly = false + runtimeState.nodes = [] + runtimeState.edges = [] + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('resizes containers, lays out nodes, and syncs draft when editable', async () => { + runtimeState.nodes = [ + createLoopNode({ + id: 'loop-node', + width: 200, + height: 160, + }), + createNode({ + id: 'loop-child', + parentId: 'loop-node', + position: { x: 20, y: 20 }, + width: 100, + height: 60, + }), + createNode({ + id: 'top-node', + position: { x: 400, y: 0 }, + }), + ] + runtimeState.edges = [] + mockGetLayoutForChildNodes.mockResolvedValue({ + bounds: { minX: 0, minY: 0, maxX: 320, maxY: 220 }, + nodes: new Map([ + ['loop-child', { x: 40, y: 60, width: 100, height: 60 }], + ]), + }) + mockGetLayoutByELK.mockResolvedValue({ + nodes: new Map([ + ['loop-node', { x: 10, y: 20, width: 360, height: 260, layer: 0 }], + ['top-node', { x: 500, y: 30, width: 240, height: 100, layer: 0 }], + ]), + }) + + const { result } = renderWorkflowHook(() => useWorkflowOrganize()) + + await act(async () => { + await result.current.handleLayout() + }) + act(() => { + vi.runAllTimers() + }) + + expect(mockSetNodes).toHaveBeenCalledTimes(1) + const nextNodes = mockSetNodes.mock.calls[0][0] + expect(nextNodes.find((node: { id: string }) => node.id === 'loop-node')).toEqual(expect.objectContaining({ + width: expect.any(Number), + height: expect.any(Number), + position: { x: 10, y: 20 }, + })) + expect(nextNodes.find((node: { id: string }) => node.id === 'loop-child')).toEqual(expect.objectContaining({ + position: { x: 100, y: 120 }, + })) + expect(mockSetViewport).toHaveBeenCalledWith({ x: 0, y: 0, zoom: 0.7 }) + expect(mockSaveStateToHistory).toHaveBeenCalledWith('LayoutOrganize') + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalled() + }) + + it('skips layout when nodes are read-only', async () => { + runtimeState.nodesReadOnly = true + runtimeState.nodes = [createNode({ id: 'n1' })] + + const { result } = renderWorkflowHook(() => useWorkflowOrganize()) + + await act(async () => { + await result.current.handleLayout() + }) + + expect(mockGetLayoutForChildNodes).not.toHaveBeenCalled() + expect(mockGetLayoutByELK).not.toHaveBeenCalled() + expect(mockSetNodes).not.toHaveBeenCalled() + expect(mockSetViewport).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-panel-interactions.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-panel-interactions.spec.tsx new file mode 100644 index 0000000000..9ff61f70f9 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-panel-interactions.spec.tsx @@ -0,0 +1,110 @@ +import { act } from '@testing-library/react' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { ControlMode } from '../../types' +import { + useWorkflowInteractions, + useWorkflowMoveMode, +} from '../use-workflow-panel-interactions' + +const mockHandleSelectionCancel = vi.hoisted(() => vi.fn()) +const mockHandleNodeCancelRunningStatus = vi.hoisted(() => vi.fn()) +const mockHandleEdgeCancelRunningStatus = vi.hoisted(() => vi.fn()) + +const runtimeState = vi.hoisted(() => ({ + nodesReadOnly: false, +})) + +vi.mock('../use-workflow', () => ({ + useNodesReadOnly: () => ({ + getNodesReadOnly: () => runtimeState.nodesReadOnly, + nodesReadOnly: runtimeState.nodesReadOnly, + }), +})) + +vi.mock('../use-selection-interactions', () => ({ + useSelectionInteractions: () => ({ + handleSelectionCancel: (...args: unknown[]) => mockHandleSelectionCancel(...args), + }), +})) + +vi.mock('../use-nodes-interactions-without-sync', () => ({ + useNodesInteractionsWithoutSync: () => ({ + handleNodeCancelRunningStatus: (...args: unknown[]) => mockHandleNodeCancelRunningStatus(...args), + }), +})) + +vi.mock('../use-edges-interactions-without-sync', () => ({ + useEdgesInteractionsWithoutSync: () => ({ + handleEdgeCancelRunningStatus: (...args: unknown[]) => mockHandleEdgeCancelRunningStatus(...args), + }), +})) + +describe('useWorkflowInteractions', () => { + beforeEach(() => { + vi.clearAllMocks() + runtimeState.nodesReadOnly = false + }) + + it('closes the debug panel and clears running state', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowInteractions(), { + initialStoreState: { + showDebugAndPreviewPanel: true, + workflowRunningData: { task_id: 'task-1' } as never, + }, + }) + + act(() => { + result.current.handleCancelDebugAndPreviewPanel() + }) + + expect(store.getState().showDebugAndPreviewPanel).toBe(false) + expect(store.getState().workflowRunningData).toBeUndefined() + expect(mockHandleNodeCancelRunningStatus).toHaveBeenCalledTimes(1) + expect(mockHandleEdgeCancelRunningStatus).toHaveBeenCalledTimes(1) + }) +}) + +describe('useWorkflowMoveMode', () => { + beforeEach(() => { + vi.clearAllMocks() + runtimeState.nodesReadOnly = false + }) + + it('switches between hand and pointer modes when editable', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowMoveMode(), { + initialStoreState: { + controlMode: ControlMode.Pointer, + }, + }) + + act(() => { + result.current.handleModeHand() + }) + + expect(store.getState().controlMode).toBe(ControlMode.Hand) + expect(mockHandleSelectionCancel).toHaveBeenCalledTimes(1) + + act(() => { + result.current.handleModePointer() + }) + + expect(store.getState().controlMode).toBe(ControlMode.Pointer) + }) + + it('does not switch modes when nodes are read-only', () => { + runtimeState.nodesReadOnly = true + const { result, store } = renderWorkflowHook(() => useWorkflowMoveMode(), { + initialStoreState: { + controlMode: ControlMode.Pointer, + }, + }) + + act(() => { + result.current.handleModeHand() + result.current.handleModePointer() + }) + + expect(store.getState().controlMode).toBe(ControlMode.Pointer) + expect(mockHandleSelectionCancel).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-refresh-draft.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-refresh-draft.spec.ts new file mode 100644 index 0000000000..83c8a4199b --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-refresh-draft.spec.ts @@ -0,0 +1,14 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowRefreshDraft } from '../use-workflow-refresh-draft' + +describe('useWorkflowRefreshDraft', () => { + it('returns handleRefreshWorkflowDraft from hooks store', () => { + const handleRefreshWorkflowDraft = vi.fn() + + const { result } = renderWorkflowHook(() => useWorkflowRefreshDraft(), { + hooksStoreProps: { handleRefreshWorkflowDraft }, + }) + + expect(result.current.handleRefreshWorkflowDraft).toBe(handleRefreshWorkflowDraft) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-store-only.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-store-only.spec.ts deleted file mode 100644 index 2085e5ab47..0000000000 --- a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-store-only.spec.ts +++ /dev/null @@ -1,242 +0,0 @@ -import type { - AgentLogResponse, - HumanInputFormFilledResponse, - HumanInputFormTimeoutResponse, - TextChunkResponse, - TextReplaceResponse, - WorkflowFinishedResponse, -} from '@/types/workflow' -import { baseRunningData, renderWorkflowHook } from '../../__tests__/workflow-test-env' -import { WorkflowRunningStatus } from '../../types' -import { useWorkflowAgentLog } from '../use-workflow-run-event/use-workflow-agent-log' -import { useWorkflowFailed } from '../use-workflow-run-event/use-workflow-failed' -import { useWorkflowFinished } from '../use-workflow-run-event/use-workflow-finished' -import { useWorkflowNodeHumanInputFormFilled } from '../use-workflow-run-event/use-workflow-node-human-input-form-filled' -import { useWorkflowNodeHumanInputFormTimeout } from '../use-workflow-run-event/use-workflow-node-human-input-form-timeout' -import { useWorkflowPaused } from '../use-workflow-run-event/use-workflow-paused' -import { useWorkflowTextChunk } from '../use-workflow-run-event/use-workflow-text-chunk' -import { useWorkflowTextReplace } from '../use-workflow-run-event/use-workflow-text-replace' - -vi.mock('@/app/components/base/file-uploader/utils', () => ({ - getFilesInLogs: vi.fn(() => []), -})) - -describe('useWorkflowFailed', () => { - it('should set status to Failed', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowFailed(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - result.current.handleWorkflowFailed() - - expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Failed) - }) -}) - -describe('useWorkflowPaused', () => { - it('should set status to Paused', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowPaused(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - result.current.handleWorkflowPaused() - - expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Paused) - }) -}) - -describe('useWorkflowTextChunk', () => { - it('should append text and activate result tab', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowTextChunk(), { - initialStoreState: { - workflowRunningData: baseRunningData({ resultText: 'Hello' }), - }, - }) - - result.current.handleWorkflowTextChunk({ data: { text: ' World' } } as TextChunkResponse) - - const state = store.getState().workflowRunningData! - expect(state.resultText).toBe('Hello World') - expect(state.resultTabActive).toBe(true) - }) -}) - -describe('useWorkflowTextReplace', () => { - it('should replace resultText', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowTextReplace(), { - initialStoreState: { - workflowRunningData: baseRunningData({ resultText: 'old text' }), - }, - }) - - result.current.handleWorkflowTextReplace({ data: { text: 'new text' } } as TextReplaceResponse) - - expect(store.getState().workflowRunningData!.resultText).toBe('new text') - }) -}) - -describe('useWorkflowFinished', () => { - it('should merge data into result and activate result tab for single string output', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowFinished(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - result.current.handleWorkflowFinished({ - data: { status: 'succeeded', outputs: { answer: 'hello' } }, - } as WorkflowFinishedResponse) - - const state = store.getState().workflowRunningData! - expect(state.result.status).toBe('succeeded') - expect(state.resultTabActive).toBe(true) - expect(state.resultText).toBe('hello') - }) - - it('should not activate result tab for multi-key outputs', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowFinished(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - result.current.handleWorkflowFinished({ - data: { status: 'succeeded', outputs: { a: 'hello', b: 'world' } }, - } as WorkflowFinishedResponse) - - expect(store.getState().workflowRunningData!.resultTabActive).toBeFalsy() - }) -}) - -describe('useWorkflowAgentLog', () => { - it('should create agent_log array when execution_metadata has no agent_log', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n1', execution_metadata: {} }], - }), - }, - }) - - result.current.handleWorkflowAgentLog({ - data: { node_id: 'n1', message_id: 'm1' }, - } as AgentLogResponse) - - const trace = store.getState().workflowRunningData!.tracing![0] - expect(trace.execution_metadata!.agent_log).toHaveLength(1) - expect(trace.execution_metadata!.agent_log![0].message_id).toBe('m1') - }) - - it('should append to existing agent_log', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ - node_id: 'n1', - execution_metadata: { agent_log: [{ message_id: 'm1', text: 'log1' }] }, - }], - }), - }, - }) - - result.current.handleWorkflowAgentLog({ - data: { node_id: 'n1', message_id: 'm2' }, - } as AgentLogResponse) - - expect(store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log).toHaveLength(2) - }) - - it('should update existing log entry by message_id', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ - node_id: 'n1', - execution_metadata: { agent_log: [{ message_id: 'm1', text: 'old' }] }, - }], - }), - }, - }) - - result.current.handleWorkflowAgentLog({ - data: { node_id: 'n1', message_id: 'm1', text: 'new' }, - } as unknown as AgentLogResponse) - - const log = store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log! - expect(log).toHaveLength(1) - expect((log[0] as unknown as { text: string }).text).toBe('new') - }) - - it('should create execution_metadata when it does not exist', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n1' }], - }), - }, - }) - - result.current.handleWorkflowAgentLog({ - data: { node_id: 'n1', message_id: 'm1' }, - } as AgentLogResponse) - - expect(store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log).toHaveLength(1) - }) -}) - -describe('useWorkflowNodeHumanInputFormFilled', () => { - it('should remove form from humanInputFormDataList and add to humanInputFilledFormDataList', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormFilled(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, - ], - }), - }, - }) - - result.current.handleWorkflowNodeHumanInputFormFilled({ - data: { node_id: 'n1', node_title: 'Node 1', rendered_content: 'done' }, - } as HumanInputFormFilledResponse) - - const state = store.getState().workflowRunningData! - expect(state.humanInputFormDataList).toHaveLength(0) - expect(state.humanInputFilledFormDataList).toHaveLength(1) - expect(state.humanInputFilledFormDataList![0].node_id).toBe('n1') - }) - - it('should create humanInputFilledFormDataList when it does not exist', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormFilled(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, - ], - }), - }, - }) - - result.current.handleWorkflowNodeHumanInputFormFilled({ - data: { node_id: 'n1', node_title: 'Node 1', rendered_content: 'done' }, - } as HumanInputFormFilledResponse) - - expect(store.getState().workflowRunningData!.humanInputFilledFormDataList).toBeDefined() - }) -}) - -describe('useWorkflowNodeHumanInputFormTimeout', () => { - it('should set expiration_time on the matching form', () => { - const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormTimeout(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '', expiration_time: 0 }, - ], - }), - }, - }) - - result.current.handleWorkflowNodeHumanInputFormTimeout({ - data: { node_id: 'n1', node_title: 'Node 1', expiration_time: 1000 }, - } as HumanInputFormTimeoutResponse) - - expect(store.getState().workflowRunningData!.humanInputFormDataList![0].expiration_time).toBe(1000) - }) -}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-store.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-store.spec.ts deleted file mode 100644 index 1c8a0764d1..0000000000 --- a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-store.spec.ts +++ /dev/null @@ -1,336 +0,0 @@ -import type { WorkflowRunningData } from '../../types' -import type { - IterationFinishedResponse, - IterationNextResponse, - LoopFinishedResponse, - LoopNextResponse, - NodeFinishedResponse, - WorkflowStartedResponse, -} from '@/types/workflow' -import { act, waitFor } from '@testing-library/react' -import { useEdges, useNodes } from 'reactflow' -import { createEdge, createNode } from '../../__tests__/fixtures' -import { baseRunningData, renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' -import { DEFAULT_ITER_TIMES } from '../../constants' -import { NodeRunningStatus, WorkflowRunningStatus } from '../../types' -import { useWorkflowNodeFinished } from '../use-workflow-run-event/use-workflow-node-finished' -import { useWorkflowNodeIterationFinished } from '../use-workflow-run-event/use-workflow-node-iteration-finished' -import { useWorkflowNodeIterationNext } from '../use-workflow-run-event/use-workflow-node-iteration-next' -import { useWorkflowNodeLoopFinished } from '../use-workflow-run-event/use-workflow-node-loop-finished' -import { useWorkflowNodeLoopNext } from '../use-workflow-run-event/use-workflow-node-loop-next' -import { useWorkflowNodeRetry } from '../use-workflow-run-event/use-workflow-node-retry' -import { useWorkflowStarted } from '../use-workflow-run-event/use-workflow-started' - -type NodeRuntimeState = { - _waitingRun?: boolean - _runningStatus?: NodeRunningStatus - _retryIndex?: number - _iterationIndex?: number - _loopIndex?: number - _runningBranchId?: string -} - -type EdgeRuntimeState = { - _sourceRunningStatus?: NodeRunningStatus - _targetRunningStatus?: NodeRunningStatus - _waitingRun?: boolean -} - -const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => - (node?.data ?? {}) as NodeRuntimeState - -const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => - (edge?.data ?? {}) as EdgeRuntimeState - -function createRunNodes() { - return [ - createNode({ - id: 'n1', - width: 200, - height: 80, - data: { _waitingRun: false }, - }), - ] -} - -function createRunEdges() { - return [ - createEdge({ - id: 'e1', - source: 'n0', - target: 'n1', - data: {}, - }), - ] -} - -function renderRunEventHook>( - useHook: () => T, - options?: { - nodes?: ReturnType - edges?: ReturnType - initialStoreState?: Record - }, -) { - const { nodes = createRunNodes(), edges = createRunEdges(), initialStoreState } = options ?? {} - - return renderWorkflowFlowHook(() => ({ - ...useHook(), - nodes: useNodes(), - edges: useEdges(), - }), { - nodes, - edges, - reactFlowProps: { fitView: false }, - initialStoreState, - }) -} - -describe('useWorkflowStarted', () => { - it('should initialize workflow running data and reset nodes/edges', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowStarted(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowStarted({ - task_id: 'task-2', - data: { id: 'run-1', workflow_id: 'wf-1', created_at: 1000 }, - } as WorkflowStartedResponse) - }) - - const state = store.getState().workflowRunningData! - expect(state.task_id).toBe('task-2') - expect(state.result.status).toBe(WorkflowRunningStatus.Running) - expect(state.resultText).toBe('') - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._waitingRun).toBe(true) - expect(getNodeRuntimeState(result.current.nodes[0])._runningBranchId).toBeUndefined() - expect(getEdgeRuntimeState(result.current.edges[0])._sourceRunningStatus).toBeUndefined() - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBeUndefined() - expect(getEdgeRuntimeState(result.current.edges[0])._waitingRun).toBe(true) - }) - }) - - it('should resume from Paused without resetting nodes/edges', () => { - const { result, store } = renderRunEventHook(() => useWorkflowStarted(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - result: { status: WorkflowRunningStatus.Paused } as WorkflowRunningData['result'], - }), - }, - }) - - act(() => { - result.current.handleWorkflowStarted({ - task_id: 'task-2', - data: { id: 'run-2', workflow_id: 'wf-1', created_at: 2000 }, - } as WorkflowStartedResponse) - }) - - expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Running) - expect(getNodeRuntimeState(result.current.nodes[0])._waitingRun).toBe(false) - expect(getEdgeRuntimeState(result.current.edges[0])._waitingRun).toBeUndefined() - }) -}) - -describe('useWorkflowNodeFinished', () => { - it('should update tracing and node running status', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeFinished(), { - nodes: [ - createNode({ - id: 'n1', - data: { _runningStatus: NodeRunningStatus.Running }, - }), - ], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Running }], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeFinished({ - data: { id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, - } as NodeFinishedResponse) - }) - - const trace = store.getState().workflowRunningData!.tracing![0] - expect(trace.status).toBe(NodeRunningStatus.Succeeded) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) - }) - }) - - it('should set _runningBranchId for IfElse node', async () => { - const { result } = renderRunEventHook(() => useWorkflowNodeFinished(), { - nodes: [ - createNode({ - id: 'n1', - data: { _runningStatus: NodeRunningStatus.Running }, - }), - ], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Running }], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeFinished({ - data: { - id: 'trace-1', - node_id: 'n1', - node_type: 'if-else', - status: NodeRunningStatus.Succeeded, - outputs: { selected_case_id: 'branch-a' }, - }, - } as unknown as NodeFinishedResponse) - }) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._runningBranchId).toBe('branch-a') - }) - }) -}) - -describe('useWorkflowNodeRetry', () => { - it('should push retry data to tracing and update _retryIndex', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeRetry(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeRetry({ - data: { node_id: 'n1', retry_index: 2 }, - } as NodeFinishedResponse) - }) - - expect(store.getState().workflowRunningData!.tracing).toHaveLength(1) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._retryIndex).toBe(2) - }) - }) -}) - -describe('useWorkflowNodeIterationNext', () => { - it('should set _iterationIndex and increment iterTimes', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeIterationNext(), { - initialStoreState: { - workflowRunningData: baseRunningData(), - iterTimes: 3, - }, - }) - - act(() => { - result.current.handleWorkflowNodeIterationNext({ - data: { node_id: 'n1' }, - } as IterationNextResponse) - }) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._iterationIndex).toBe(3) - }) - expect(store.getState().iterTimes).toBe(4) - }) -}) - -describe('useWorkflowNodeIterationFinished', () => { - it('should update tracing, reset iterTimes, update node status and edges', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeIterationFinished(), { - nodes: [ - createNode({ - id: 'n1', - data: { _runningStatus: NodeRunningStatus.Running }, - }), - ], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ id: 'iter-1', node_id: 'n1', status: NodeRunningStatus.Running }], - }), - iterTimes: 10, - }, - }) - - act(() => { - result.current.handleWorkflowNodeIterationFinished({ - data: { id: 'iter-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, - } as IterationFinishedResponse) - }) - - expect(store.getState().iterTimes).toBe(DEFAULT_ITER_TIMES) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) - }) - }) -}) - -describe('useWorkflowNodeLoopNext', () => { - it('should set _loopIndex and reset child nodes to waiting', async () => { - const { result } = renderRunEventHook(() => useWorkflowNodeLoopNext(), { - nodes: [ - createNode({ id: 'n1', data: {} }), - createNode({ - id: 'n2', - position: { x: 300, y: 0 }, - parentId: 'n1', - data: { _waitingRun: false }, - }), - ], - edges: [], - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeLoopNext({ - data: { node_id: 'n1', index: 5 }, - } as LoopNextResponse) - }) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n1'))._loopIndex).toBe(5) - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n2'))._waitingRun).toBe(true) - expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n2'))._runningStatus).toBe(NodeRunningStatus.Waiting) - }) - }) -}) - -describe('useWorkflowNodeLoopFinished', () => { - it('should update tracing, node status and edges', async () => { - const { result, store } = renderRunEventHook(() => useWorkflowNodeLoopFinished(), { - nodes: [ - createNode({ - id: 'n1', - data: { _runningStatus: NodeRunningStatus.Running }, - }), - ], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ id: 'loop-1', node_id: 'n1', status: NodeRunningStatus.Running }], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeLoopFinished({ - data: { id: 'loop-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, - } as LoopFinishedResponse) - }) - - const trace = store.getState().workflowRunningData!.tracing![0] - expect(trace.status).toBe(NodeRunningStatus.Succeeded) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) - }) - }) -}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-viewport.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-viewport.spec.ts deleted file mode 100644 index 73b16acf2e..0000000000 --- a/web/app/components/workflow/hooks/__tests__/use-workflow-run-event-with-viewport.spec.ts +++ /dev/null @@ -1,331 +0,0 @@ -import type { - HumanInputRequiredResponse, - IterationStartedResponse, - LoopStartedResponse, - NodeStartedResponse, -} from '@/types/workflow' -import { act, waitFor } from '@testing-library/react' -import { useEdges, useNodes, useStoreApi } from 'reactflow' -import { createEdge, createNode } from '../../__tests__/fixtures' -import { baseRunningData, renderWorkflowFlowHook } from '../../__tests__/workflow-test-env' -import { DEFAULT_ITER_TIMES } from '../../constants' -import { NodeRunningStatus } from '../../types' -import { useWorkflowNodeHumanInputRequired } from '../use-workflow-run-event/use-workflow-node-human-input-required' -import { useWorkflowNodeIterationStarted } from '../use-workflow-run-event/use-workflow-node-iteration-started' -import { useWorkflowNodeLoopStarted } from '../use-workflow-run-event/use-workflow-node-loop-started' -import { useWorkflowNodeStarted } from '../use-workflow-run-event/use-workflow-node-started' - -type NodeRuntimeState = { - _waitingRun?: boolean - _runningStatus?: NodeRunningStatus - _iterationLength?: number - _loopLength?: number -} - -type EdgeRuntimeState = { - _sourceRunningStatus?: NodeRunningStatus - _targetRunningStatus?: NodeRunningStatus - _waitingRun?: boolean -} - -const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => - (node?.data ?? {}) as NodeRuntimeState - -const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => - (edge?.data ?? {}) as EdgeRuntimeState - -const containerParams = { clientWidth: 1200, clientHeight: 800 } - -function createViewportNodes() { - return [ - createNode({ - id: 'n0', - width: 200, - height: 80, - data: { _runningStatus: NodeRunningStatus.Succeeded }, - }), - createNode({ - id: 'n1', - position: { x: 100, y: 50 }, - width: 200, - height: 80, - data: { _waitingRun: true }, - }), - createNode({ - id: 'n2', - position: { x: 400, y: 50 }, - width: 200, - height: 80, - parentId: 'n1', - data: { _waitingRun: true }, - }), - ] -} - -function createViewportEdges() { - return [ - createEdge({ - id: 'e1', - source: 'n0', - target: 'n1', - sourceHandle: 'source', - data: {}, - }), - ] -} - -function renderViewportHook>( - useHook: () => T, - options?: { - nodes?: ReturnType - edges?: ReturnType - initialStoreState?: Record - }, -) { - const { - nodes = createViewportNodes(), - edges = createViewportEdges(), - initialStoreState, - } = options ?? {} - - return renderWorkflowFlowHook(() => ({ - ...useHook(), - nodes: useNodes(), - edges: useEdges(), - reactFlowStore: useStoreApi(), - }), { - nodes, - edges, - reactFlowProps: { fitView: false }, - initialStoreState, - }) -} - -describe('useWorkflowNodeStarted', () => { - it('should push to tracing, set node running, and adjust viewport for root node', async () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeStarted(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeStarted( - { data: { node_id: 'n1' } } as NodeStartedResponse, - containerParams, - ) - }) - - const tracing = store.getState().workflowRunningData!.tracing! - expect(tracing).toHaveLength(1) - expect(tracing[0].status).toBe(NodeRunningStatus.Running) - - await waitFor(() => { - const transform = result.current.reactFlowStore.getState().transform - expect(transform[0]).toBe(200) - expect(transform[1]).toBe(310) - expect(transform[2]).toBe(1) - - const node = result.current.nodes.find(item => item.id === 'n1') - expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(node)._waitingRun).toBe(false) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) - }) - }) - - it('should not adjust viewport for child node (has parentId)', async () => { - const { result } = renderViewportHook(() => useWorkflowNodeStarted(), { - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeStarted( - { data: { node_id: 'n2' } } as NodeStartedResponse, - containerParams, - ) - }) - - await waitFor(() => { - const transform = result.current.reactFlowStore.getState().transform - expect(transform[0]).toBe(0) - expect(transform[1]).toBe(0) - expect(transform[2]).toBe(1) - expect(getNodeRuntimeState(result.current.nodes.find(item => item.id === 'n2'))._runningStatus).toBe(NodeRunningStatus.Running) - }) - }) - - it('should update existing tracing entry if node_id exists at non-zero index', () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeStarted(), { - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [ - { node_id: 'n0', status: NodeRunningStatus.Succeeded }, - { node_id: 'n1', status: NodeRunningStatus.Succeeded }, - ], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeStarted( - { data: { node_id: 'n1' } } as NodeStartedResponse, - containerParams, - ) - }) - - const tracing = store.getState().workflowRunningData!.tracing! - expect(tracing).toHaveLength(2) - expect(tracing[1].status).toBe(NodeRunningStatus.Running) - }) -}) - -describe('useWorkflowNodeIterationStarted', () => { - it('should push to tracing, reset iterTimes, set viewport, and update node with _iterationLength', async () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeIterationStarted(), { - nodes: createViewportNodes().slice(0, 2), - initialStoreState: { - workflowRunningData: baseRunningData(), - iterTimes: 99, - }, - }) - - act(() => { - result.current.handleWorkflowNodeIterationStarted( - { data: { node_id: 'n1', metadata: { iterator_length: 10 } } } as IterationStartedResponse, - containerParams, - ) - }) - - const tracing = store.getState().workflowRunningData!.tracing! - expect(tracing[0].status).toBe(NodeRunningStatus.Running) - expect(store.getState().iterTimes).toBe(DEFAULT_ITER_TIMES) - - await waitFor(() => { - const transform = result.current.reactFlowStore.getState().transform - expect(transform[0]).toBe(200) - expect(transform[1]).toBe(310) - expect(transform[2]).toBe(1) - - const node = result.current.nodes.find(item => item.id === 'n1') - expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(node)._iterationLength).toBe(10) - expect(getNodeRuntimeState(node)._waitingRun).toBe(false) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) - }) - }) -}) - -describe('useWorkflowNodeLoopStarted', () => { - it('should push to tracing, set viewport, and update node with _loopLength', async () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeLoopStarted(), { - nodes: createViewportNodes().slice(0, 2), - initialStoreState: { workflowRunningData: baseRunningData() }, - }) - - act(() => { - result.current.handleWorkflowNodeLoopStarted( - { data: { node_id: 'n1', metadata: { loop_length: 5 } } } as LoopStartedResponse, - containerParams, - ) - }) - - expect(store.getState().workflowRunningData!.tracing![0].status).toBe(NodeRunningStatus.Running) - - await waitFor(() => { - const transform = result.current.reactFlowStore.getState().transform - expect(transform[0]).toBe(200) - expect(transform[1]).toBe(310) - expect(transform[2]).toBe(1) - - const node = result.current.nodes.find(item => item.id === 'n1') - expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) - expect(getNodeRuntimeState(node)._loopLength).toBe(5) - expect(getNodeRuntimeState(node)._waitingRun).toBe(false) - expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) - }) - }) -}) - -describe('useWorkflowNodeHumanInputRequired', () => { - it('should create humanInputFormDataList and set tracing/node to Paused', async () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { - nodes: [ - createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), - createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), - ], - edges: [], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n1', status: NodeRunningStatus.Running }], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeHumanInputRequired({ - data: { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: 'content' }, - } as HumanInputRequiredResponse) - }) - - const state = store.getState().workflowRunningData! - expect(state.humanInputFormDataList).toHaveLength(1) - expect(state.humanInputFormDataList![0].form_id).toBe('f1') - expect(state.tracing![0].status).toBe(NodeRunningStatus.Paused) - - await waitFor(() => { - expect(getNodeRuntimeState(result.current.nodes.find(item => item.id === 'n1'))._runningStatus).toBe(NodeRunningStatus.Paused) - }) - }) - - it('should update existing form entry for same node_id', () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { - nodes: [ - createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), - createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), - ], - edges: [], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n1', status: NodeRunningStatus.Running }], - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'old', node_title: 'Node 1', form_content: 'old' }, - ], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeHumanInputRequired({ - data: { node_id: 'n1', form_id: 'new', node_title: 'Node 1', form_content: 'new' }, - } as HumanInputRequiredResponse) - }) - - const formList = store.getState().workflowRunningData!.humanInputFormDataList! - expect(formList).toHaveLength(1) - expect(formList[0].form_id).toBe('new') - }) - - it('should append new form entry for different node_id', () => { - const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { - nodes: [ - createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), - createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), - ], - edges: [], - initialStoreState: { - workflowRunningData: baseRunningData({ - tracing: [{ node_id: 'n2', status: NodeRunningStatus.Running }], - humanInputFormDataList: [ - { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, - ], - }), - }, - }) - - act(() => { - result.current.handleWorkflowNodeHumanInputRequired({ - data: { node_id: 'n2', form_id: 'f2', node_title: 'Node 2', form_content: 'content2' }, - } as HumanInputRequiredResponse) - }) - - expect(store.getState().workflowRunningData!.humanInputFormDataList).toHaveLength(2) - }) -}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-run.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-run.spec.ts new file mode 100644 index 0000000000..ff8c64656e --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-run.spec.ts @@ -0,0 +1,24 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowRun } from '../use-workflow-run' + +describe('useWorkflowRun', () => { + it('returns workflow run handlers from hooks store', () => { + const handlers = { + handleBackupDraft: vi.fn(), + handleLoadBackupDraft: vi.fn(), + handleRestoreFromPublishedWorkflow: vi.fn(), + handleRun: vi.fn(), + handleStopRun: vi.fn(), + } + + const { result } = renderWorkflowHook(() => useWorkflowRun(), { + hooksStoreProps: handlers, + }) + + expect(result.current.handleBackupDraft).toBe(handlers.handleBackupDraft) + expect(result.current.handleLoadBackupDraft).toBe(handlers.handleLoadBackupDraft) + expect(result.current.handleRestoreFromPublishedWorkflow).toBe(handlers.handleRestoreFromPublishedWorkflow) + expect(result.current.handleRun).toBe(handlers.handleRun) + expect(result.current.handleStopRun).toBe(handlers.handleStopRun) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-search.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-search.spec.tsx new file mode 100644 index 0000000000..4e9f4c9b45 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-search.spec.tsx @@ -0,0 +1,119 @@ +import type { CommonNodeType, Node, ToolWithProvider } from '../../types' +import { act, renderHook } from '@testing-library/react' +import { workflowNodesAction } from '@/app/components/goto-anything/actions/workflow-nodes' +import { CollectionType } from '@/app/components/tools/types' +import { BlockEnum } from '../../types' +import { useWorkflowSearch } from '../use-workflow-search' + +const mockHandleNodeSelect = vi.hoisted(() => vi.fn()) +const runtimeNodes = vi.hoisted(() => [] as Node[]) + +vi.mock('reactflow', () => ({ + useNodes: () => runtimeNodes, +})) + +vi.mock('../use-nodes-interactions', () => ({ + useNodesInteractions: () => ({ + handleNodeSelect: mockHandleNodeSelect, + }), +})) + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ + data: [{ + id: 'provider-1', + icon: 'tool-icon', + tools: [], + }] satisfies Partial[], + }), + useAllCustomTools: () => ({ data: [] }), + useAllWorkflowTools: () => ({ data: [] }), + useAllMCPTools: () => ({ data: [] }), +})) + +const createNode = (overrides: Partial = {}): Node => ({ + id: 'node-1', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.LLM, + title: 'Writer', + desc: 'Draft content', + } as CommonNodeType, + ...overrides, +}) + +describe('useWorkflowSearch', () => { + beforeEach(() => { + vi.clearAllMocks() + runtimeNodes.length = 0 + workflowNodesAction.searchFn = undefined + }) + + it('registers workflow node search results with tool icons and llm metadata scoring', async () => { + runtimeNodes.push( + createNode({ + id: 'llm-1', + data: { + type: BlockEnum.LLM, + title: 'Writer', + desc: 'Draft content', + model: { + provider: 'openai', + name: 'gpt-4o', + mode: 'chat', + }, + } as CommonNodeType, + }), + createNode({ + id: 'tool-1', + data: { + type: BlockEnum.Tool, + title: 'Google Search', + desc: 'Search the web', + provider_type: CollectionType.builtIn, + provider_id: 'provider-1', + } as CommonNodeType, + }), + createNode({ + id: 'internal-start', + data: { + type: BlockEnum.IterationStart, + title: 'Internal Start', + desc: '', + } as CommonNodeType, + }), + ) + + const { unmount } = renderHook(() => useWorkflowSearch()) + + const llmResults = await workflowNodesAction.search('', 'gpt') + expect(llmResults.map(item => item.id)).toEqual(['llm-1']) + expect(llmResults[0]?.title).toBe('Writer') + + const toolResults = await workflowNodesAction.search('', 'search') + expect(toolResults.map(item => item.id)).toEqual(['tool-1']) + expect(toolResults[0]?.description).toBe('Search the web') + + unmount() + + expect(workflowNodesAction.searchFn).toBeUndefined() + }) + + it('binds the node selection listener to handleNodeSelect', () => { + const { unmount } = renderHook(() => useWorkflowSearch()) + + act(() => { + document.dispatchEvent(new CustomEvent('workflow:select-node', { + detail: { + nodeId: 'node-42', + focus: false, + }, + })) + }) + + expect(mockHandleNodeSelect).toHaveBeenCalledWith('node-42') + + unmount() + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-start-run.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-start-run.spec.tsx new file mode 100644 index 0000000000..fdde912285 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-start-run.spec.tsx @@ -0,0 +1,28 @@ +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowStartRun } from '../use-workflow-start-run' + +describe('useWorkflowStartRun', () => { + it('returns start-run handlers from hooks store', () => { + const handlers = { + handleStartWorkflowRun: vi.fn(), + handleWorkflowStartRunInWorkflow: vi.fn(), + handleWorkflowStartRunInChatflow: vi.fn(), + handleWorkflowTriggerScheduleRunInWorkflow: vi.fn(), + handleWorkflowTriggerWebhookRunInWorkflow: vi.fn(), + handleWorkflowTriggerPluginRunInWorkflow: vi.fn(), + handleWorkflowRunAllTriggersInWorkflow: vi.fn(), + } + + const { result } = renderWorkflowHook(() => useWorkflowStartRun(), { + hooksStoreProps: handlers, + }) + + expect(result.current.handleStartWorkflowRun).toBe(handlers.handleStartWorkflowRun) + expect(result.current.handleWorkflowStartRunInWorkflow).toBe(handlers.handleWorkflowStartRunInWorkflow) + expect(result.current.handleWorkflowStartRunInChatflow).toBe(handlers.handleWorkflowStartRunInChatflow) + expect(result.current.handleWorkflowTriggerScheduleRunInWorkflow).toBe(handlers.handleWorkflowTriggerScheduleRunInWorkflow) + expect(result.current.handleWorkflowTriggerWebhookRunInWorkflow).toBe(handlers.handleWorkflowTriggerWebhookRunInWorkflow) + expect(result.current.handleWorkflowTriggerPluginRunInWorkflow).toBe(handlers.handleWorkflowTriggerPluginRunInWorkflow) + expect(result.current.handleWorkflowRunAllTriggersInWorkflow).toBe(handlers.handleWorkflowRunAllTriggersInWorkflow) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-update.spec.tsx b/web/app/components/workflow/hooks/__tests__/use-workflow-update.spec.tsx new file mode 100644 index 0000000000..8bd2a1c4f3 --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-update.spec.tsx @@ -0,0 +1,66 @@ +import { act } from '@testing-library/react' +import { createNode } from '../../__tests__/fixtures' +import { renderWorkflowHook } from '../../__tests__/workflow-test-env' +import { useWorkflowUpdate } from '../use-workflow-update' + +const mockSetViewport = vi.hoisted(() => vi.fn()) +const mockEventEmit = vi.hoisted(() => vi.fn()) +const mockInitialNodes = vi.hoisted(() => vi.fn((nodes: unknown[], _edges: unknown[]) => nodes)) +const mockInitialEdges = vi.hoisted(() => vi.fn((edges: unknown[], _nodes: unknown[]) => edges)) + +vi.mock('reactflow', () => ({ + Position: { + Left: 'left', + Right: 'right', + Top: 'top', + Bottom: 'bottom', + }, + useReactFlow: () => ({ + setViewport: mockSetViewport, + }), +})) + +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + emit: (...args: unknown[]) => mockEventEmit(...args), + }, + }), +})) + +vi.mock('../../utils', async importOriginal => ({ + ...(await importOriginal()), + initialNodes: (nodes: unknown[], edges: unknown[]) => mockInitialNodes(nodes, edges), + initialEdges: (edges: unknown[], nodes: unknown[]) => mockInitialEdges(edges, nodes), +})) + +describe('useWorkflowUpdate', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('emits initialized data and only sets a valid viewport', () => { + const { result } = renderWorkflowHook(() => useWorkflowUpdate()) + + act(() => { + result.current.handleUpdateWorkflowCanvas({ + nodes: [createNode({ id: 'n1' })], + edges: [], + viewport: { x: 10, y: 20, zoom: 0.5 }, + } as never) + result.current.handleUpdateWorkflowCanvas({ + nodes: [], + edges: [], + viewport: { x: 'bad' } as never, + }) + }) + + expect(mockInitialNodes).toHaveBeenCalled() + expect(mockInitialEdges).toHaveBeenCalled() + expect(mockEventEmit).toHaveBeenCalledWith(expect.objectContaining({ + type: 'WORKFLOW_DATA_UPDATE', + })) + expect(mockSetViewport).toHaveBeenCalledTimes(1) + expect(mockSetViewport).toHaveBeenCalledWith({ x: 10, y: 20, zoom: 0.5 }) + }) +}) diff --git a/web/app/components/workflow/hooks/__tests__/use-workflow-zoom.spec.ts b/web/app/components/workflow/hooks/__tests__/use-workflow-zoom.spec.ts new file mode 100644 index 0000000000..83bc1b27ad --- /dev/null +++ b/web/app/components/workflow/hooks/__tests__/use-workflow-zoom.spec.ts @@ -0,0 +1,86 @@ +import { act, renderHook } from '@testing-library/react' +import { useWorkflowZoom } from '../use-workflow-zoom' + +const { + mockFitView, + mockZoomIn, + mockZoomOut, + mockZoomTo, + mockHandleSyncWorkflowDraft, + runtimeState, +} = vi.hoisted(() => ({ + mockFitView: vi.fn(), + mockZoomIn: vi.fn(), + mockZoomOut: vi.fn(), + mockZoomTo: vi.fn(), + mockHandleSyncWorkflowDraft: vi.fn(), + runtimeState: { + workflowReadOnly: false, + }, +})) + +vi.mock('reactflow', () => ({ + useReactFlow: () => ({ + fitView: mockFitView, + zoomIn: mockZoomIn, + zoomOut: mockZoomOut, + zoomTo: mockZoomTo, + }), +})) + +vi.mock('../use-nodes-sync-draft', () => ({ + useNodesSyncDraft: () => ({ + handleSyncWorkflowDraft: (...args: unknown[]) => mockHandleSyncWorkflowDraft(...args), + }), +})) + +vi.mock('../use-workflow', () => ({ + useWorkflowReadOnly: () => ({ + getWorkflowReadOnly: () => runtimeState.workflowReadOnly, + }), +})) + +describe('useWorkflowZoom', () => { + beforeEach(() => { + vi.clearAllMocks() + runtimeState.workflowReadOnly = false + }) + + it('runs zoom actions and syncs the workflow draft when editable', () => { + const { result } = renderHook(() => useWorkflowZoom()) + + act(() => { + result.current.handleFitView() + result.current.handleBackToOriginalSize() + result.current.handleSizeToHalf() + result.current.handleZoomOut() + result.current.handleZoomIn() + }) + + expect(mockFitView).toHaveBeenCalledTimes(1) + expect(mockZoomTo).toHaveBeenCalledWith(1) + expect(mockZoomTo).toHaveBeenCalledWith(0.5) + expect(mockZoomOut).toHaveBeenCalledTimes(1) + expect(mockZoomIn).toHaveBeenCalledTimes(1) + expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(5) + }) + + it('blocks zoom actions when the workflow is read-only', () => { + runtimeState.workflowReadOnly = true + const { result } = renderHook(() => useWorkflowZoom()) + + act(() => { + result.current.handleFitView() + result.current.handleBackToOriginalSize() + result.current.handleSizeToHalf() + result.current.handleZoomOut() + result.current.handleZoomIn() + }) + + expect(mockFitView).not.toHaveBeenCalled() + expect(mockZoomTo).not.toHaveBeenCalled() + expect(mockZoomOut).not.toHaveBeenCalled() + expect(mockZoomIn).not.toHaveBeenCalled() + expect(mockHandleSyncWorkflowDraft).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/test-helpers.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/test-helpers.ts new file mode 100644 index 0000000000..8c2ed18f19 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/test-helpers.ts @@ -0,0 +1,186 @@ +import type { WorkflowRunningData } from '../../../types' +import type { + IterationFinishedResponse, + IterationNextResponse, + LoopFinishedResponse, + LoopNextResponse, + NodeFinishedResponse, + NodeStartedResponse, + WorkflowStartedResponse, +} from '@/types/workflow' +import { useEdges, useNodes, useStoreApi } from 'reactflow' +import { createEdge, createNode } from '../../../__tests__/fixtures' +import { renderWorkflowFlowHook } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus, WorkflowRunningStatus } from '../../../types' + +type NodeRuntimeState = { + _waitingRun?: boolean + _runningStatus?: NodeRunningStatus + _retryIndex?: number + _iterationIndex?: number + _iterationLength?: number + _loopIndex?: number + _loopLength?: number + _runningBranchId?: string +} + +type EdgeRuntimeState = { + _sourceRunningStatus?: NodeRunningStatus + _targetRunningStatus?: NodeRunningStatus + _waitingRun?: boolean +} + +export const getNodeRuntimeState = (node?: { data?: unknown }): NodeRuntimeState => + (node?.data ?? {}) as NodeRuntimeState + +export const getEdgeRuntimeState = (edge?: { data?: unknown }): EdgeRuntimeState => + (edge?.data ?? {}) as EdgeRuntimeState + +function createRunNodes() { + return [ + createNode({ + id: 'n1', + width: 200, + height: 80, + data: { _waitingRun: false }, + }), + ] +} + +function createRunEdges() { + return [ + createEdge({ + id: 'e1', + source: 'n0', + target: 'n1', + data: {}, + }), + ] +} + +export function createViewportNodes() { + return [ + createNode({ + id: 'n0', + width: 200, + height: 80, + data: { _runningStatus: NodeRunningStatus.Succeeded }, + }), + createNode({ + id: 'n1', + position: { x: 100, y: 50 }, + width: 200, + height: 80, + data: { _waitingRun: true }, + }), + createNode({ + id: 'n2', + position: { x: 400, y: 50 }, + width: 200, + height: 80, + parentId: 'n1', + data: { _waitingRun: true }, + }), + ] +} + +function createViewportEdges() { + return [ + createEdge({ + id: 'e1', + source: 'n0', + target: 'n1', + sourceHandle: 'source', + data: {}, + }), + ] +} + +export const containerParams = { clientWidth: 1200, clientHeight: 800 } + +export function renderRunEventHook>( + useHook: () => T, + options?: { + nodes?: ReturnType + edges?: ReturnType + initialStoreState?: Record + }, +) { + const { nodes = createRunNodes(), edges = createRunEdges(), initialStoreState } = options ?? {} + + return renderWorkflowFlowHook(() => ({ + ...useHook(), + nodes: useNodes(), + edges: useEdges(), + }), { + nodes, + edges, + reactFlowProps: { fitView: false }, + initialStoreState, + }) +} + +export function renderViewportHook>( + useHook: () => T, + options?: { + nodes?: ReturnType + edges?: ReturnType + initialStoreState?: Record + }, +) { + const { + nodes = createViewportNodes(), + edges = createViewportEdges(), + initialStoreState, + } = options ?? {} + + return renderWorkflowFlowHook(() => ({ + ...useHook(), + nodes: useNodes(), + edges: useEdges(), + reactFlowStore: useStoreApi(), + }), { + nodes, + edges, + reactFlowProps: { fitView: false }, + initialStoreState, + }) +} + +export const createStartedResponse = (overrides: Partial = {}): WorkflowStartedResponse => ({ + task_id: 'task-2', + data: { id: 'run-1', workflow_id: 'wf-1', created_at: 1000 }, + ...overrides, +} as WorkflowStartedResponse) + +export const createNodeFinishedResponse = (overrides: Partial = {}): NodeFinishedResponse => ({ + data: { id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, + ...overrides, +} as NodeFinishedResponse) + +export const createIterationNextResponse = (overrides: Partial = {}): IterationNextResponse => ({ + data: { node_id: 'n1' }, + ...overrides, +} as IterationNextResponse) + +export const createIterationFinishedResponse = (overrides: Partial = {}): IterationFinishedResponse => ({ + data: { id: 'iter-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, + ...overrides, +} as IterationFinishedResponse) + +export const createLoopNextResponse = (overrides: Partial = {}): LoopNextResponse => ({ + data: { node_id: 'n1', index: 5 }, + ...overrides, +} as LoopNextResponse) + +export const createLoopFinishedResponse = (overrides: Partial = {}): LoopFinishedResponse => ({ + data: { id: 'loop-1', node_id: 'n1', status: NodeRunningStatus.Succeeded }, + ...overrides, +} as LoopFinishedResponse) + +export const createNodeStartedResponse = (overrides: Partial = {}): NodeStartedResponse => ({ + data: { node_id: 'n1' }, + ...overrides, +} as NodeStartedResponse) + +export const pausedRunningData = (): WorkflowRunningData['result'] => ({ status: WorkflowRunningStatus.Paused } as WorkflowRunningData['result']) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-agent-log.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-agent-log.spec.ts new file mode 100644 index 0000000000..cabfc0f6d1 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-agent-log.spec.ts @@ -0,0 +1,83 @@ +import type { AgentLogResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowAgentLog } from '../use-workflow-agent-log' + +vi.mock('@/app/components/base/file-uploader/utils', () => ({ + getFilesInLogs: vi.fn(() => []), +})) + +describe('useWorkflowAgentLog', () => { + it('creates agent_log when execution_metadata has none', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n1', execution_metadata: {} }], + }), + }, + }) + + result.current.handleWorkflowAgentLog({ + data: { node_id: 'n1', message_id: 'm1' }, + } as AgentLogResponse) + + const trace = store.getState().workflowRunningData!.tracing![0] + expect(trace.execution_metadata!.agent_log).toHaveLength(1) + expect(trace.execution_metadata!.agent_log![0].message_id).toBe('m1') + }) + + it('appends to existing agent_log', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ + node_id: 'n1', + execution_metadata: { agent_log: [{ message_id: 'm1', text: 'log1' }] }, + }], + }), + }, + }) + + result.current.handleWorkflowAgentLog({ + data: { node_id: 'n1', message_id: 'm2' }, + } as AgentLogResponse) + + expect(store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log).toHaveLength(2) + }) + + it('updates an existing log entry by message_id', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ + node_id: 'n1', + execution_metadata: { agent_log: [{ message_id: 'm1', text: 'old' }] }, + }], + }), + }, + }) + + result.current.handleWorkflowAgentLog({ + data: { node_id: 'n1', message_id: 'm1', text: 'new' }, + } as unknown as AgentLogResponse) + + const log = store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log! + expect(log).toHaveLength(1) + expect((log[0] as unknown as { text: string }).text).toBe('new') + }) + + it('creates execution_metadata when it does not exist', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowAgentLog(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n1' }], + }), + }, + }) + + result.current.handleWorkflowAgentLog({ + data: { node_id: 'n1', message_id: 'm1' }, + } as AgentLogResponse) + + expect(store.getState().workflowRunningData!.tracing![0].execution_metadata!.agent_log).toHaveLength(1) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-failed.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-failed.spec.ts new file mode 100644 index 0000000000..53ee281f7e --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-failed.spec.ts @@ -0,0 +1,15 @@ +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { WorkflowRunningStatus } from '../../../types' +import { useWorkflowFailed } from '../use-workflow-failed' + +describe('useWorkflowFailed', () => { + it('sets status to Failed', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowFailed(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + result.current.handleWorkflowFailed() + + expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Failed) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-finished.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-finished.spec.ts new file mode 100644 index 0000000000..910b64ed18 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-finished.spec.ts @@ -0,0 +1,32 @@ +import type { WorkflowFinishedResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowFinished } from '../use-workflow-finished' + +describe('useWorkflowFinished', () => { + it('merges data into result and activates result tab for single string output', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowFinished(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + result.current.handleWorkflowFinished({ + data: { status: 'succeeded', outputs: { answer: 'hello' } }, + } as WorkflowFinishedResponse) + + const state = store.getState().workflowRunningData! + expect(state.result.status).toBe('succeeded') + expect(state.resultTabActive).toBe(true) + expect(state.resultText).toBe('hello') + }) + + it('does not activate the result tab for multi-key outputs', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowFinished(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + result.current.handleWorkflowFinished({ + data: { status: 'succeeded', outputs: { a: 'hello', b: 'world' } }, + } as WorkflowFinishedResponse) + + expect(store.getState().workflowRunningData!.resultTabActive).toBeFalsy() + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-finished.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-finished.spec.ts new file mode 100644 index 0000000000..efcdc15d88 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-finished.spec.ts @@ -0,0 +1,73 @@ +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { BlockEnum, NodeRunningStatus } from '../../../types' +import { useWorkflowNodeFinished } from '../use-workflow-node-finished' +import { + createNodeFinishedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeFinished', () => { + it('updates tracing and node running status', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeFinished(), { + nodes: [ + createNode({ + id: 'n1', + data: { _runningStatus: NodeRunningStatus.Running }, + }), + ], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Running }], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeFinished(createNodeFinishedResponse()) + }) + + const trace = store.getState().workflowRunningData!.tracing![0] + expect(trace.status).toBe(NodeRunningStatus.Succeeded) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) + }) + }) + + it('sets _runningBranchId for IfElse nodes', async () => { + const { result } = renderRunEventHook(() => useWorkflowNodeFinished(), { + nodes: [ + createNode({ + id: 'n1', + data: { _runningStatus: NodeRunningStatus.Running }, + }), + ], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ id: 'trace-1', node_id: 'n1', status: NodeRunningStatus.Running }], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeFinished(createNodeFinishedResponse({ + data: { + id: 'trace-1', + node_id: 'n1', + node_type: BlockEnum.IfElse, + status: NodeRunningStatus.Succeeded, + outputs: { selected_case_id: 'branch-a' }, + } as never, + })) + }) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._runningBranchId).toBe('branch-a') + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-filled.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-filled.spec.ts new file mode 100644 index 0000000000..aa8e89327b --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-filled.spec.ts @@ -0,0 +1,44 @@ +import type { HumanInputFormFilledResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowNodeHumanInputFormFilled } from '../use-workflow-node-human-input-form-filled' + +describe('useWorkflowNodeHumanInputFormFilled', () => { + it('removes the form from pending and adds it to filled', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormFilled(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, + ], + }), + }, + }) + + result.current.handleWorkflowNodeHumanInputFormFilled({ + data: { node_id: 'n1', node_title: 'Node 1', rendered_content: 'done' }, + } as HumanInputFormFilledResponse) + + const state = store.getState().workflowRunningData! + expect(state.humanInputFormDataList).toHaveLength(0) + expect(state.humanInputFilledFormDataList).toHaveLength(1) + expect(state.humanInputFilledFormDataList![0].node_id).toBe('n1') + }) + + it('creates humanInputFilledFormDataList when it does not exist', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormFilled(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, + ], + }), + }, + }) + + result.current.handleWorkflowNodeHumanInputFormFilled({ + data: { node_id: 'n1', node_title: 'Node 1', rendered_content: 'done' }, + } as HumanInputFormFilledResponse) + + expect(store.getState().workflowRunningData!.humanInputFilledFormDataList).toBeDefined() + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-timeout.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-timeout.spec.ts new file mode 100644 index 0000000000..e528b49846 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-form-timeout.spec.ts @@ -0,0 +1,23 @@ +import type { HumanInputFormTimeoutResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowNodeHumanInputFormTimeout } from '../use-workflow-node-human-input-form-timeout' + +describe('useWorkflowNodeHumanInputFormTimeout', () => { + it('sets expiration_time on the matching form', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowNodeHumanInputFormTimeout(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '', expiration_time: 0 }, + ], + }), + }, + }) + + result.current.handleWorkflowNodeHumanInputFormTimeout({ + data: { node_id: 'n1', node_title: 'Node 1', expiration_time: 1000 }, + } as HumanInputFormTimeoutResponse) + + expect(store.getState().workflowRunningData!.humanInputFormDataList![0].expiration_time).toBe(1000) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-required.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-required.spec.ts new file mode 100644 index 0000000000..23fdf8a3c3 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-human-input-required.spec.ts @@ -0,0 +1,96 @@ +import type { HumanInputRequiredResponse } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeHumanInputRequired } from '../use-workflow-node-human-input-required' +import { + getNodeRuntimeState, + renderViewportHook, +} from './test-helpers' + +describe('useWorkflowNodeHumanInputRequired', () => { + it('creates humanInputFormDataList and sets tracing and node to Paused', async () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { + nodes: [ + createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), + createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), + ], + edges: [], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n1', status: NodeRunningStatus.Running }], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeHumanInputRequired({ + data: { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: 'content' }, + } as HumanInputRequiredResponse) + }) + + const state = store.getState().workflowRunningData! + expect(state.humanInputFormDataList).toHaveLength(1) + expect(state.humanInputFormDataList![0].form_id).toBe('f1') + expect(state.tracing![0].status).toBe(NodeRunningStatus.Paused) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes.find(item => item.id === 'n1'))._runningStatus).toBe(NodeRunningStatus.Paused) + }) + }) + + it('updates existing form entry for the same node_id', () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { + nodes: [ + createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), + createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), + ], + edges: [], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n1', status: NodeRunningStatus.Running }], + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'old', node_title: 'Node 1', form_content: 'old' }, + ], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeHumanInputRequired({ + data: { node_id: 'n1', form_id: 'new', node_title: 'Node 1', form_content: 'new' }, + } as HumanInputRequiredResponse) + }) + + const formList = store.getState().workflowRunningData!.humanInputFormDataList! + expect(formList).toHaveLength(1) + expect(formList[0].form_id).toBe('new') + }) + + it('appends a new form entry for a different node_id', () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeHumanInputRequired(), { + nodes: [ + createNode({ id: 'n1', data: { _runningStatus: NodeRunningStatus.Running } }), + createNode({ id: 'n2', position: { x: 300, y: 0 }, data: { _runningStatus: NodeRunningStatus.Running } }), + ], + edges: [], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ node_id: 'n2', status: NodeRunningStatus.Running }], + humanInputFormDataList: [ + { node_id: 'n1', form_id: 'f1', node_title: 'Node 1', form_content: '' }, + ], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeHumanInputRequired({ + data: { node_id: 'n2', form_id: 'f2', node_title: 'Node 2', form_content: 'content2' }, + } as HumanInputRequiredResponse) + }) + + expect(store.getState().workflowRunningData!.humanInputFormDataList).toHaveLength(2) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-finished.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-finished.spec.ts new file mode 100644 index 0000000000..87617f0835 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-finished.spec.ts @@ -0,0 +1,42 @@ +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { DEFAULT_ITER_TIMES } from '../../../constants' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeIterationFinished } from '../use-workflow-node-iteration-finished' +import { + createIterationFinishedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeIterationFinished', () => { + it('updates tracing, resets iterTimes, updates node status and edges', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeIterationFinished(), { + nodes: [ + createNode({ + id: 'n1', + data: { _runningStatus: NodeRunningStatus.Running }, + }), + ], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ id: 'iter-1', node_id: 'n1', status: NodeRunningStatus.Running }], + }), + iterTimes: 10, + }, + }) + + act(() => { + result.current.handleWorkflowNodeIterationFinished(createIterationFinishedResponse()) + }) + + expect(store.getState().iterTimes).toBe(DEFAULT_ITER_TIMES) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-next.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-next.spec.ts new file mode 100644 index 0000000000..ac5f2f02ea --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-next.spec.ts @@ -0,0 +1,28 @@ +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { useWorkflowNodeIterationNext } from '../use-workflow-node-iteration-next' +import { + createIterationNextResponse, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeIterationNext', () => { + it('sets _iterationIndex and increments iterTimes', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeIterationNext(), { + initialStoreState: { + workflowRunningData: baseRunningData(), + iterTimes: 3, + }, + }) + + act(() => { + result.current.handleWorkflowNodeIterationNext(createIterationNextResponse()) + }) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._iterationIndex).toBe(3) + }) + expect(store.getState().iterTimes).toBe(4) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-started.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-started.spec.ts new file mode 100644 index 0000000000..ccff1b288b --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-iteration-started.spec.ts @@ -0,0 +1,49 @@ +import type { IterationStartedResponse } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { DEFAULT_ITER_TIMES } from '../../../constants' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeIterationStarted } from '../use-workflow-node-iteration-started' +import { + containerParams, + createViewportNodes, + getEdgeRuntimeState, + getNodeRuntimeState, + renderViewportHook, +} from './test-helpers' + +describe('useWorkflowNodeIterationStarted', () => { + it('pushes to tracing, resets iterTimes, sets viewport, and updates node with _iterationLength', async () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeIterationStarted(), { + nodes: createViewportNodes().slice(0, 2), + initialStoreState: { + workflowRunningData: baseRunningData(), + iterTimes: 99, + }, + }) + + act(() => { + result.current.handleWorkflowNodeIterationStarted( + { data: { node_id: 'n1', metadata: { iterator_length: 10 } } } as IterationStartedResponse, + containerParams, + ) + }) + + const tracing = store.getState().workflowRunningData!.tracing! + expect(tracing[0].status).toBe(NodeRunningStatus.Running) + expect(store.getState().iterTimes).toBe(DEFAULT_ITER_TIMES) + + await waitFor(() => { + const transform = result.current.reactFlowStore.getState().transform + expect(transform[0]).toBe(200) + expect(transform[1]).toBe(310) + expect(transform[2]).toBe(1) + + const node = result.current.nodes.find(item => item.id === 'n1') + expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(node)._iterationLength).toBe(10) + expect(getNodeRuntimeState(node)._waitingRun).toBe(false) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-finished.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-finished.spec.ts new file mode 100644 index 0000000000..7acd9897ed --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-finished.spec.ts @@ -0,0 +1,40 @@ +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeLoopFinished } from '../use-workflow-node-loop-finished' +import { + createLoopFinishedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeLoopFinished', () => { + it('updates tracing, node status and edges', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeLoopFinished(), { + nodes: [ + createNode({ + id: 'n1', + data: { _runningStatus: NodeRunningStatus.Running }, + }), + ], + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [{ id: 'loop-1', node_id: 'n1', status: NodeRunningStatus.Running }], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeLoopFinished(createLoopFinishedResponse()) + }) + + expect(store.getState().workflowRunningData!.tracing![0].status).toBe(NodeRunningStatus.Succeeded) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._runningStatus).toBe(NodeRunningStatus.Succeeded) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Succeeded) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-next.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-next.spec.ts new file mode 100644 index 0000000000..5baa44c983 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-next.spec.ts @@ -0,0 +1,38 @@ +import { act, waitFor } from '@testing-library/react' +import { createNode } from '../../../__tests__/fixtures' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeLoopNext } from '../use-workflow-node-loop-next' +import { + createLoopNextResponse, + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeLoopNext', () => { + it('sets _loopIndex and resets child nodes to waiting', async () => { + const { result } = renderRunEventHook(() => useWorkflowNodeLoopNext(), { + nodes: [ + createNode({ id: 'n1', data: {} }), + createNode({ + id: 'n2', + position: { x: 300, y: 0 }, + parentId: 'n1', + data: { _waitingRun: false }, + }), + ], + edges: [], + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeLoopNext(createLoopNextResponse()) + }) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n1'))._loopIndex).toBe(5) + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n2'))._waitingRun).toBe(true) + expect(getNodeRuntimeState(result.current.nodes.find(node => node.id === 'n2'))._runningStatus).toBe(NodeRunningStatus.Waiting) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-started.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-started.spec.ts new file mode 100644 index 0000000000..b0e8bf2cc5 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-loop-started.spec.ts @@ -0,0 +1,43 @@ +import type { LoopStartedResponse } from '@/types/workflow' +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeLoopStarted } from '../use-workflow-node-loop-started' +import { + containerParams, + createViewportNodes, + getEdgeRuntimeState, + getNodeRuntimeState, + renderViewportHook, +} from './test-helpers' + +describe('useWorkflowNodeLoopStarted', () => { + it('pushes to tracing, sets viewport, and updates node with _loopLength', async () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeLoopStarted(), { + nodes: createViewportNodes().slice(0, 2), + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeLoopStarted( + { data: { node_id: 'n1', metadata: { loop_length: 5 } } } as LoopStartedResponse, + containerParams, + ) + }) + + expect(store.getState().workflowRunningData!.tracing![0].status).toBe(NodeRunningStatus.Running) + + await waitFor(() => { + const transform = result.current.reactFlowStore.getState().transform + expect(transform[0]).toBe(200) + expect(transform[1]).toBe(310) + expect(transform[2]).toBe(1) + + const node = result.current.nodes.find(item => item.id === 'n1') + expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(node)._loopLength).toBe(5) + expect(getNodeRuntimeState(node)._waitingRun).toBe(false) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-retry.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-retry.spec.ts new file mode 100644 index 0000000000..b3c6b814b1 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-retry.spec.ts @@ -0,0 +1,27 @@ +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { useWorkflowNodeRetry } from '../use-workflow-node-retry' +import { + getNodeRuntimeState, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowNodeRetry', () => { + it('pushes retry data to tracing and updates _retryIndex', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowNodeRetry(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeRetry({ + data: { node_id: 'n1', retry_index: 2 }, + } as never) + }) + + expect(store.getState().workflowRunningData!.tracing).toHaveLength(1) + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._retryIndex).toBe(2) + }) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-started.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-started.spec.ts new file mode 100644 index 0000000000..a8a52e0a84 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-node-started.spec.ts @@ -0,0 +1,80 @@ +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { NodeRunningStatus } from '../../../types' +import { useWorkflowNodeStarted } from '../use-workflow-node-started' +import { + containerParams, + createNodeStartedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + renderViewportHook, +} from './test-helpers' + +describe('useWorkflowNodeStarted', () => { + it('pushes to tracing, sets node running, and adjusts viewport for root node', async () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeStarted(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeStarted(createNodeStartedResponse(), containerParams) + }) + + const tracing = store.getState().workflowRunningData!.tracing! + expect(tracing).toHaveLength(1) + expect(tracing[0].status).toBe(NodeRunningStatus.Running) + + await waitFor(() => { + const transform = result.current.reactFlowStore.getState().transform + expect(transform[0]).toBe(200) + expect(transform[1]).toBe(310) + expect(transform[2]).toBe(1) + + const node = result.current.nodes.find(item => item.id === 'n1') + expect(getNodeRuntimeState(node)._runningStatus).toBe(NodeRunningStatus.Running) + expect(getNodeRuntimeState(node)._waitingRun).toBe(false) + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBe(NodeRunningStatus.Running) + }) + }) + + it('does not adjust viewport for child nodes', async () => { + const { result } = renderViewportHook(() => useWorkflowNodeStarted(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowNodeStarted(createNodeStartedResponse({ + data: { node_id: 'n2' } as never, + }), containerParams) + }) + + await waitFor(() => { + const transform = result.current.reactFlowStore.getState().transform + expect(transform[0]).toBe(0) + expect(transform[1]).toBe(0) + expect(transform[2]).toBe(1) + expect(getNodeRuntimeState(result.current.nodes.find(item => item.id === 'n2'))._runningStatus).toBe(NodeRunningStatus.Running) + }) + }) + + it('updates existing tracing entry when node_id already exists', () => { + const { result, store } = renderViewportHook(() => useWorkflowNodeStarted(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + tracing: [ + { node_id: 'n0', status: NodeRunningStatus.Succeeded } as never, + { node_id: 'n1', status: NodeRunningStatus.Succeeded } as never, + ], + }), + }, + }) + + act(() => { + result.current.handleWorkflowNodeStarted(createNodeStartedResponse(), containerParams) + }) + + const tracing = store.getState().workflowRunningData!.tracing! + expect(tracing).toHaveLength(2) + expect(tracing[1].status).toBe(NodeRunningStatus.Running) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-paused.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-paused.spec.ts new file mode 100644 index 0000000000..9cfb8f62d9 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-paused.spec.ts @@ -0,0 +1,15 @@ +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { WorkflowRunningStatus } from '../../../types' +import { useWorkflowPaused } from '../use-workflow-paused' + +describe('useWorkflowPaused', () => { + it('sets status to Paused', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowPaused(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + result.current.handleWorkflowPaused() + + expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Paused) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-run-event.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-run-event.spec.ts new file mode 100644 index 0000000000..fb8ea51638 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-run-event.spec.ts @@ -0,0 +1,54 @@ +import { renderHook } from '@testing-library/react' +import { useWorkflowRunEvent } from '../use-workflow-run-event' + +const handlers = vi.hoisted(() => ({ + handleWorkflowStarted: vi.fn(), + handleWorkflowFinished: vi.fn(), + handleWorkflowFailed: vi.fn(), + handleWorkflowNodeStarted: vi.fn(), + handleWorkflowNodeFinished: vi.fn(), + handleWorkflowNodeIterationStarted: vi.fn(), + handleWorkflowNodeIterationNext: vi.fn(), + handleWorkflowNodeIterationFinished: vi.fn(), + handleWorkflowNodeLoopStarted: vi.fn(), + handleWorkflowNodeLoopNext: vi.fn(), + handleWorkflowNodeLoopFinished: vi.fn(), + handleWorkflowNodeRetry: vi.fn(), + handleWorkflowTextChunk: vi.fn(), + handleWorkflowTextReplace: vi.fn(), + handleWorkflowAgentLog: vi.fn(), + handleWorkflowPaused: vi.fn(), + handleWorkflowNodeHumanInputRequired: vi.fn(), + handleWorkflowNodeHumanInputFormFilled: vi.fn(), + handleWorkflowNodeHumanInputFormTimeout: vi.fn(), +})) + +vi.mock('..', () => ({ + useWorkflowStarted: () => ({ handleWorkflowStarted: handlers.handleWorkflowStarted }), + useWorkflowFinished: () => ({ handleWorkflowFinished: handlers.handleWorkflowFinished }), + useWorkflowFailed: () => ({ handleWorkflowFailed: handlers.handleWorkflowFailed }), + useWorkflowNodeStarted: () => ({ handleWorkflowNodeStarted: handlers.handleWorkflowNodeStarted }), + useWorkflowNodeFinished: () => ({ handleWorkflowNodeFinished: handlers.handleWorkflowNodeFinished }), + useWorkflowNodeIterationStarted: () => ({ handleWorkflowNodeIterationStarted: handlers.handleWorkflowNodeIterationStarted }), + useWorkflowNodeIterationNext: () => ({ handleWorkflowNodeIterationNext: handlers.handleWorkflowNodeIterationNext }), + useWorkflowNodeIterationFinished: () => ({ handleWorkflowNodeIterationFinished: handlers.handleWorkflowNodeIterationFinished }), + useWorkflowNodeLoopStarted: () => ({ handleWorkflowNodeLoopStarted: handlers.handleWorkflowNodeLoopStarted }), + useWorkflowNodeLoopNext: () => ({ handleWorkflowNodeLoopNext: handlers.handleWorkflowNodeLoopNext }), + useWorkflowNodeLoopFinished: () => ({ handleWorkflowNodeLoopFinished: handlers.handleWorkflowNodeLoopFinished }), + useWorkflowNodeRetry: () => ({ handleWorkflowNodeRetry: handlers.handleWorkflowNodeRetry }), + useWorkflowTextChunk: () => ({ handleWorkflowTextChunk: handlers.handleWorkflowTextChunk }), + useWorkflowTextReplace: () => ({ handleWorkflowTextReplace: handlers.handleWorkflowTextReplace }), + useWorkflowAgentLog: () => ({ handleWorkflowAgentLog: handlers.handleWorkflowAgentLog }), + useWorkflowPaused: () => ({ handleWorkflowPaused: handlers.handleWorkflowPaused }), + useWorkflowNodeHumanInputRequired: () => ({ handleWorkflowNodeHumanInputRequired: handlers.handleWorkflowNodeHumanInputRequired }), + useWorkflowNodeHumanInputFormFilled: () => ({ handleWorkflowNodeHumanInputFormFilled: handlers.handleWorkflowNodeHumanInputFormFilled }), + useWorkflowNodeHumanInputFormTimeout: () => ({ handleWorkflowNodeHumanInputFormTimeout: handlers.handleWorkflowNodeHumanInputFormTimeout }), +})) + +describe('useWorkflowRunEvent', () => { + it('returns the composed handlers from all workflow event hooks', () => { + const { result } = renderHook(() => useWorkflowRunEvent()) + + expect(result.current).toEqual(handlers) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-started.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-started.spec.ts new file mode 100644 index 0000000000..4fd49c9c6a --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-started.spec.ts @@ -0,0 +1,56 @@ +import { act, waitFor } from '@testing-library/react' +import { baseRunningData } from '../../../__tests__/workflow-test-env' +import { WorkflowRunningStatus } from '../../../types' +import { useWorkflowStarted } from '../use-workflow-started' +import { + createStartedResponse, + getEdgeRuntimeState, + getNodeRuntimeState, + pausedRunningData, + renderRunEventHook, +} from './test-helpers' + +describe('useWorkflowStarted', () => { + it('initializes workflow running data and resets nodes and edges', async () => { + const { result, store } = renderRunEventHook(() => useWorkflowStarted(), { + initialStoreState: { workflowRunningData: baseRunningData() }, + }) + + act(() => { + result.current.handleWorkflowStarted(createStartedResponse()) + }) + + const state = store.getState().workflowRunningData! + expect(state.task_id).toBe('task-2') + expect(state.result.status).toBe(WorkflowRunningStatus.Running) + expect(state.resultText).toBe('') + + await waitFor(() => { + expect(getNodeRuntimeState(result.current.nodes[0])._waitingRun).toBe(true) + expect(getNodeRuntimeState(result.current.nodes[0])._runningBranchId).toBeUndefined() + expect(getEdgeRuntimeState(result.current.edges[0])._sourceRunningStatus).toBeUndefined() + expect(getEdgeRuntimeState(result.current.edges[0])._targetRunningStatus).toBeUndefined() + expect(getEdgeRuntimeState(result.current.edges[0])._waitingRun).toBe(true) + }) + }) + + it('resumes from Paused without resetting nodes or edges', () => { + const { result, store } = renderRunEventHook(() => useWorkflowStarted(), { + initialStoreState: { + workflowRunningData: baseRunningData({ + result: pausedRunningData(), + }), + }, + }) + + act(() => { + result.current.handleWorkflowStarted(createStartedResponse({ + data: { id: 'run-2', workflow_id: 'wf-1', created_at: 2000 }, + })) + }) + + expect(store.getState().workflowRunningData!.result.status).toBe(WorkflowRunningStatus.Running) + expect(getNodeRuntimeState(result.current.nodes[0])._waitingRun).toBe(false) + expect(getEdgeRuntimeState(result.current.edges[0])._waitingRun).toBeUndefined() + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-chunk.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-chunk.spec.ts new file mode 100644 index 0000000000..fcf36fe596 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-chunk.spec.ts @@ -0,0 +1,19 @@ +import type { TextChunkResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowTextChunk } from '../use-workflow-text-chunk' + +describe('useWorkflowTextChunk', () => { + it('appends text and activates the result tab', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowTextChunk(), { + initialStoreState: { + workflowRunningData: baseRunningData({ resultText: 'Hello' }), + }, + }) + + result.current.handleWorkflowTextChunk({ data: { text: ' World' } } as TextChunkResponse) + + const state = store.getState().workflowRunningData! + expect(state.resultText).toBe('Hello World') + expect(state.resultTabActive).toBe(true) + }) +}) diff --git a/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-replace.spec.ts b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-replace.spec.ts new file mode 100644 index 0000000000..f9c1dcb256 --- /dev/null +++ b/web/app/components/workflow/hooks/use-workflow-run-event/__tests__/use-workflow-text-replace.spec.ts @@ -0,0 +1,17 @@ +import type { TextReplaceResponse } from '@/types/workflow' +import { baseRunningData, renderWorkflowHook } from '../../../__tests__/workflow-test-env' +import { useWorkflowTextReplace } from '../use-workflow-text-replace' + +describe('useWorkflowTextReplace', () => { + it('replaces resultText', () => { + const { result, store } = renderWorkflowHook(() => useWorkflowTextReplace(), { + initialStoreState: { + workflowRunningData: baseRunningData({ resultText: 'old text' }), + }, + }) + + result.current.handleWorkflowTextReplace({ data: { text: 'new text' } } as TextReplaceResponse) + + expect(store.getState().workflowRunningData!.resultText).toBe('new text') + }) +}) diff --git a/web/app/components/workflow/nodes/_base/components/editor/base.tsx b/web/app/components/workflow/nodes/_base/components/editor/base.tsx index 6ed582369c..c0545ff01c 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/base.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/base.tsx @@ -84,7 +84,7 @@ const Base: FC = ({ return ( -
+
{title}
{ expect(screen.getByRole('link', { name: 'workflow.panel.helpLink' })).toHaveAttribute('href', 'https://docs.example.com/node') }) + it('should hide change action when node is undeletable', () => { + mockUseNodeMetaData.mockReturnValueOnce({ + isTypeFixed: false, + isSingleton: true, + isUndeletable: true, + description: 'Undeletable node', + author: 'Dify', + } as ReturnType) + + renderWorkflowFlowComponent( + , + { + nodes: [], + edges: [], + }, + ) + + expect(screen.getByText('workflow.panel.runThisStep')).toBeInTheDocument() + expect(screen.queryByText('workflow.panel.change')).not.toBeInTheDocument() + expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument() + }) + it('should render workflow-tool and readonly popup variants', () => { mockUseAllWorkflowTools.mockReturnValueOnce({ data: [{ id: 'workflow-tool', workflow_app_id: 'app-123' }], diff --git a/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx b/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx index b460aa651c..a93c3e1d14 100644 --- a/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx +++ b/web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx @@ -47,7 +47,7 @@ const PanelOperatorPopup = ({ const { nodesReadOnly } = useNodesReadOnly() const edge = edges.find(edge => edge.target === id) const nodeMetaData = useNodeMetaData({ id, data } as Node) - const showChangeBlock = !nodeMetaData.isTypeFixed && !nodesReadOnly + const showChangeBlock = !nodeMetaData.isTypeFixed && !nodeMetaData.isUndeletable && !nodesReadOnly const isChildNode = !!(data.isInIteration || data.isInLoop) const { data: workflowTools } = useAllWorkflowTools() diff --git a/web/app/components/workflow/nodes/_base/components/variable/__tests__/output-var-list.spec.tsx b/web/app/components/workflow/nodes/_base/components/variable/__tests__/output-var-list.spec.tsx new file mode 100644 index 0000000000..24464e4f08 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/variable/__tests__/output-var-list.spec.tsx @@ -0,0 +1,209 @@ +import type { OutputVar } from '../../../../code/types' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import OutputVarList from '../output-var-list' + +vi.mock('../var-type-picker', () => ({ + default: (props: { value: string, onChange: (v: string) => void, readonly: boolean }) => ( + + ), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { error: vi.fn() }, +})) + +describe('OutputVarList', () => { + const createOutputs = (entries: Record = {}): OutputVar => { + const result: OutputVar = {} + for (const [key, type] of Object.entries(entries)) + result[key] = { type: type as OutputVar[string]['type'], children: null } + return result + } + + // Render the component and trigger a rename at the given index. + // Returns the newOutputs passed to onChange. + const collectRenameResult = ( + outputs: OutputVar, + outputKeyOrders: string[], + renameIndex: number, + newName: string, + ): OutputVar => { + let captured: OutputVar | undefined + + render( + { captured = newOutputs }} + onRemove={vi.fn()} + />, + ) + + const inputs = screen.getAllByRole('textbox') + fireEvent.change(inputs[renameIndex], { target: { value: newName } }) + + return captured! + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('duplicate name handling', () => { + it('should preserve outputs entry when renaming one of two duplicate-name variables', () => { + const outputs = createOutputs({ var_1: 'string' }) + const outputKeyOrders = ['var_1', 'var_1'] + + const newOutputs = collectRenameResult(outputs, outputKeyOrders, 1, '') + + // Renamed entry gets a new key '' + expect(newOutputs['']).toEqual({ type: 'string', children: null }) + // Original key 'var_1' must survive because index 0 still uses it + expect(newOutputs.var_1).toEqual({ type: 'string', children: null }) + }) + + it('should delete old key when renamed entry is the only one using it', () => { + const outputs = createOutputs({ var_1: 'string', var_2: 'number' }) + const outputKeyOrders = ['var_1', 'var_2'] + + const newOutputs = collectRenameResult(outputs, outputKeyOrders, 1, 'renamed') + + expect(newOutputs.renamed).toEqual({ type: 'number', children: null }) + expect(newOutputs.var_2).toBeUndefined() + expect(newOutputs.var_1).toEqual({ type: 'string', children: null }) + }) + + it('should keep outputs key alive when duplicate is renamed back to unique name', () => { + // Step 1: rename var_2 -> var_1 (creates duplicate) + const outputs = createOutputs({ var_1: 'string', var_2: 'number' }) + const afterFirst = collectRenameResult(outputs, ['var_1', 'var_2'], 1, 'var_1') + + expect(afterFirst.var_2).toBeUndefined() + expect(afterFirst.var_1).toBeDefined() + + // Clean up first render before the second to avoid DOM collision + cleanup() + + // Step 2: rename second var_1 -> var_2 (restores unique names) + const afterSecond = collectRenameResult(afterFirst, ['var_1', 'var_1'], 1, 'var_2') + + // var_1 must survive because index 0 still uses it + expect(afterSecond.var_1).toBeDefined() + expect(afterSecond.var_2).toBeDefined() + }) + }) + + describe('removal with duplicate names', () => { + it('should call onRemove with correct index when removing a duplicate', () => { + const outputs = createOutputs({ var_1: 'string' }) + const onRemove = vi.fn() + + render( + , + ) + + // The second remove button (index 1 in the row) + const buttons = screen.getAllByRole('button') + fireEvent.click(buttons[1]) + + expect(onRemove).toHaveBeenCalledWith(1) + }) + }) + + describe('normal operation', () => { + it('should render one row per outputKeyOrders entry', () => { + const outputs = createOutputs({ a: 'string', b: 'number' }) + const onChange = vi.fn() + + render( + , + ) + + const inputs = screen.getAllByRole('textbox') + expect(inputs).toHaveLength(2) + expect(inputs[0]).toHaveValue('a') + expect(inputs[1]).toHaveValue('b') + }) + + it('should call onChange with updated outputs when renaming', () => { + const outputs = createOutputs({ var_1: 'string' }) + const onChange = vi.fn() + + render( + , + ) + + fireEvent.change(screen.getByRole('textbox'), { target: { value: 'new_name' } }) + + expect(onChange).toHaveBeenCalledWith( + expect.objectContaining({ + new_name: { type: 'string', children: null }, + }), + 0, + 'new_name', + ) + }) + + it('should call onRemove when remove button is clicked', () => { + const outputs = createOutputs({ var_1: 'string' }) + const onRemove = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByRole('button')) + + expect(onRemove).toHaveBeenCalledWith(0) + }) + + it('should render inputs as readonly when readonly is true', () => { + const outputs = createOutputs({ var_1: 'string' }) + + render( + , + ) + + expect(screen.getByRole('textbox')).toHaveAttribute('readonly') + }) + }) +}) diff --git a/web/app/components/workflow/nodes/_base/components/variable/output-var-list.tsx b/web/app/components/workflow/nodes/_base/components/variable/output-var-list.tsx index b9a1bc524e..79238aa6de 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/output-var-list.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/output-var-list.tsx @@ -59,7 +59,9 @@ const OutputVarList: FC = ({ const newOutputs = produce(outputs, (draft) => { draft[newKey] = draft[oldKey] - delete draft[oldKey] + // Only delete old key if no other entry shares this name + if (!list.some((item, i) => i !== index && item.variable === oldKey)) + delete draft[oldKey] }) onChange(newOutputs, index, newKey) } diff --git a/web/app/components/workflow/nodes/_base/hooks/__tests__/use-toggle-expend.spec.ts b/web/app/components/workflow/nodes/_base/hooks/__tests__/use-toggle-expend.spec.ts new file mode 100644 index 0000000000..266800d9aa --- /dev/null +++ b/web/app/components/workflow/nodes/_base/hooks/__tests__/use-toggle-expend.spec.ts @@ -0,0 +1,123 @@ +import { act, renderHook } from '@testing-library/react' +import { useRef } from 'react' +import useToggleExpend from '../use-toggle-expend' + +type HookProps = { + hasFooter?: boolean + isInNode?: boolean + clientHeight?: number +} + +/** + * Wrapper that provides a real ref whose `.current.clientHeight` is stubbed + * so we can verify the height math without a real DOM layout pass. + */ +function useHarness({ hasFooter, isInNode, clientHeight = 400 }: HookProps) { + const ref = useRef(null) + + // Stub a ref-like object so measurements are deterministic. + if (!ref.current) { + Object.defineProperty(ref, 'current', { + value: { clientHeight } as HTMLDivElement, + writable: true, + }) + } + + return useToggleExpend({ ref, hasFooter, isInNode }) +} + +describe('useToggleExpend', () => { + describe('collapsed state', () => { + it('returns empty wrapClassName and zero expand height when collapsed', () => { + const { result } = renderHook(() => useHarness({ clientHeight: 400 })) + + expect(result.current.isExpand).toBe(false) + expect(result.current.wrapClassName).toBe('') + expect(result.current.editorExpandHeight).toBe(0) + }) + }) + + describe('expanded state (node context)', () => { + it('uses fixed positioning inside a workflow node panel', () => { + const { result } = renderHook(() => + useHarness({ isInNode: true, clientHeight: 400 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + expect(result.current.isExpand).toBe(true) + expect(result.current.wrapClassName).toContain('fixed') + expect(result.current.wrapClassName).toContain('bg-components-panel-bg') + expect(result.current.wrapStyle).toEqual( + expect.objectContaining({ boxShadow: expect.any(String) }), + ) + }) + }) + + describe('expanded state (execution-log / webapp context)', () => { + it('fills its positioned ancestor edge-to-edge without hardcoded offsets', () => { + const { result } = renderHook(() => + useHarness({ isInNode: false, clientHeight: 400 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + // The expanded panel must fill the nearest positioned ancestor entirely + // (absolute + inset-0). Previously it used hardcoded `top-[52px]` which + // assumed a 52px header that does not exist in the conversation-log + // layout, causing the expanded panel to overlap the status bar above + // the editor (#34887). + expect(result.current.wrapClassName).toContain('absolute') + expect(result.current.wrapClassName).toContain('inset-0') + expect(result.current.wrapClassName).not.toMatch(/top-\[\d+px\]/) + expect(result.current.wrapClassName).not.toMatch(/left-\d+/) + expect(result.current.wrapClassName).not.toMatch(/right-\d+/) + expect(result.current.wrapClassName).toContain('bg-components-panel-bg') + }) + }) + + describe('expanded state height math', () => { + it('subtracts the 29px chrome when hasFooter is false', () => { + const { result } = renderHook(() => + useHarness({ hasFooter: false, clientHeight: 400 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + // 400 (clientHeight) - 29 (title bar) = 371 + expect(result.current.editorExpandHeight).toBe(371) + }) + + it('subtracts the 56px chrome when hasFooter is true', () => { + const { result } = renderHook(() => + useHarness({ hasFooter: true, clientHeight: 400 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + // 400 (clientHeight) - 56 (title bar + footer) = 344 + expect(result.current.editorExpandHeight).toBe(344) + }) + + it('never returns a negative height even if chrome exceeds wrap', () => { + const { result } = renderHook(() => + useHarness({ hasFooter: true, clientHeight: 20 }), + ) + + act(() => { + result.current.setIsExpand(true) + }) + + // 20 - 56 would be -36; clamped to 0. + expect(result.current.editorExpandHeight).toBe(0) + }) + }) +}) diff --git a/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts b/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts index 09b8fde0b5..a77af2daef 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-output-var-list.ts @@ -134,19 +134,24 @@ function useOutputVarList({ return } + const newOutputKeyOrders = outputKeyOrders.filter((_, i) => i !== index) const newInputs = produce(inputs, (draft: any) => { - delete draft[varKey][key] + // Only delete from outputs when no remaining entry shares this name + if (!newOutputKeyOrders.includes(key)) + delete draft[varKey][key] if ((inputs as CodeNodeType).type === BlockEnum.Code && (inputs as CodeNodeType).error_strategy === ErrorHandleTypeEnum.defaultValue && varKey === 'outputs') draft.default_value = getDefaultValue(draft as any) }) setInputs(newInputs) - onOutputKeyOrdersChange(outputKeyOrders.filter((_, i) => i !== index)) - const varId = nodesWithInspectVars.find(node => node.nodeId === id)?.vars.find((varItem) => { - return varItem.name === key - })?.id - if (varId) - deleteInspectVar(id, varId) + onOutputKeyOrdersChange(newOutputKeyOrders) + if (!newOutputKeyOrders.includes(key)) { + const varId = nodesWithInspectVars.find(node => node.nodeId === id)?.vars.find((varItem) => { + return varItem.name === key + })?.id + if (varId) + deleteInspectVar(id, varId) + } }, [outputKeyOrders, isVarUsedInNodes, id, inputs, setInputs, onOutputKeyOrdersChange, nodesWithInspectVars, deleteInspectVar, showRemoveVarConfirm, varKey]) return { diff --git a/web/app/components/workflow/nodes/_base/hooks/use-toggle-expend.ts b/web/app/components/workflow/nodes/_base/hooks/use-toggle-expend.ts index c123c00e2d..1afeb8db12 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-toggle-expend.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-toggle-expend.ts @@ -1,4 +1,4 @@ -import { useEffect, useState } from 'react' +import { useLayoutEffect, useState } from 'react' type Params = { ref?: React.RefObject @@ -6,30 +6,62 @@ type Params = { isInNode?: boolean } +// Chrome (title bar + optional footer) heights subtracted from the wrap so +// the editor body never paints underneath its own controls. +const CHROME_HEIGHT_WITH_FOOTER = 56 +const CHROME_HEIGHT_WITHOUT_FOOTER = 29 + +/** + * Controls the expand/collapse behavior of the code editor wrapper used across + * workflow nodes and execution-log panels. + * + * Returns: + * - `wrapClassName` / `wrapStyle` — positioning + shadow applied to the outer + * wrapper when the editor is expanded. + * - `editorExpandHeight` — height for the editor body (wrap minus chrome). + * - `isExpand` / `setIsExpand` — state + setter for the consumer. + * + * Height is measured via `useLayoutEffect` so the first expanded render + * already has the correct value — the previous `useEffect` implementation + * left the editor at the collapsed height for one paint on first expand. + */ const useToggleExpend = ({ ref, hasFooter = true, isInNode }: Params) => { const [isExpand, setIsExpand] = useState(false) - const [wrapHeight, setWrapHeight] = useState(ref?.current?.clientHeight) - const editorExpandHeight = isExpand ? wrapHeight! - (hasFooter ? 56 : 29) : 0 - useEffect(() => { + const [wrapHeight, setWrapHeight] = useState(undefined) + + useLayoutEffect(() => { if (!ref?.current) return - setWrapHeight(ref.current?.clientHeight) - }, [isExpand]) + setWrapHeight(ref.current.clientHeight) + }, [isExpand, ref]) + + const chromeHeight = hasFooter ? CHROME_HEIGHT_WITH_FOOTER : CHROME_HEIGHT_WITHOUT_FOOTER + const editorExpandHeight = isExpand && wrapHeight !== undefined + ? Math.max(0, wrapHeight - chromeHeight) + : 0 const wrapClassName = (() => { if (!isExpand) return '' if (isInNode) - return 'fixed z-10 right-[9px] top-[166px] bottom-[8px] p-4 bg-components-panel-bg rounded-xl' + return 'fixed z-10 right-[9px] top-[166px] bottom-[8px] p-4 bg-components-panel-bg rounded-xl' - return 'absolute z-10 left-4 right-6 top-[52px] bottom-0 pb-4 bg-components-panel-bg' + // Fill the nearest positioned ancestor entirely. Previously hardcoded + // `top-[52px] left-4 right-6` offsets assumed a 52px header above the + // scroll container — that assumption no longer holds in the conversation + // log (result-panel) layout, where the status bar above the editor is + // taller than 52px, causing the expanded panel to partially overlap the + // status bar (issue #34887). + return 'absolute z-10 inset-0 pb-4 bg-components-panel-bg' })() + const wrapStyle = isExpand ? { boxShadow: '0px 0px 12px -4px rgba(16, 24, 40, 0.05), 0px -3px 6px -2px rgba(16, 24, 40, 0.03)', } : {} + return { wrapClassName, wrapStyle, diff --git a/web/app/components/workflow/nodes/agent/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/agent/__tests__/node.spec.tsx new file mode 100644 index 0000000000..025c9bd84c --- /dev/null +++ b/web/app/components/workflow/nodes/agent/__tests__/node.spec.tsx @@ -0,0 +1,249 @@ +import type { ReactNode } from 'react' +import type { AgentNodeType } from '../types' +import type useConfig from '../use-config' +import type { StrategyParamItem } from '@/app/components/plugins/types' +import { render, screen } from '@testing-library/react' +import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { BlockEnum } from '@/app/components/workflow/types' +import { VarType } from '../../tool/types' +import Node from '../node' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockModelBar = vi.hoisted(() => vi.fn()) +const mockToolIcon = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('@/hooks/use-i18n', () => ({ + useRenderI18nObject: () => (value: string | { en_US?: string }) => typeof value === 'string' ? value : value.en_US || '', +})) + +vi.mock('../components/model-bar', () => ({ + ModelBar: (props: { provider?: string, model?: string, param: string }) => { + mockModelBar(props) + return
{props.provider ? `${props.param}:${props.provider}/${props.model}` : `${props.param}:empty-model`}
+ }, +})) + +vi.mock('../components/tool-icon', () => ({ + ToolIcon: (props: { providerName: string }) => { + mockToolIcon(props) + return
{`tool:${props.providerName}`}
+ }, +})) + +vi.mock('../../_base/components/group', () => ({ + Group: ({ label, children }: { label: ReactNode, children: ReactNode }) => ( +
+
{label}
+ {children} +
+ ), + GroupLabel: ({ className, children }: { className?: string, children: ReactNode }) =>
{children}
, +})) + +vi.mock('../../_base/components/setting-item', () => ({ + SettingItem: ({ + label, + status, + tooltip, + children, + }: { + label: ReactNode + status?: string + tooltip?: string + children?: ReactNode + }) => ( +
+ {`${label}:${status || 'normal'}:${tooltip || ''}`} + {children} +
+ ), +})) + +const createStrategyParam = (overrides: Partial = {}): StrategyParamItem => ({ + name: 'requiredModel', + type: FormTypeEnum.modelSelector, + required: true, + label: { en_US: 'Required Model' } as StrategyParamItem['label'], + help: { en_US: 'Required model help' } as StrategyParamItem['help'], + placeholder: { en_US: 'Required model placeholder' } as StrategyParamItem['placeholder'], + scope: 'global', + default: null, + options: [], + template: { enabled: false }, + auto_generate: { type: 'none' }, + ...overrides, +}) + +const createData = (overrides: Partial = {}): AgentNodeType => ({ + title: 'Agent', + desc: '', + type: BlockEnum.Agent, + output_schema: {}, + agent_strategy_provider_name: 'provider/agent', + agent_strategy_name: 'react', + agent_strategy_label: 'React Agent', + plugin_unique_identifier: 'provider/agent:1.0.0', + agent_parameters: { + optionalModel: { + type: VarType.constant, + value: { provider: 'openai', model: 'gpt-4o' }, + }, + toolParam: { + type: VarType.constant, + value: { provider_name: 'author/tool-a' }, + }, + multiToolParam: { + type: VarType.constant, + value: [ + { provider_name: 'author/tool-b' }, + { provider_name: 'author/tool-c' }, + ], + }, + }, + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + setInputs: vi.fn(), + handleVarListChange: vi.fn(), + handleAddVariable: vi.fn(), + currentStrategy: { + identity: { + author: 'provider', + name: 'react', + icon: 'icon', + label: { en_US: 'React Agent' } as StrategyParamItem['label'], + provider: 'provider/agent', + }, + parameters: [ + createStrategyParam(), + createStrategyParam({ + name: 'optionalModel', + required: false, + }), + createStrategyParam({ + name: 'toolParam', + type: FormTypeEnum.toolSelector, + required: false, + }), + createStrategyParam({ + name: 'multiToolParam', + type: FormTypeEnum.multiToolSelector, + required: false, + }), + ], + description: { en_US: 'agent description' } as StrategyParamItem['label'], + output_schema: {}, + features: [], + }, + formData: {}, + onFormChange: vi.fn(), + currentStrategyStatus: { + plugin: { source: 'marketplace', installed: true }, + isExistInPlugin: false, + }, + strategyProvider: undefined, + pluginDetail: ({ + declaration: { + label: { en_US: 'Plugin Marketplace' } as never, + }, + } as never), + availableVars: [], + availableNodesWithParent: [], + outputSchema: [], + handleMemoryChange: vi.fn(), + isChatMode: true, + ...overrides, +}) + +describe('agent/node', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders the not-set state when no strategy is configured', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + inputs: createData({ + agent_strategy_name: undefined, + agent_strategy_label: undefined, + agent_parameters: {}, + }), + currentStrategy: undefined, + })) + + render( + , + ) + + expect(screen.getByText('workflow.nodes.agent.strategyNotSet:normal:')).toBeInTheDocument() + expect(mockModelBar).not.toHaveBeenCalled() + expect(mockToolIcon).not.toHaveBeenCalled() + }) + + it('renders strategy status, required and selected model bars, and tool icons', () => { + render( + , + ) + + expect(screen.getByText(/workflow.nodes.agent.strategy.shortLabel:error:/)).toHaveTextContent('React Agent') + expect(screen.getByText(/workflow.nodes.agent.strategy.shortLabel:error:/)).toHaveTextContent('Plugin Marketplace') + expect(screen.getByText('requiredModel:empty-model')).toBeInTheDocument() + expect(screen.getByText('optionalModel:openai/gpt-4o')).toBeInTheDocument() + expect(screen.getByText('tool:author/tool-a')).toBeInTheDocument() + expect(screen.getByText('tool:author/tool-b')).toBeInTheDocument() + expect(screen.getByText('tool:author/tool-c')).toBeInTheDocument() + expect(mockModelBar).toHaveBeenCalledTimes(2) + expect(mockToolIcon).toHaveBeenCalledTimes(3) + }) + + it('skips optional models and empty tool values when no configuration is provided', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + inputs: createData({ + agent_parameters: {}, + }), + currentStrategy: { + ...createConfigResult().currentStrategy!, + parameters: [ + createStrategyParam({ + name: 'optionalModel', + required: false, + }), + createStrategyParam({ + name: 'toolParam', + type: FormTypeEnum.toolSelector, + required: false, + }), + ], + }, + currentStrategyStatus: { + plugin: { source: 'marketplace', installed: true }, + isExistInPlugin: true, + }, + })) + + render( + , + ) + + expect(mockModelBar).not.toHaveBeenCalled() + expect(mockToolIcon).not.toHaveBeenCalled() + expect(screen.queryByText('optionalModel:empty-model')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/agent/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/agent/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..15001b4757 --- /dev/null +++ b/web/app/components/workflow/nodes/agent/__tests__/panel.spec.tsx @@ -0,0 +1,297 @@ +import type { ReactNode } from 'react' +import type { AgentNodeType } from '../types' +import type useConfig from '../use-config' +import type { StrategyParamItem } from '@/app/components/plugins/types' +import type { NodePanelProps } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' +import { AgentFeature } from '../types' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockResetEditor = vi.hoisted(() => vi.fn()) +const mockAgentStrategy = vi.hoisted(() => vi.fn()) +const mockMemoryConfig = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../../../store', () => ({ + useStore: (selector: (state: { setControlPromptEditorRerenderKey: typeof mockResetEditor }) => unknown) => selector({ + setControlPromptEditorRerenderKey: mockResetEditor, + }), +})) + +vi.mock('../../_base/components/agent-strategy', () => ({ + AgentStrategy: (props: { + strategy?: { + agent_strategy_provider_name: string + agent_strategy_name: string + agent_strategy_label: string + agent_output_schema: AgentNodeType['output_schema'] + plugin_unique_identifier: string + meta?: AgentNodeType['meta'] + } + formSchema: Array<{ variable: string, tooltip?: StrategyParamItem['help'] }> + formValue: Record + onStrategyChange: (strategy: { + agent_strategy_provider_name: string + agent_strategy_name: string + agent_strategy_label: string + agent_output_schema: AgentNodeType['output_schema'] + plugin_unique_identifier: string + meta?: AgentNodeType['meta'] + }) => void + onFormValueChange: (value: Record) => void + }) => { + mockAgentStrategy(props) + return ( +
+ + +
+ ) + }, +})) + +vi.mock('../../_base/components/memory-config', () => ({ + __esModule: true, + default: (props: { + readonly?: boolean + config: { data?: AgentNodeType['memory'] } + onChange: (value?: AgentNodeType['memory']) => void + }) => { + mockMemoryConfig(props) + return ( + + ) + }, +})) + +vi.mock('../../_base/components/output-vars', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) =>
{children}
, + VarItem: ({ name, type, description }: { name: string, type: string, description?: string }) => ( +
{`${name}:${type}:${description || ''}`}
+ ), +})) + +const createStrategyParam = (overrides: Partial = {}): StrategyParamItem => ({ + name: 'instruction', + type: FormTypeEnum.any, + required: true, + label: { en_US: 'Instruction' } as StrategyParamItem['label'], + help: { en_US: 'Instruction help' } as StrategyParamItem['help'], + placeholder: { en_US: 'Instruction placeholder' } as StrategyParamItem['placeholder'], + scope: 'global', + default: null, + options: [], + template: { enabled: false }, + auto_generate: { type: 'none' }, + ...overrides, +}) + +const createData = (overrides: Partial = {}): AgentNodeType => ({ + title: 'Agent', + desc: '', + type: BlockEnum.Agent, + output_schema: { + properties: { + summary: { + type: 'string', + description: 'summary output', + }, + }, + }, + agent_strategy_provider_name: 'provider/agent', + agent_strategy_name: 'react', + agent_strategy_label: 'React Agent', + plugin_unique_identifier: 'provider/agent:1.0.0', + meta: { version: '1.0.0' } as AgentNodeType['meta'], + memory: { + window: { + enabled: false, + size: 3, + }, + query_prompt_template: '', + } as AgentNodeType['memory'], + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + setInputs: vi.fn(), + handleVarListChange: vi.fn(), + handleAddVariable: vi.fn(), + currentStrategy: { + identity: { + author: 'provider', + name: 'react', + icon: 'icon', + label: { en_US: 'React Agent' } as StrategyParamItem['label'], + provider: 'provider/agent', + }, + parameters: [ + createStrategyParam(), + createStrategyParam({ + name: 'modelParam', + type: FormTypeEnum.modelSelector, + required: false, + }), + ], + description: { en_US: 'agent description' } as StrategyParamItem['label'], + output_schema: {}, + features: [AgentFeature.HISTORY_MESSAGES], + }, + formData: { + instruction: 'Plan and answer', + }, + onFormChange: vi.fn(), + currentStrategyStatus: { + plugin: { source: 'marketplace', installed: true }, + isExistInPlugin: true, + }, + strategyProvider: undefined, + pluginDetail: undefined, + availableVars: [], + availableNodesWithParent: [], + outputSchema: [{ + name: 'summary', + type: 'String', + description: 'summary output', + }], + handleMemoryChange: vi.fn(), + isChatMode: true, + ...overrides, +}) + +const panelProps = {} as NodePanelProps['panelProps'] + +describe('agent/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders strategy data, forwards strategy and form updates, and exposes output vars', async () => { + const user = userEvent.setup() + const setInputs = vi.fn() + const onFormChange = vi.fn() + const handleMemoryChange = vi.fn() + + mockUseConfig.mockReturnValue(createConfigResult({ + setInputs, + onFormChange, + handleMemoryChange, + })) + + render( + , + ) + + expect(screen.getByText('text:String:workflow.nodes.agent.outputVars.text')).toBeInTheDocument() + expect(screen.getByText('usage:object:workflow.nodes.agent.outputVars.usage')).toBeInTheDocument() + expect(screen.getByText('files:Array[File]:workflow.nodes.agent.outputVars.files.title')).toBeInTheDocument() + expect(screen.getByText('json:Array[Object]:workflow.nodes.agent.outputVars.json')).toBeInTheDocument() + expect(screen.getByText('summary:String:summary output')).toBeInTheDocument() + expect(mockAgentStrategy).toHaveBeenCalledWith(expect.objectContaining({ + formSchema: expect.arrayContaining([ + expect.objectContaining({ + variable: 'instruction', + tooltip: { en_US: 'Instruction help' }, + }), + expect.objectContaining({ + variable: 'modelParam', + }), + ]), + formValue: { + instruction: 'Plan and answer', + }, + })) + + await user.click(screen.getByRole('button', { name: 'change-strategy' })) + await user.click(screen.getByRole('button', { name: 'change-form' })) + await user.click(screen.getByRole('button', { name: 'change-memory' })) + + expect(setInputs).toHaveBeenCalledWith(expect.objectContaining({ + agent_strategy_provider_name: 'provider/updated', + agent_strategy_name: 'updated', + agent_strategy_label: 'Updated Strategy', + plugin_unique_identifier: 'provider/updated:1.0.0', + output_schema: expect.objectContaining({ + properties: expect.objectContaining({ + structured: expect.any(Object), + }), + }), + })) + expect(onFormChange).toHaveBeenCalledWith({ instruction: 'Use the tool' }) + expect(handleMemoryChange).toHaveBeenCalledWith(expect.objectContaining({ + query_prompt_template: 'history', + })) + expect(mockResetEditor).toHaveBeenCalledTimes(1) + }) + + it('hides memory config when chat mode support is unavailable', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + isChatMode: false, + currentStrategy: { + ...createConfigResult().currentStrategy!, + features: [], + }, + })) + + render( + , + ) + + expect(screen.queryByRole('button', { name: 'change-memory' })).not.toBeInTheDocument() + expect(mockMemoryConfig).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/nodes/agent/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/agent/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..9e09ab6d78 --- /dev/null +++ b/web/app/components/workflow/nodes/agent/__tests__/use-config.spec.ts @@ -0,0 +1,422 @@ +import type { AgentNodeType } from '../types' +import type { StrategyParamItem } from '@/app/components/plugins/types' +import { act, renderHook, waitFor } from '@testing-library/react' +import { FormTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { BlockEnum, VarType as WorkflowVarType } from '@/app/components/workflow/types' +import { VarType } from '../../tool/types' +import useConfig, { useStrategyInfo } from '../use-config' + +const mockUseNodesReadOnly = vi.hoisted(() => vi.fn()) +const mockUseIsChatMode = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) +const mockUseVarList = vi.hoisted(() => vi.fn()) +const mockUseAvailableVarList = vi.hoisted(() => vi.fn()) +const mockUseStrategyProviderDetail = vi.hoisted(() => vi.fn()) +const mockUseFetchPluginsInMarketPlaceByIds = vi.hoisted(() => vi.fn()) +const mockUseCheckInstalled = vi.hoisted(() => vi.fn()) +const mockGenerateAgentToolValue = vi.hoisted(() => vi.fn()) +const mockToolParametersToFormSchemas = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: (...args: unknown[]) => mockUseNodesReadOnly(...args), + useIsChatMode: (...args: unknown[]) => mockUseIsChatMode(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseVarList(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-available-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseAvailableVarList(...args), +})) + +vi.mock('@/service/use-strategy', () => ({ + useStrategyProviderDetail: (...args: unknown[]) => mockUseStrategyProviderDetail(...args), +})) + +vi.mock('@/service/use-plugins', () => ({ + useFetchPluginsInMarketPlaceByIds: (...args: unknown[]) => mockUseFetchPluginsInMarketPlaceByIds(...args), + useCheckInstalled: (...args: unknown[]) => mockUseCheckInstalled(...args), +})) + +vi.mock('@/app/components/tools/utils/to-form-schema', () => ({ + generateAgentToolValue: (...args: unknown[]) => mockGenerateAgentToolValue(...args), + toolParametersToFormSchemas: (...args: unknown[]) => mockToolParametersToFormSchemas(...args), +})) + +const createStrategyParam = (overrides: Partial = {}): StrategyParamItem => ({ + name: 'instruction', + type: FormTypeEnum.any, + required: true, + label: { en_US: 'Instruction' } as StrategyParamItem['label'], + help: { en_US: 'Instruction help' } as StrategyParamItem['help'], + placeholder: { en_US: 'Instruction placeholder' } as StrategyParamItem['placeholder'], + scope: 'global', + default: null, + options: [], + template: { enabled: false }, + auto_generate: { type: 'none' }, + ...overrides, +}) + +const createToolValue = () => ({ + settings: { + api_key: 'secret', + }, + parameters: { + query: 'weather', + }, + schemas: [ + { + variable: 'api_key', + form: 'form', + }, + { + variable: 'query', + form: 'llm', + }, + ], +}) + +const createData = (overrides: Partial = {}): AgentNodeType => ({ + title: 'Agent', + desc: '', + type: BlockEnum.Agent, + output_schema: { + properties: { + summary: { + type: 'string', + description: 'summary output', + }, + items: { + type: 'array', + items: { + type: 'number', + }, + description: 'items output', + }, + }, + }, + agent_strategy_provider_name: 'provider/agent', + agent_strategy_name: 'react', + agent_strategy_label: 'React Agent', + plugin_unique_identifier: 'provider/agent:1.0.0', + agent_parameters: { + instruction: { + type: VarType.variable, + value: '#start.topic#', + }, + modelParam: { + type: VarType.constant, + value: { + provider: 'openai', + model: 'gpt-4o', + }, + }, + }, + meta: { version: '1.0.0' } as AgentNodeType['meta'], + ...overrides, +}) + +describe('agent/use-config', () => { + const providerRefetch = vi.fn() + const marketplaceRefetch = vi.fn() + const setInputs = vi.fn() + const handleVarListChange = vi.fn() + const handleAddVariable = vi.fn() + let currentInputs: AgentNodeType + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createData({ + tool_node_version: '2', + }) + + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false, getNodesReadOnly: () => false }) + mockUseIsChatMode.mockReturnValue(true) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs, + })) + mockUseVarList.mockReturnValue({ + handleVarListChange, + handleAddVariable, + } as never) + mockUseAvailableVarList.mockReturnValue({ + availableVars: [{ + nodeId: 'node-1', + title: 'Start', + vars: [{ + variable: 'topic', + type: WorkflowVarType.string, + }], + }], + availableNodesWithParent: [{ + nodeId: 'node-1', + title: 'Start', + }], + } as never) + mockUseStrategyProviderDetail.mockReturnValue({ + isLoading: false, + isError: false, + data: { + declaration: { + strategies: [{ + identity: { + name: 'react', + }, + parameters: [ + createStrategyParam(), + createStrategyParam({ + name: 'modelParam', + type: FormTypeEnum.modelSelector, + required: false, + }), + ], + }], + }, + }, + refetch: providerRefetch, + } as never) + mockUseFetchPluginsInMarketPlaceByIds.mockReturnValue({ + isLoading: false, + data: { + data: { + plugins: [{ id: 'provider/agent' }], + }, + }, + refetch: marketplaceRefetch, + } as never) + mockUseCheckInstalled.mockReturnValue({ + data: { + plugins: [{ + declaration: { + label: { en_US: 'Installed Agent Plugin' }, + }, + }], + }, + } as never) + mockToolParametersToFormSchemas.mockImplementation(value => value as never) + mockGenerateAgentToolValue.mockImplementation((_value, schemas, isLLM) => ({ + kind: isLLM ? 'llm' : 'setting', + fields: (schemas as Array<{ variable: string }>).map(item => item.variable), + }) as never) + }) + + it('returns an undefined strategy status while strategy data is still loading and can refetch dependencies', () => { + mockUseStrategyProviderDetail.mockReturnValue({ + isLoading: true, + isError: false, + data: undefined, + refetch: providerRefetch, + } as never) + + const { result } = renderHook(() => useStrategyInfo('provider/agent', 'react')) + + expect(result.current.strategyStatus).toBeUndefined() + expect(result.current.strategy).toBeUndefined() + + act(() => { + result.current.refetch() + }) + + expect(providerRefetch).toHaveBeenCalledTimes(1) + expect(marketplaceRefetch).toHaveBeenCalledTimes(1) + }) + + it('resolves strategy status for external plugins that are missing or not installed', () => { + mockUseStrategyProviderDetail.mockReturnValue({ + isLoading: false, + isError: true, + data: { + declaration: { + strategies: [], + }, + }, + refetch: providerRefetch, + } as never) + mockUseFetchPluginsInMarketPlaceByIds.mockReturnValue({ + isLoading: false, + data: { + data: { + plugins: [], + }, + }, + refetch: marketplaceRefetch, + } as never) + + const { result } = renderHook(() => useStrategyInfo('provider/agent', 'react')) + + expect(result.current.strategyStatus).toEqual({ + plugin: { + source: 'external', + installed: false, + }, + isExistInPlugin: false, + }) + }) + + it('exposes derived form data, strategy state, output schema, and setter helpers', () => { + const { result } = renderHook(() => useConfig('agent-node', currentInputs)) + + expect(result.current.readOnly).toBe(false) + expect(result.current.isChatMode).toBe(true) + expect(result.current.formData).toEqual({ + instruction: '#start.topic#', + modelParam: { + provider: 'openai', + model: 'gpt-4o', + }, + }) + expect(result.current.currentStrategyStatus).toEqual({ + plugin: { + source: 'marketplace', + installed: true, + }, + isExistInPlugin: true, + }) + expect(result.current.availableVars).toHaveLength(1) + expect(result.current.availableNodesWithParent).toEqual([{ + nodeId: 'node-1', + title: 'Start', + }]) + expect(result.current.outputSchema).toEqual([ + { name: 'summary', type: 'String', description: 'summary output' }, + { name: 'items', type: 'Array[Number]', description: 'items output' }, + ]) + + setInputs.mockClear() + + act(() => { + result.current.onFormChange({ + instruction: '#start.updated#', + modelParam: { + provider: 'anthropic', + model: 'claude-sonnet', + }, + }) + result.current.handleMemoryChange({ + window: { + enabled: true, + size: 6, + }, + query_prompt_template: 'history', + } as AgentNodeType['memory']) + }) + + expect(setInputs).toHaveBeenNthCalledWith(1, expect.objectContaining({ + agent_parameters: { + instruction: { + type: VarType.variable, + value: '#start.updated#', + }, + modelParam: { + type: VarType.constant, + value: { + provider: 'anthropic', + model: 'claude-sonnet', + }, + }, + }, + })) + expect(setInputs).toHaveBeenNthCalledWith(2, expect.objectContaining({ + memory: { + window: { + enabled: true, + size: 6, + }, + query_prompt_template: 'history', + }, + })) + expect(result.current.handleVarListChange).toBe(handleVarListChange) + expect(result.current.handleAddVariable).toBe(handleAddVariable) + expect(result.current.pluginDetail).toEqual({ + declaration: { + label: { en_US: 'Installed Agent Plugin' }, + }, + }) + }) + + it('formats legacy tool selector values before exposing the node config', async () => { + currentInputs = createData({ + tool_node_version: undefined, + agent_parameters: { + toolParam: { + type: VarType.constant, + value: createToolValue(), + }, + multiToolParam: { + type: VarType.constant, + value: [createToolValue()], + }, + }, + }) + mockUseStrategyProviderDetail.mockReturnValue({ + isLoading: false, + isError: false, + data: { + declaration: { + strategies: [{ + identity: { + name: 'react', + }, + parameters: [ + createStrategyParam({ + name: 'toolParam', + type: FormTypeEnum.toolSelector, + required: false, + }), + createStrategyParam({ + name: 'multiToolParam', + type: FormTypeEnum.multiToolSelector, + required: false, + }), + ], + }], + }, + }, + refetch: providerRefetch, + } as never) + + renderHook(() => useConfig('agent-node', currentInputs)) + + await waitFor(() => { + expect(setInputs).toHaveBeenCalledWith(expect.objectContaining({ + tool_node_version: '2', + agent_parameters: expect.objectContaining({ + toolParam: expect.objectContaining({ + value: expect.objectContaining({ + settings: { + kind: 'setting', + fields: ['api_key'], + }, + parameters: { + kind: 'llm', + fields: ['query'], + }, + }), + }), + multiToolParam: expect.objectContaining({ + value: [expect.objectContaining({ + settings: { + kind: 'setting', + fields: ['api_key'], + }, + parameters: { + kind: 'llm', + fields: ['query'], + }, + })], + }), + }), + })) + }) + }) +}) diff --git a/web/app/components/workflow/nodes/agent/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/agent/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..33075e685f --- /dev/null +++ b/web/app/components/workflow/nodes/agent/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,144 @@ +import type { AgentNodeType } from '../types' +import type { InputVar } from '@/app/components/workflow/types' +import { renderHook } from '@testing-library/react' +import formatTracing from '@/app/components/workflow/run/utils/format-log' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' +import useNodeCrud from '../../_base/hooks/use-node-crud' +import { VarType } from '../../tool/types' +import { useStrategyInfo } from '../use-config' +import useSingleRunFormParams from '../use-single-run-form-params' + +vi.mock('@/app/components/workflow/run/utils/format-log', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('../../_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('../use-config', async () => { + const actual = await vi.importActual('../use-config') + return { + ...actual, + useStrategyInfo: vi.fn(), + } +}) + +const mockFormatTracing = vi.mocked(formatTracing) +const mockUseNodeCrud = vi.mocked(useNodeCrud) +const mockUseStrategyInfo = vi.mocked(useStrategyInfo) + +const createData = (overrides: Partial = {}): AgentNodeType => ({ + title: 'Agent', + desc: '', + type: BlockEnum.Agent, + output_schema: {}, + agent_strategy_provider_name: 'provider/agent', + agent_strategy_name: 'react', + agent_strategy_label: 'React Agent', + agent_parameters: { + prompt: { + type: VarType.variable, + value: '#start.topic#', + }, + summary: { + type: VarType.variable, + value: '#node-2.answer#', + }, + count: { + type: VarType.constant, + value: 2, + }, + }, + ...overrides, +}) + +describe('agent/use-single-run-form-params', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodeCrud.mockReturnValue({ + inputs: createData(), + setInputs: vi.fn(), + } as unknown as ReturnType) + mockUseStrategyInfo.mockReturnValue({ + strategyProvider: undefined, + strategy: { + parameters: [ + { name: 'prompt', type: 'string' }, + { name: 'summary', type: 'string' }, + { name: 'count', type: 'number' }, + ], + }, + strategyStatus: undefined, + refetch: vi.fn(), + } as unknown as ReturnType) + mockFormatTracing.mockReturnValue([{ + id: 'agent-node', + status: 'succeeded', + }] as unknown as ReturnType) + }) + + it('builds a single-run variable form, returns node info, and skips malformed dependent vars', () => { + const setRunInputData = vi.fn() + const getInputVars = vi.fn<() => InputVar[]>(() => [ + { + label: 'Prompt', + variable: '#start.topic#', + type: InputVarType.textInput, + required: true, + }, + { + label: 'Broken', + variable: undefined as unknown as string, + type: InputVarType.textInput, + required: false, + }, + ]) + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'agent-node', + payload: createData(), + runInputData: { topic: 'finance' }, + runInputDataRef: { current: { topic: 'finance' } }, + getInputVars, + setRunInputData, + toVarInputs: () => [], + runResult: { id: 'trace-1' } as never, + })) + + expect(getInputVars).toHaveBeenCalledWith(['#start.topic#', '#node-2.answer#']) + expect(result.current.forms).toHaveLength(1) + expect(result.current.forms[0].inputs).toHaveLength(2) + expect(result.current.forms[0].values).toEqual({ topic: 'finance' }) + expect(result.current.nodeInfo).toEqual({ + id: 'agent-node', + status: 'succeeded', + }) + + result.current.forms[0].onChange({ topic: 'updated' }) + + expect(setRunInputData).toHaveBeenCalledWith({ topic: 'updated' }) + expect(result.current.getDependentVars()).toEqual([ + ['start', 'topic'], + ]) + }) + + it('returns an empty form list when no variable input is required and no run result is available', () => { + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'agent-node', + payload: createData(), + runInputData: {}, + runInputDataRef: { current: {} }, + getInputVars: () => [], + setRunInputData: vi.fn(), + toVarInputs: () => [], + runResult: undefined as never, + })) + + expect(result.current.forms).toEqual([]) + expect(result.current.nodeInfo).toBeUndefined() + expect(result.current.getDependentVars()).toEqual([]) + }) +}) diff --git a/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx b/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx new file mode 100644 index 0000000000..d85f54ed19 --- /dev/null +++ b/web/app/components/workflow/nodes/agent/components/__tests__/model-bar.spec.tsx @@ -0,0 +1,78 @@ +import type { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { fireEvent, render, screen } from '@testing-library/react' +import { ModelBar } from '../model-bar' + +type ModelProviderItem = { + provider: string + models: Array<{ model: string }> +} + +const mockModelLists = new Map() + +vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ + useModelList: (modelType: ModelTypeEnum) => ({ + data: mockModelLists.get(modelType) || [], + }), +})) + +vi.mock('@/app/components/header/account-setting/model-provider-page/model-selector', () => ({ + default: ({ + defaultModel, + modelList, + }: { + defaultModel?: { provider: string, model: string } + modelList: ModelProviderItem[] + }) => ( +
+ {defaultModel ? `${defaultModel.provider}/${defaultModel.model}` : 'no-model'} + : + {modelList.length} +
+ ), +})) + +vi.mock('@/app/components/header/indicator', () => ({ + default: ({ color }: { color: string }) =>
{`indicator:${color}`}
, +})) + +describe('agent/model-bar', () => { + beforeEach(() => { + vi.clearAllMocks() + mockModelLists.clear() + mockModelLists.set('llm' as ModelTypeEnum, [{ provider: 'openai', models: [{ model: 'gpt-4o' }] }]) + mockModelLists.set('moderation' as ModelTypeEnum, []) + mockModelLists.set('rerank' as ModelTypeEnum, []) + mockModelLists.set('speech2text' as ModelTypeEnum, []) + mockModelLists.set('text-embedding' as ModelTypeEnum, []) + mockModelLists.set('tts' as ModelTypeEnum, []) + }) + + it('should render an empty readonly selector with a warning when no model is selected', () => { + render() + + const emptySelector = screen.getByText((_, element) => element?.textContent === 'no-model:0') + + fireEvent.mouseEnter(emptySelector) + + expect(emptySelector).toBeInTheDocument() + expect(screen.getByText('indicator:red')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.agent.modelNotSelected')).toBeInTheDocument() + }) + + it('should render the selected model without warning when it is installed', () => { + render() + + expect(screen.getByText('openai/gpt-4o:1')).toBeInTheDocument() + expect(screen.queryByText('indicator:red')).not.toBeInTheDocument() + }) + + it('should show a warning tooltip when the selected model is not installed', () => { + render() + + fireEvent.mouseEnter(screen.getByText('openai/gpt-4.1:1')) + + expect(screen.getByText('openai/gpt-4.1:1')).toBeInTheDocument() + expect(screen.getByText('indicator:red')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.agent.modelNotInstallTooltip')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx b/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx new file mode 100644 index 0000000000..30a12bb528 --- /dev/null +++ b/web/app/components/workflow/nodes/agent/components/__tests__/tool-icon.spec.tsx @@ -0,0 +1,113 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { ToolIcon } from '../tool-icon' + +type ToolProvider = { + id?: string + name?: string + icon?: string | { content: string, background: string } + is_team_authorization?: boolean +} + +let mockBuiltInTools: ToolProvider[] | undefined +let mockCustomTools: ToolProvider[] | undefined +let mockWorkflowTools: ToolProvider[] | undefined +let mockMcpTools: ToolProvider[] | undefined +let mockMarketplaceIcon: string | { content: string, background: string } | undefined + +vi.mock('@/service/use-tools', () => ({ + useAllBuiltInTools: () => ({ data: mockBuiltInTools }), + useAllCustomTools: () => ({ data: mockCustomTools }), + useAllWorkflowTools: () => ({ data: mockWorkflowTools }), + useAllMCPTools: () => ({ data: mockMcpTools }), +})) + +vi.mock('@/app/components/base/app-icon', () => ({ + default: ({ + icon, + background, + className, + }: { + icon?: string + background?: string + className?: string + }) =>
{`app-icon:${background}:${icon}`}
, +})) + +vi.mock('@/app/components/base/icons/src/vender/other', () => ({ + Group: ({ className }: { className?: string }) =>
group-icon
, +})) + +vi.mock('@/app/components/header/indicator', () => ({ + default: ({ color }: { color: string }) =>
{`indicator:${color}`}
, +})) + +vi.mock('@/utils/get-icon', () => ({ + getIconFromMarketPlace: () => mockMarketplaceIcon, +})) + +describe('agent/tool-icon', () => { + beforeEach(() => { + vi.clearAllMocks() + mockBuiltInTools = [] + mockCustomTools = [] + mockWorkflowTools = [] + mockMcpTools = [] + mockMarketplaceIcon = undefined + }) + + it('should render a string icon, recover from fetch errors, and keep installed tools warning-free', () => { + mockBuiltInTools = [{ + name: 'author/tool-a', + icon: 'https://example.com/tool-a.png', + is_team_authorization: true, + }] + + render() + + const icon = screen.getByRole('img', { name: 'tool icon' }) + expect(icon).toHaveAttribute('src', 'https://example.com/tool-a.png') + expect(screen.queryByText(/indicator:/)).not.toBeInTheDocument() + + fireEvent.mouseEnter(icon) + expect(screen.queryByText('workflow.nodes.agent.toolNotInstallTooltip')).not.toBeInTheDocument() + + fireEvent.error(icon) + expect(screen.getByText('group-icon')).toBeInTheDocument() + }) + + it('should render authorization and installation warnings with the correct icon sources', () => { + mockWorkflowTools = [{ + id: 'author/tool-b', + icon: { + content: 'B', + background: '#fff', + }, + is_team_authorization: false, + }] + + const { rerender } = render() + + fireEvent.mouseEnter(screen.getByText('app-icon:#fff:B')) + expect(screen.getByText('indicator:yellow')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.agent.toolNotAuthorizedTooltip:{"tool":"tool-b"}')).toBeInTheDocument() + + mockWorkflowTools = [] + mockMarketplaceIcon = 'https://example.com/market-tool.png' + rerender() + + const marketplaceIcon = screen.getByRole('img', { name: 'tool icon' }) + fireEvent.mouseEnter(marketplaceIcon) + expect(marketplaceIcon).toHaveAttribute('src', 'https://example.com/market-tool.png') + expect(screen.getByText('indicator:red')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.agent.toolNotInstallTooltip:{"tool":"tool-c"}')).toBeInTheDocument() + }) + + it('should fall back to the group icon while tool data is still loading', () => { + mockBuiltInTools = undefined + + render() + + expect(screen.getByText('group-icon')).toBeInTheDocument() + expect(screen.queryByText(/indicator:/)).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/answer/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/answer/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..b5fbdf163f --- /dev/null +++ b/web/app/components/workflow/nodes/answer/__tests__/panel.spec.tsx @@ -0,0 +1,92 @@ +import type { AnswerNodeType } from '../types' +import type { PanelProps } from '@/types/workflow' +import { fireEvent, render, screen } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' + +type MockEditorProps = { + readOnly: boolean + title: string + value: string + onChange: (value: string) => void + nodesOutputVars: unknown[] + availableNodes: unknown[] +} + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockUseAvailableVarList = vi.hoisted(() => vi.fn()) +const mockEditorRender = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-available-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseAvailableVarList(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/prompt/editor', () => ({ + __esModule: true, + default: (props: MockEditorProps) => { + mockEditorRender(props) + return ( + + ) + }, +})) + +const createData = (overrides: Partial = {}): AnswerNodeType => ({ + title: 'Answer', + desc: '', + type: BlockEnum.Answer, + variables: [], + answer: 'Initial answer', + ...overrides, +}) + +describe('AnswerPanel', () => { + const handleAnswerChange = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue({ + readOnly: false, + inputs: createData(), + handleAnswerChange, + filterVar: vi.fn(), + }) + mockUseAvailableVarList.mockReturnValue({ + availableVars: [{ variable: 'context', type: 'string' }], + availableNodesWithParent: [{ value: 'node-1', label: 'Node 1' }], + }) + }) + + it('should pass editor state and available variables through to the prompt editor', () => { + render() + + expect(screen.getByRole('button', { name: 'workflow.nodes.answer.answer:Initial answer' })).toBeInTheDocument() + expect(mockEditorRender).toHaveBeenCalledWith(expect.objectContaining({ + readOnly: false, + title: 'workflow.nodes.answer.answer', + value: 'Initial answer', + nodesOutputVars: [{ variable: 'context', type: 'string' }], + availableNodes: [{ value: 'node-1', label: 'Node 1' }], + isSupportFileVar: true, + justVar: true, + })) + }) + + it('should delegate answer edits to use-config', () => { + render() + + fireEvent.click(screen.getByRole('button', { name: 'workflow.nodes.answer.answer:Initial answer' })) + + expect(handleAnswerChange).toHaveBeenCalledWith('Updated answer') + }) +}) diff --git a/web/app/components/workflow/nodes/answer/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/answer/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..106355e8c5 --- /dev/null +++ b/web/app/components/workflow/nodes/answer/__tests__/use-config.spec.ts @@ -0,0 +1,81 @@ +import type { AnswerNodeType } from '../types' +import { act, renderHook } from '@testing-library/react' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import useConfig from '../use-config' + +const mockUseNodesReadOnly = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) +const mockUseVarList = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: () => mockUseNodesReadOnly(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseVarList(...args), +})) + +const createPayload = (overrides: Partial = {}): AnswerNodeType => ({ + title: 'Answer', + desc: '', + type: BlockEnum.Answer, + variables: [], + answer: 'Initial answer', + ...overrides, +}) + +describe('answer/use-config', () => { + const mockSetInputs = vi.fn() + const mockHandleVarListChange = vi.fn() + const mockHandleAddVariable = vi.fn() + let currentInputs: AnswerNodeType + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false }) + mockUseNodeCrud.mockReturnValue({ + inputs: currentInputs, + setInputs: mockSetInputs, + }) + mockUseVarList.mockReturnValue({ + handleVarListChange: mockHandleVarListChange, + handleAddVariable: mockHandleAddVariable, + }) + }) + + it('should update the answer text and expose var-list handlers', () => { + const { result } = renderHook(() => useConfig('answer-node', currentInputs)) + + act(() => { + result.current.handleAnswerChange('Updated answer') + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + answer: 'Updated answer', + })) + expect(result.current.handleVarListChange).toBe(mockHandleVarListChange) + expect(result.current.handleAddVariable).toBe(mockHandleAddVariable) + expect(result.current.readOnly).toBe(false) + }) + + it('should filter out array-object variables from the prompt editor picker', () => { + const { result } = renderHook(() => useConfig('answer-node', currentInputs)) + + expect(result.current.filterVar({ + variable: 'items', + type: VarType.arrayObject, + })).toBe(false) + expect(result.current.filterVar({ + variable: 'message', + type: VarType.string, + })).toBe(true) + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/assigner/__tests__/node.spec.tsx new file mode 100644 index 0000000000..a1fd87d386 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/__tests__/node.spec.tsx @@ -0,0 +1,150 @@ +import type { AssignerNodeOperation, AssignerNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import { useNodes } from 'reactflow' +import { BlockEnum } from '@/app/components/workflow/types' +import Node from '../node' +import { AssignerNodeInputType, WriteMode } from '../types' + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useNodes: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/variable-label', () => ({ + VariableLabelInNode: ({ + variables, + nodeTitle, + nodeType, + rightSlot, + }: { + variables: string[] + nodeTitle?: string + nodeType?: BlockEnum + rightSlot?: React.ReactNode + }) => ( +
+ {`${nodeTitle}:${nodeType}:${variables.join('.')}`} + {rightSlot} +
+ ), +})) + +const mockUseNodes = vi.mocked(useNodes) + +const createOperation = (overrides: Partial = {}): AssignerNodeOperation => ({ + variable_selector: ['node-1', 'count'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-2', 'result'], + ...overrides, +}) + +const createData = (overrides: Partial = {}): AssignerNodeType => ({ + title: 'Assigner', + desc: '', + type: BlockEnum.VariableAssigner, + version: '2', + items: [createOperation()], + ...overrides, +}) + +describe('assigner/node', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodes.mockReturnValue([ + { + id: 'node-1', + data: { + title: 'Answer', + type: BlockEnum.Answer, + }, + }, + { + id: 'start-node', + data: { + title: 'Start', + type: BlockEnum.Start, + }, + }, + ] as ReturnType) + }) + + it('renders the empty-state hint when no assignable variable is configured', () => { + render( + , + ) + + expect(screen.getByText('workflow.nodes.assigner.varNotSet')).toBeInTheDocument() + }) + + it('renders both version 2 and legacy previews with resolved node labels', () => { + const { container, rerender } = render( + , + ) + + expect(screen.getByText('Answer:answer:node-1.count')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.over-write')).toBeInTheDocument() + + rerender( + , + ) + + expect(screen.getByText('Start:start:sys.query')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.append')).toBeInTheDocument() + + rerender( + , + ) + + expect(container).toBeEmptyDOMElement() + }) + + it('skips empty v2 operations and resolves system variables through the start node', () => { + render( + , + ) + + expect(screen.getByText('Start:start:sys.query')).toBeInTheDocument() + expect(screen.queryByText('undefined:undefined:')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/assigner/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..c70c84beab --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/__tests__/panel.spec.tsx @@ -0,0 +1,119 @@ +import type { AssignerNodeOperation, AssignerNodeType } from '../types' +import type { PanelProps } from '@/types/workflow' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' +import { AssignerNodeInputType, WriteMode } from '../types' + +type MockVarListProps = { + readonly: boolean + nodeId: string + list: AssignerNodeOperation[] + onChange: (list: AssignerNodeOperation[]) => void +} + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockUseHandleAddOperationItem = vi.hoisted(() => vi.fn()) +const mockVarListRender = vi.hoisted(() => vi.fn()) + +const createOperation = (overrides: Partial = {}): AssignerNodeOperation => ({ + variable_selector: ['node-1', 'count'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-2', 'result'], + ...overrides, +}) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../hooks', () => ({ + useHandleAddOperationItem: () => mockUseHandleAddOperationItem, +})) + +vi.mock('../components/var-list', () => ({ + __esModule: true, + default: (props: MockVarListProps) => { + mockVarListRender(props) + return ( +
+
{props.list.map(item => item.variable_selector.join('.')).join(',')}
+ +
+ ) + }, +})) + +const createData = (overrides: Partial = {}): AssignerNodeType => ({ + title: 'Assigner', + desc: '', + type: BlockEnum.VariableAssigner, + version: '2', + items: [createOperation()], + ...overrides, +}) + +const panelProps = {} as PanelProps + +describe('assigner/panel', () => { + const handleOperationListChanges = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseHandleAddOperationItem.mockReturnValue([ + createOperation(), + createOperation({ variable_selector: [] }), + ]) + mockUseConfig.mockReturnValue({ + readOnly: false, + inputs: createData(), + handleOperationListChanges, + getAssignedVarType: vi.fn(), + getToAssignedVarType: vi.fn(), + writeModeTypesNum: [], + writeModeTypesArr: [], + writeModeTypes: [], + filterAssignedVar: vi.fn(), + filterToAssignedVar: vi.fn(), + }) + }) + + it('passes the resolved config to the variable list and appends operations through the add button', async () => { + const user = userEvent.setup() + + render( + , + ) + + expect(screen.getByText('workflow.nodes.assigner.variables')).toBeInTheDocument() + expect(screen.getByText('node-1.count')).toBeInTheDocument() + expect(mockVarListRender).toHaveBeenCalledWith(expect.objectContaining({ + readonly: false, + nodeId: 'assigner-node', + list: createData().items, + })) + + await user.click(screen.getAllByRole('button')[0]!) + + expect(mockUseHandleAddOperationItem).toHaveBeenCalledWith(createData().items) + expect(handleOperationListChanges).toHaveBeenCalledWith([ + createOperation(), + createOperation({ variable_selector: [] }), + ]) + + await user.click(screen.getByRole('button', { name: 'emit-list-change' })) + + expect(handleOperationListChanges).toHaveBeenCalledWith([ + createOperation({ variable_selector: ['node-1', 'updated'] }), + ]) + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/assigner/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..0551d1fd30 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,85 @@ +import type { AssignerNodeOperation, AssignerNodeType } from '../types' +import type { InputVar } from '@/app/components/workflow/types' +import { renderHook } from '@testing-library/react' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' +import useNodeCrud from '../../_base/hooks/use-node-crud' +import { AssignerNodeInputType, WriteMode } from '../types' +import useSingleRunFormParams from '../use-single-run-form-params' + +vi.mock('../../_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const mockUseNodeCrud = vi.mocked(useNodeCrud) + +const createOperation = (overrides: Partial = {}): AssignerNodeOperation => ({ + variable_selector: ['node-1', 'target'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-2', 'result'], + ...overrides, +}) + +const createData = (overrides: Partial = {}): AssignerNodeType => ({ + title: 'Assigner', + desc: '', + type: BlockEnum.VariableAssigner, + version: '2', + items: [ + createOperation(), + createOperation({ operation: WriteMode.append, value: ['node-3', 'items'] }), + createOperation({ operation: WriteMode.clear, value: ['node-4', 'unused'] }), + createOperation({ operation: WriteMode.set, input_type: AssignerNodeInputType.constant, value: 'fixed' }), + createOperation({ operation: WriteMode.increment, input_type: AssignerNodeInputType.constant, value: 2 }), + ], + ...overrides, +}) + +describe('assigner/use-single-run-form-params', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodeCrud.mockReturnValue({ + inputs: createData(), + setInputs: vi.fn(), + } as unknown as ReturnType) + }) + + it('exposes only variable-driven dependencies in the single-run form', () => { + const setRunInputData = vi.fn() + const varInputs: InputVar[] = [{ + label: 'Result', + variable: 'result', + type: InputVarType.textInput, + required: true, + }] + const varSelectorsToVarInputs = vi.fn(() => varInputs) + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'assigner-node', + payload: createData(), + runInputData: { result: 'hello' }, + runInputDataRef: { current: {} }, + getInputVars: () => [], + setRunInputData, + toVarInputs: () => [], + varSelectorsToVarInputs, + })) + + expect(varSelectorsToVarInputs).toHaveBeenCalledWith([ + ['node-2', 'result'], + ['node-3', 'items'], + ]) + expect(result.current.forms).toHaveLength(1) + expect(result.current.forms[0].inputs).toEqual(varInputs) + expect(result.current.forms[0].values).toEqual({ result: 'hello' }) + + result.current.forms[0].onChange({ result: 'updated' }) + + expect(setRunInputData).toHaveBeenCalledWith({ result: 'updated' }) + expect(result.current.getDependentVars()).toEqual([ + ['node-2', 'result'], + ['node-3', 'items'], + ]) + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx b/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx new file mode 100644 index 0000000000..63813c8a46 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/components/__tests__/operation-selector.spec.tsx @@ -0,0 +1,52 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { VarType } from '@/app/components/workflow/types' +import { WriteMode } from '../../types' +import OperationSelector from '../operation-selector' + +describe('assigner/operation-selector', () => { + it('shows numeric write modes and emits the selected operation', async () => { + const user = userEvent.setup() + const onSelect = vi.fn() + + render( + , + ) + + await user.click(screen.getByText('workflow.nodes.assigner.operations.over-write')) + + expect(screen.getByText('workflow.nodes.assigner.operations.title')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.clear')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.set')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.assigner.operations.+=')).toBeInTheDocument() + + await user.click(screen.getAllByText('workflow.nodes.assigner.operations.+=').at(-1)!) + + expect(onSelect).toHaveBeenCalledWith({ value: WriteMode.increment, name: WriteMode.increment }) + }) + + it('does not open when the selector is disabled', async () => { + const user = userEvent.setup() + + render( + , + ) + + await user.click(screen.getByText('workflow.nodes.assigner.operations.over-write')) + + expect(screen.queryByText('workflow.nodes.assigner.operations.title')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/branches.spec.tsx b/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/branches.spec.tsx new file mode 100644 index 0000000000..a9b5a304f4 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/branches.spec.tsx @@ -0,0 +1,213 @@ +import type { ComponentProps } from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { VarType } from '@/app/components/workflow/types' +import { AssignerNodeInputType, WriteMode } from '../../../types' +import VarList from '../index' + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({ + __esModule: true, + default: ({ + popupFor = 'assigned', + onOpen, + onChange, + }: { + popupFor?: string + onOpen?: () => void + onChange: (value: string[]) => void + }) => ( +
+ + +
+ ), +})) + +vi.mock('../../operation-selector', () => ({ + __esModule: true, + default: ({ + onSelect, + }: { + onSelect: (item: { value: string }) => void + }) => ( +
+ + +
+ ), +})) + +const createOperation = ( + overrides: Partial['list'][number]> = {}, +): ComponentProps['list'][number] => ({ + variable_selector: ['node-a', 'flag'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-a', 'answer'], + ...overrides, +}) + +const renderVarList = (props: Partial> = {}) => { + const handleChange = vi.fn() + const handleOpen = vi.fn() + + const result = render( + VarType.string} + getToAssignedVarType={() => VarType.string} + writeModeTypes={[WriteMode.overwrite, WriteMode.clear, WriteMode.set]} + writeModeTypesArr={[WriteMode.overwrite, WriteMode.clear]} + writeModeTypesNum={[WriteMode.increment]} + {...props} + />, + ) + + return { + ...result, + handleChange, + handleOpen, + } +} + +describe('assigner/var-list branches', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('resets operation metadata when the assigned variable changes', async () => { + const user = userEvent.setup() + const { handleChange, handleOpen } = renderVarList({ + list: [createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 'stale', + })], + }) + + await user.click(screen.getByTestId('assigned-picker-trigger')) + await user.click(screen.getByRole('button', { name: 'select-assigned' })) + + expect(handleOpen).toHaveBeenCalledWith(0) + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + variable_selector: ['node-b', 'total'], + operation: WriteMode.overwrite, + input_type: AssignerNodeInputType.variable, + value: undefined, + }), + ], ['node-b', 'total']) + }) + + it('switches back to variable mode when the selected operation no longer requires a constant', async () => { + const user = userEvent.setup() + const { handleChange } = renderVarList({ + list: [createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 'hello', + })], + }) + + await user.click(screen.getByRole('button', { name: 'operation-overwrite' })) + + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + operation: WriteMode.overwrite, + input_type: AssignerNodeInputType.variable, + value: '', + }), + ]) + }) + + it('updates string and number constant inputs through the inline editors', () => { + const { handleChange, rerender } = renderVarList({ + list: [createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 1, + })], + getAssignedVarType: () => VarType.number, + getToAssignedVarType: () => VarType.number, + }) + + fireEvent.change(screen.getByRole('spinbutton'), { + target: { value: '2' }, + }) + + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 2, + }), + ], 2) + + rerender( + VarType.string} + getToAssignedVarType={() => VarType.string} + writeModeTypes={[WriteMode.overwrite, WriteMode.clear, WriteMode.set]} + writeModeTypesArr={[WriteMode.overwrite, WriteMode.clear]} + writeModeTypesNum={[WriteMode.increment]} + />, + ) + + fireEvent.change(screen.getByRole('textbox'), { + target: { value: 'updated' }, + }) + + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: 'updated', + }), + ], 'updated') + }) + + it('updates numeric write-mode inputs through the dedicated number field', () => { + const { handleChange } = renderVarList({ + list: [createOperation({ + operation: WriteMode.increment, + value: 2, + })], + getAssignedVarType: () => VarType.number, + getToAssignedVarType: () => VarType.number, + writeModeTypesNum: [WriteMode.increment], + }) + + fireEvent.change(screen.getByRole('spinbutton'), { + target: { value: '5' }, + }) + + expect(handleChange).toHaveBeenLastCalledWith([ + createOperation({ + operation: WriteMode.increment, + value: 5, + }), + ], 5) + }) +}) diff --git a/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/index.spec.tsx b/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/index.spec.tsx new file mode 100644 index 0000000000..f7408ab814 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/components/var-list/__tests__/index.spec.tsx @@ -0,0 +1,146 @@ +import type { ComponentProps } from 'react' +import { fireEvent, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { createNode, resetFixtureCounters } from '@/app/components/workflow/__tests__/fixtures' +import { renderWorkflowFlowComponent } from '@/app/components/workflow/__tests__/workflow-test-env' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import { AssignerNodeInputType, WriteMode } from '../../../types' +import VarList from '../index' + +const sourceNode = createNode({ + id: 'node-a', + data: { + type: BlockEnum.Answer, + title: 'Answer Node', + outputs: { + answer: { type: VarType.string }, + flag: { type: VarType.boolean }, + }, + }, +}) + +const currentNode = createNode({ + id: 'node-current', + data: { + type: BlockEnum.VariableAssigner, + title: 'Assigner Node', + }, +}) + +const createOperation = (overrides: Partial['list'][number]> = {}) => ({ + variable_selector: ['node-a', 'flag'], + input_type: AssignerNodeInputType.variable, + operation: WriteMode.overwrite, + value: ['node-a', 'answer'], + ...overrides, +}) + +const renderVarList = (props: Partial> = {}) => { + const handleChange = vi.fn() + const handleOpen = vi.fn() + + const result = renderWorkflowFlowComponent( + VarType.string} + getToAssignedVarType={() => VarType.string} + writeModeTypes={[WriteMode.overwrite, WriteMode.clear, WriteMode.set]} + writeModeTypesArr={[WriteMode.overwrite, WriteMode.clear]} + writeModeTypesNum={[WriteMode.increment]} + {...props} + />, + { + nodes: [sourceNode, currentNode], + edges: [], + hooksStoreProps: {}, + }, + ) + + return { + ...result, + handleChange, + handleOpen, + } +} + +describe('assigner/var-list', () => { + beforeEach(() => { + resetFixtureCounters() + }) + + it('renders the empty placeholder when no operations are configured', () => { + renderVarList() + + expect(screen.getByText('workflow.nodes.assigner.noVarTip')).toBeInTheDocument() + }) + + it('switches a boolean assignment to constant mode and updates the selected value', async () => { + const user = userEvent.setup() + const list = [createOperation()] + const { handleChange, rerender } = renderVarList({ + list, + getAssignedVarType: () => VarType.boolean, + getToAssignedVarType: () => VarType.boolean, + }) + + await user.click(screen.getByText('workflow.nodes.assigner.operations.over-write')) + await user.click(screen.getAllByText('workflow.nodes.assigner.operations.set').at(-1)!) + + expect(handleChange.mock.lastCall?.[0]).toEqual([ + createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: false, + }), + ]) + + rerender( + VarType.boolean} + getToAssignedVarType={() => VarType.boolean} + writeModeTypes={[WriteMode.overwrite, WriteMode.clear, WriteMode.set]} + writeModeTypesArr={[WriteMode.overwrite, WriteMode.clear]} + writeModeTypesNum={[WriteMode.increment]} + />, + ) + + await user.click(screen.getByText('True')) + + expect(handleChange.mock.lastCall?.[0]).toEqual([ + createOperation({ + operation: WriteMode.set, + input_type: AssignerNodeInputType.constant, + value: true, + }), + ]) + }) + + it('opens the assigned-variable picker and removes an operation', () => { + const { handleChange, handleOpen } = renderVarList({ + list: [createOperation()], + }) + + fireEvent.click(screen.getAllByTestId('var-reference-picker-trigger')[0]!) + expect(handleOpen).toHaveBeenCalledWith(0) + + const buttons = screen.getAllByRole('button') + fireEvent.click(buttons[buttons.length - 1]!) + + expect(handleChange).toHaveBeenLastCalledWith([]) + }) +}) diff --git a/web/app/components/workflow/nodes/code/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/code/__tests__/node.spec.tsx new file mode 100644 index 0000000000..a8648324ed --- /dev/null +++ b/web/app/components/workflow/nodes/code/__tests__/node.spec.tsx @@ -0,0 +1,29 @@ +import type { CodeNodeType } from '../types' +import { render } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import Node from '../node' +import { CodeLanguage } from '../types' + +const createData = (overrides: Partial = {}): CodeNodeType => ({ + title: 'Code', + desc: '', + type: BlockEnum.Code, + variables: [], + code_language: CodeLanguage.javascript, + code: 'function main() { return {} }', + outputs: {}, + ...overrides, +}) + +describe('code/node', () => { + it('renders an empty summary container', () => { + const { container } = render( + , + ) + + expect(container.firstChild).toBeEmptyDOMElement() + }) +}) diff --git a/web/app/components/workflow/nodes/code/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/code/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..72d640651d --- /dev/null +++ b/web/app/components/workflow/nodes/code/__tests__/panel.spec.tsx @@ -0,0 +1,295 @@ +import type { ReactNode } from 'react' +import type { CodeNodeType, OutputVar } from '../types' +import type useConfig from '../use-config' +import type { NodePanelProps, Variable } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import Panel from '../panel' +import { CodeLanguage } from '../types' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockExtractFunctionParams = vi.hoisted(() => vi.fn()) +const mockExtractReturnType = vi.hoisted(() => vi.fn()) +const mockCodeEditor = vi.hoisted(() => vi.fn()) +const mockVarList = vi.hoisted(() => vi.fn()) +const mockOutputVarList = vi.hoisted(() => vi.fn()) +const mockRemoveEffectVarConfirm = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../code-parser', () => ({ + extractFunctionParams: (...args: unknown[]) => mockExtractFunctionParams(...args), + extractReturnType: (...args: unknown[]) => mockExtractReturnType(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({ + __esModule: true, + default: (props: { + readOnly: boolean + language: CodeLanguage + value: string + onChange: (value: string) => void + onGenerated: (value: string) => void + title: ReactNode + }) => { + mockCodeEditor(props) + return ( +
+
{props.readOnly ? 'editor:readonly' : 'editor:editable'}
+
{props.language}
+
{props.title}
+ + +
+ ) + }, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/selector', () => ({ + __esModule: true, + default: (props: { + value: CodeLanguage + onChange: (value: CodeLanguage) => void + }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-list', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + list: Variable[] + onChange: (list: Variable[]) => void + }) => { + mockVarList(props) + return ( +
+
{props.readonly ? 'var-list:readonly' : 'var-list:editable'}
+ +
+ ) + }, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/output-var-list', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + outputs: OutputVar + onChange: (outputs: OutputVar) => void + onRemove: (name: string) => void + }) => { + mockOutputVarList(props) + return ( +
+
{props.readonly ? 'output-list:readonly' : 'output-list:editable'}
+ + +
+ ) + }, +})) + +vi.mock('../../_base/components/remove-effect-var-confirm', () => ({ + __esModule: true, + default: (props: { + isShow: boolean + onCancel: () => void + onConfirm: () => void + }) => { + mockRemoveEffectVarConfirm(props) + return props.isShow + ? ( +
+ + +
+ ) + : null + }, +})) + +const createData = (overrides: Partial = {}): CodeNodeType => ({ + title: 'Code', + desc: '', + type: BlockEnum.Code, + code_language: CodeLanguage.javascript, + code: 'function main({ foo }) { return { result: foo } }', + variables: [{ + variable: 'foo', + value_selector: ['start', 'foo'], + value_type: VarType.string, + }], + outputs: { + result: { + type: VarType.string, + children: null, + }, + }, + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + outputKeyOrders: ['result'], + handleCodeAndVarsChange: vi.fn(), + handleVarListChange: vi.fn(), + handleAddVariable: vi.fn(), + handleRemoveVariable: vi.fn(), + handleSyncFunctionSignature: vi.fn(), + handleCodeChange: vi.fn(), + handleCodeLanguageChange: vi.fn(), + handleVarsChange: vi.fn(), + handleAddOutputVariable: vi.fn(), + filterVar: vi.fn(() => true), + isShowRemoveVarConfirm: true, + hideRemoveVarConfirm: vi.fn(), + onRemoveVarConfirm: vi.fn(), + ...overrides, +}) + +const renderPanel = (data: CodeNodeType = createData()) => { + const props: NodePanelProps = { + id: 'code-node', + data, + panelProps: { + getInputVars: vi.fn(() => []), + toVarInputs: vi.fn(() => []), + runInputData: {}, + runInputDataRef: { current: {} }, + setRunInputData: vi.fn(), + runResult: null, + }, + } + + return render() +} + +describe('code/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockExtractFunctionParams.mockReturnValue(['summary', 'count']) + mockExtractReturnType.mockReturnValue({ + result: { + type: VarType.string, + children: null, + }, + }) + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders editable controls and forwards all input, output, and code actions', async () => { + const user = userEvent.setup() + const config = createConfigResult() + mockUseConfig.mockReturnValue(config) + + renderPanel() + + expect(screen.getByText('workflow.nodes.code.inputVars')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.code.outputVars')).toBeInTheDocument() + expect(screen.getByText('editor:editable')).toBeInTheDocument() + expect(screen.getByText('language:javascript')).toBeInTheDocument() + + const addButtons = screen.getAllByTestId('add-button') + await user.click(addButtons[0]!) + await user.click(screen.getByTestId('sync-button')) + await user.click(screen.getByRole('button', { name: 'change-code' })) + await user.click(screen.getByRole('button', { name: 'generate-code' })) + await user.click(screen.getByRole('button', { name: 'language:javascript' })) + await user.click(screen.getByRole('button', { name: 'change-var-list' })) + await user.click(screen.getByRole('button', { name: 'change-output-list' })) + await user.click(screen.getByRole('button', { name: 'remove-output' })) + await user.click(addButtons[1]!) + await user.click(screen.getByRole('button', { name: 'cancel-remove' })) + await user.click(screen.getByRole('button', { name: 'confirm-remove' })) + + expect(config.handleAddVariable).toHaveBeenCalled() + expect(config.handleSyncFunctionSignature).toHaveBeenCalled() + expect(config.handleCodeChange).toHaveBeenCalledWith('generated code body') + expect(config.handleCodeLanguageChange).toHaveBeenCalledWith(CodeLanguage.python3) + expect(config.handleVarListChange).toHaveBeenCalledWith([{ + variable: 'changed', + value_selector: ['start', 'changed'], + }]) + expect(config.handleVarsChange).toHaveBeenCalledWith({ + next_result: { + type: VarType.number, + children: null, + }, + }) + expect(config.handleRemoveVariable).toHaveBeenCalledWith('result') + expect(config.handleAddOutputVariable).toHaveBeenCalled() + expect(config.hideRemoveVarConfirm).toHaveBeenCalled() + expect(config.onRemoveVarConfirm).toHaveBeenCalled() + expect(config.handleCodeAndVarsChange).toHaveBeenCalledWith( + 'generated signature code', + [{ + variable: 'summary', + value_selector: [], + }, { + variable: 'count', + value_selector: [], + }], + { + result: { + type: VarType.string, + children: null, + }, + }, + ) + expect(mockExtractFunctionParams).toHaveBeenCalledWith('generated signature code', CodeLanguage.javascript) + expect(mockExtractReturnType).toHaveBeenCalledWith('generated signature code', CodeLanguage.javascript) + }) + + it('removes input actions in readonly mode and passes readonly state to child sections', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + readOnly: true, + isShowRemoveVarConfirm: false, + })) + + renderPanel() + + expect(screen.queryByTestId('sync-button')).not.toBeInTheDocument() + expect(screen.getAllByTestId('add-button')).toHaveLength(1) + expect(screen.getByText('editor:readonly')).toBeInTheDocument() + expect(screen.getByText('var-list:readonly')).toBeInTheDocument() + expect(screen.getByText('output-list:readonly')).toBeInTheDocument() + expect(mockRemoveEffectVarConfirm).toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/nodes/code/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/code/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..b02ff8a4fc --- /dev/null +++ b/web/app/components/workflow/nodes/code/__tests__/use-config.spec.ts @@ -0,0 +1,315 @@ +import type { CodeNodeType, OutputVar } from '../types' +import type { Var, Variable } from '@/app/components/workflow/types' +import { act, renderHook, waitFor } from '@testing-library/react' +import { useNodesReadOnly } from '@/app/components/workflow/hooks' +import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' +import { useStore } from '@/app/components/workflow/store' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import { fetchNodeDefault, fetchPipelineNodeDefault } from '@/service/workflow' +import useOutputVarList from '../../_base/hooks/use-output-var-list' +import useVarList from '../../_base/hooks/use-var-list' +import { CodeLanguage } from '../types' +import useConfig from '../use-config' + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-output-var-list', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: vi.fn(), +})) + +vi.mock('@/service/workflow', () => ({ + fetchNodeDefault: vi.fn(), + fetchPipelineNodeDefault: vi.fn(), +})) + +const mockUseNodesReadOnly = vi.mocked(useNodesReadOnly) +const mockUseNodeCrud = vi.mocked(useNodeCrud) +const mockUseVarList = vi.mocked(useVarList) +const mockUseOutputVarList = vi.mocked(useOutputVarList) +const mockUseStore = vi.mocked(useStore) +const mockFetchNodeDefault = vi.mocked(fetchNodeDefault) +const mockFetchPipelineNodeDefault = vi.mocked(fetchPipelineNodeDefault) + +const createVariable = (variable: string, valueType: VarType = VarType.string): Variable => ({ + variable, + value_selector: ['start', variable], + value_type: valueType, +}) + +const createOutputs = (name = 'result', type: VarType = VarType.string): OutputVar => ({ + [name]: { + type, + children: null, + }, +}) + +const createData = (overrides: Partial = {}): CodeNodeType => ({ + title: 'Code', + desc: '', + type: BlockEnum.Code, + code_language: CodeLanguage.javascript, + code: 'function main({ foo }) { return { result: foo } }', + variables: [createVariable('foo')], + outputs: createOutputs(), + ...overrides, +}) + +describe('code/use-config', () => { + const mockSetInputs = vi.fn() + const mockHandleVarListChange = vi.fn() + const mockHandleAddVariable = vi.fn() + const mockHandleVarsChange = vi.fn() + const mockHandleAddOutputVariable = vi.fn() + const mockHandleRemoveVariable = vi.fn() + const mockHideRemoveVarConfirm = vi.fn() + const mockOnRemoveVarConfirm = vi.fn() + + let workflowStoreState: { + appId?: string + pipelineId?: string + nodesDefaultConfigs?: Record + } + let currentInputs: CodeNodeType + let javaScriptConfig: CodeNodeType + let pythonConfig: CodeNodeType + + beforeEach(() => { + vi.clearAllMocks() + + javaScriptConfig = createData({ + code_language: CodeLanguage.javascript, + code: 'function main({ query }) { return { result: query } }', + variables: [createVariable('query')], + outputs: createOutputs('result'), + }) + pythonConfig = createData({ + code_language: CodeLanguage.python3, + code: 'def main(name: str):\n return {"result": name}', + variables: [createVariable('name')], + outputs: createOutputs('result'), + }) + currentInputs = createData() + workflowStoreState = { + appId: undefined, + pipelineId: undefined, + nodesDefaultConfigs: { + [BlockEnum.Code]: createData({ + code_language: CodeLanguage.javascript, + code: 'function main() { return { default_result: "" } }', + variables: [], + outputs: createOutputs('default_result'), + }), + }, + } + + mockUseNodesReadOnly.mockReturnValue({ + nodesReadOnly: false, + getNodesReadOnly: () => false, + }) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs: mockSetInputs, + })) + mockUseVarList.mockReturnValue({ + handleVarListChange: mockHandleVarListChange, + handleAddVariable: mockHandleAddVariable, + } as ReturnType) + mockUseOutputVarList.mockReturnValue({ + handleVarsChange: mockHandleVarsChange, + handleAddVariable: mockHandleAddOutputVariable, + handleRemoveVariable: mockHandleRemoveVariable, + isShowRemoveVarConfirm: false, + hideRemoveVarConfirm: mockHideRemoveVarConfirm, + onRemoveVarConfirm: mockOnRemoveVarConfirm, + } as ReturnType) + mockUseStore.mockImplementation(selector => selector(workflowStoreState as never)) + mockFetchNodeDefault.mockResolvedValue({ config: javaScriptConfig } as never) + mockFetchPipelineNodeDefault.mockResolvedValue({ config: javaScriptConfig } as never) + mockFetchNodeDefault + .mockResolvedValueOnce({ config: javaScriptConfig } as never) + .mockResolvedValueOnce({ config: pythonConfig } as never) + mockFetchPipelineNodeDefault + .mockResolvedValueOnce({ config: javaScriptConfig } as never) + .mockResolvedValueOnce({ config: pythonConfig } as never) + }) + + it('hydrates node defaults when the code payload is empty and syncs output key order', async () => { + currentInputs = createData({ + code: '', + variables: [], + outputs: {}, + }) + + const { result } = renderHook(() => useConfig('code-node', currentInputs)) + + await waitFor(() => { + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: workflowStoreState.nodesDefaultConfigs?.[BlockEnum.Code]?.code, + outputs: workflowStoreState.nodesDefaultConfigs?.[BlockEnum.Code]?.outputs, + })) + }) + + expect(result.current.handleVarListChange).toBe(mockHandleVarListChange) + expect(result.current.handleAddVariable).toBe(mockHandleAddVariable) + expect(result.current.handleVarsChange).toBe(mockHandleVarsChange) + expect(result.current.handleAddOutputVariable).toBe(mockHandleAddOutputVariable) + expect(result.current.handleRemoveVariable).toBe(mockHandleRemoveVariable) + expect(result.current.hideRemoveVarConfirm).toBe(mockHideRemoveVarConfirm) + expect(result.current.onRemoveVarConfirm).toBe(mockOnRemoveVarConfirm) + expect(result.current.outputKeyOrders).toEqual(['default_result']) + expect(result.current.filterVar({ type: VarType.file } as Var)).toBe(true) + expect(result.current.filterVar({ type: VarType.secret } as Var)).toBe(true) + }) + + it('fetches app and pipeline defaults, switches language, and updates code and output vars together', async () => { + workflowStoreState.appId = 'app-1' + workflowStoreState.pipelineId = 'pipeline-1' + + const { result } = renderHook(() => useConfig('code-node', currentInputs)) + + await waitFor(() => { + expect(mockFetchNodeDefault).toHaveBeenCalledWith('app-1', BlockEnum.Code, { code_language: CodeLanguage.javascript }) + expect(mockFetchNodeDefault).toHaveBeenCalledWith('app-1', BlockEnum.Code, { code_language: CodeLanguage.python3 }) + expect(mockFetchPipelineNodeDefault).toHaveBeenCalledWith('pipeline-1', BlockEnum.Code, { code_language: CodeLanguage.javascript }) + expect(mockFetchPipelineNodeDefault).toHaveBeenCalledWith('pipeline-1', BlockEnum.Code, { code_language: CodeLanguage.python3 }) + }) + + mockSetInputs.mockClear() + + act(() => { + result.current.handleCodeLanguageChange(CodeLanguage.python3) + result.current.handleCodeChange('function main({ bar }) { return { result: bar } }') + result.current.handleCodeAndVarsChange( + 'function main({ amount }) { return { total: amount } }', + [createVariable('amount', VarType.number)], + createOutputs('total', VarType.number), + ) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code_language: CodeLanguage.python3, + code: pythonConfig.code, + variables: pythonConfig.variables, + outputs: pythonConfig.outputs, + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: 'function main({ bar }) { return { result: bar } }', + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: 'function main({ amount }) { return { total: amount } }', + variables: [expect.objectContaining({ variable: 'amount' })], + outputs: createOutputs('total', VarType.number), + })) + expect(result.current.outputKeyOrders).toEqual(['total']) + }) + + it('syncs javascript and python function signatures and keeps json code unchanged', () => { + currentInputs = createData({ + code_language: CodeLanguage.javascript, + code: 'function main() { return { result: "" } }', + variables: [createVariable('foo'), createVariable('bar')], + }) + + const { result, rerender } = renderHook(() => useConfig('code-node', currentInputs)) + + act(() => { + result.current.handleSyncFunctionSignature() + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: 'function main({foo, bar}) { return { result: "" } }', + })) + + mockSetInputs.mockClear() + currentInputs = createData({ + code_language: CodeLanguage.python3, + code: 'def main():\n return {"result": ""}', + variables: [ + createVariable('text', VarType.string), + createVariable('score', VarType.number), + createVariable('payload', VarType.object), + createVariable('items', VarType.array), + createVariable('numbers', VarType.arrayNumber), + createVariable('names', VarType.arrayString), + createVariable('records', VarType.arrayObject), + ], + }) + rerender() + + act(() => { + result.current.handleSyncFunctionSignature() + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: 'def main(text: str, score: float, payload: dict, items: list, numbers: list[float], names: list[str], records: list[dict]):\n return {"result": ""}', + })) + + mockSetInputs.mockClear() + currentInputs = createData({ + code_language: CodeLanguage.json, + code: '{"result": true}', + }) + rerender() + + act(() => { + result.current.handleSyncFunctionSignature() + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code: '{"result": true}', + })) + }) + + it('keeps language changes local when no fetched default exists and preserves existing output order', async () => { + currentInputs = createData({ + outputs: { + summary: { + type: VarType.string, + children: null, + }, + count: { + type: VarType.number, + children: null, + }, + }, + }) + workflowStoreState.appId = undefined + workflowStoreState.pipelineId = undefined + + const { result } = renderHook(() => useConfig('code-node', currentInputs)) + + await waitFor(() => { + expect(result.current.outputKeyOrders).toEqual(['summary', 'count']) + }) + + mockSetInputs.mockClear() + + act(() => { + result.current.handleCodeLanguageChange(CodeLanguage.python3) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + code_language: CodeLanguage.python3, + code: currentInputs.code, + variables: currentInputs.variables, + outputs: currentInputs.outputs, + })) + }) +}) diff --git a/web/app/components/workflow/nodes/code/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/code/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..39e9d8139a --- /dev/null +++ b/web/app/components/workflow/nodes/code/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,80 @@ +import type { CodeNodeType } from '../types' +import { renderHook } from '@testing-library/react' +import { BlockEnum, InputVarType, VarType } from '@/app/components/workflow/types' +import useNodeCrud from '../../_base/hooks/use-node-crud' +import { CodeLanguage } from '../types' +import useSingleRunFormParams from '../use-single-run-form-params' + +vi.mock('../../_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const mockUseNodeCrud = vi.mocked(useNodeCrud) + +const createData = (overrides: Partial = {}): CodeNodeType => ({ + title: 'Code', + desc: '', + type: BlockEnum.Code, + code_language: CodeLanguage.javascript, + code: 'function main({ amount }) { return { result: amount } }', + variables: [{ + variable: 'amount', + value_selector: ['start', 'amount'], + value_type: VarType.number, + }], + outputs: { + result: { + type: VarType.number, + children: null, + }, + }, + ...overrides, +}) + +describe('code/use-single-run-form-params', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodeCrud.mockReturnValue({ + inputs: createData(), + setInputs: vi.fn(), + } as unknown as ReturnType) + }) + + it('builds a single form, updates run input values, and exposes dependent vars', () => { + const setRunInputData = vi.fn() + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'code-node', + payload: createData(), + runInputData: { amount: 1 }, + runInputDataRef: { current: { amount: 1 } }, + getInputVars: () => [], + setRunInputData, + toVarInputs: variables => variables.map(variable => ({ + type: InputVarType.number, + label: variable.variable, + variable: variable.variable, + required: false, + })), + })) + + expect(result.current.forms).toEqual([{ + inputs: [{ + type: InputVarType.number, + label: 'amount', + variable: 'amount', + required: false, + }], + values: { amount: 1 }, + onChange: expect.any(Function), + }]) + + result.current.forms[0]?.onChange({ amount: 3 }) + + expect(setRunInputData).toHaveBeenCalledWith({ amount: 3 }) + expect(result.current.getDependentVars()).toEqual([['start', 'amount']]) + expect(result.current.getDependentVar('amount')).toEqual(['start', 'amount']) + expect(result.current.getDependentVar('missing')).toBeUndefined() + }) +}) diff --git a/web/app/components/workflow/nodes/data-source/__tests__/before-run-form.spec.tsx b/web/app/components/workflow/nodes/data-source/__tests__/before-run-form.spec.tsx new file mode 100644 index 0000000000..c12ec212bf --- /dev/null +++ b/web/app/components/workflow/nodes/data-source/__tests__/before-run-form.spec.tsx @@ -0,0 +1,205 @@ +import type { ReactNode } from 'react' +import type { CustomRunFormProps, DataSourceNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { DatasourceType } from '@/models/pipeline' +import { FlowType } from '@/types/common' +import { BlockEnum } from '../../../types' +import BeforeRunForm from '../before-run-form' +import useBeforeRunForm from '../hooks/use-before-run-form' + +const mockUseDataSourceStore = vi.hoisted(() => vi.fn()) +const mockSetCurrentCredentialId = vi.hoisted(() => vi.fn()) +const mockClearOnlineDocumentData = vi.hoisted(() => vi.fn()) +const mockClearWebsiteCrawlData = vi.hoisted(() => vi.fn()) +const mockClearOnlineDriveData = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store', () => ({ + useDataSourceStore: () => mockUseDataSourceStore(), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store/provider', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) => <>{children}, +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/local-file', () => ({ + __esModule: true, + default: ({ allowedExtensions }: { allowedExtensions: string[] }) =>
{allowedExtensions.join(',')}
, +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/online-documents', () => ({ + __esModule: true, + default: ({ onCredentialChange }: { onCredentialChange: (credentialId: string) => void }) => ( + + ), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl', () => ({ + __esModule: true, + default: ({ onCredentialChange }: { onCredentialChange: (credentialId: string) => void }) => ( + + ), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/online-drive', () => ({ + __esModule: true, + default: ({ onCredentialChange }: { onCredentialChange: (credentialId: string) => void }) => ( + + ), +})) + +vi.mock('@/app/components/rag-pipeline/components/panel/test-run/preparation/hooks', () => ({ + useOnlineDocument: () => ({ clearOnlineDocumentData: mockClearOnlineDocumentData }), + useWebsiteCrawl: () => ({ clearWebsiteCrawlData: mockClearWebsiteCrawlData }), + useOnlineDrive: () => ({ clearOnlineDriveData: mockClearOnlineDriveData }), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/before-run-form/panel-wrap', () => ({ + __esModule: true, + default: ({ nodeName, onHide, children }: { nodeName: string, onHide: () => void, children: ReactNode }) => ( +
+
{nodeName}
+ + {children} +
+ ), +})) + +vi.mock('../hooks/use-before-run-form', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const mockUseBeforeRunForm = vi.mocked(useBeforeRunForm) + +const createData = (overrides: Partial = {}): DataSourceNodeType => ({ + title: 'Datasource', + desc: '', + type: BlockEnum.DataSource, + plugin_id: 'plugin-id', + provider_type: DatasourceType.localFile, + provider_name: 'file', + datasource_name: 'local-file', + datasource_label: 'Local File', + datasource_parameters: {}, + datasource_configurations: {}, + fileExtensions: ['pdf', 'md'], + ...overrides, +}) + +const createProps = (overrides: Partial = {}): CustomRunFormProps => ({ + nodeId: 'data-source-node', + flowId: 'flow-id', + flowType: FlowType.ragPipeline, + payload: createData(), + setRunResult: vi.fn(), + setIsRunAfterSingleRun: vi.fn(), + isPaused: false, + isRunAfterSingleRun: false, + onSuccess: vi.fn(), + onCancel: vi.fn(), + appendNodeInspectVars: vi.fn(), + ...overrides, +}) + +describe('data-source/before-run-form', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseDataSourceStore.mockReturnValue({ + getState: () => ({ + setCurrentCredentialId: mockSetCurrentCredentialId, + }), + }) + mockUseBeforeRunForm.mockReturnValue({ + isPending: false, + handleRunWithSyncDraft: vi.fn(), + datasourceType: DatasourceType.localFile, + datasourceNodeData: createData(), + startRunBtnDisabled: false, + }) + }) + + it('renders the local-file preparation form and triggers run/cancel actions', async () => { + const user = userEvent.setup() + const onCancel = vi.fn() + const handleRunWithSyncDraft = vi.fn() + + mockUseBeforeRunForm.mockReturnValueOnce({ + isPending: false, + handleRunWithSyncDraft, + datasourceType: DatasourceType.localFile, + datasourceNodeData: createData(), + startRunBtnDisabled: false, + }) + + render() + + expect(screen.getByText('Datasource')).toBeInTheDocument() + expect(screen.getByText('pdf,md')).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'common.operation.cancel' })) + await user.click(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })) + + expect(onCancel).toHaveBeenCalled() + expect(handleRunWithSyncDraft).toHaveBeenCalled() + }) + + it('clears stale online document data before switching credentials', async () => { + const user = userEvent.setup() + + mockUseBeforeRunForm.mockReturnValueOnce({ + isPending: false, + handleRunWithSyncDraft: vi.fn(), + datasourceType: DatasourceType.onlineDocument, + datasourceNodeData: createData({ provider_type: DatasourceType.onlineDocument }), + startRunBtnDisabled: true, + }) + + render() + + await user.click(screen.getByRole('button', { name: 'online-documents' })) + + expect(mockClearOnlineDocumentData).toHaveBeenCalled() + expect(mockSetCurrentCredentialId).toHaveBeenCalledWith('credential-doc') + expect(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })).toBeDisabled() + }) + + it('clears website crawl data before switching credentials', async () => { + const user = userEvent.setup() + + mockUseBeforeRunForm.mockReturnValueOnce({ + isPending: false, + handleRunWithSyncDraft: vi.fn(), + datasourceType: DatasourceType.websiteCrawl, + datasourceNodeData: createData({ provider_type: DatasourceType.websiteCrawl }), + startRunBtnDisabled: false, + }) + + render() + + await user.click(screen.getByRole('button', { name: 'website-crawl' })) + + expect(mockClearWebsiteCrawlData).toHaveBeenCalled() + expect(mockSetCurrentCredentialId).toHaveBeenCalledWith('credential-site') + }) + + it('clears online drive data before switching credentials', async () => { + const user = userEvent.setup() + + mockUseBeforeRunForm.mockReturnValueOnce({ + isPending: false, + handleRunWithSyncDraft: vi.fn(), + datasourceType: DatasourceType.onlineDrive, + datasourceNodeData: createData({ provider_type: DatasourceType.onlineDrive }), + startRunBtnDisabled: false, + }) + + render() + + await user.click(screen.getByRole('button', { name: 'online-drive' })) + + expect(mockClearOnlineDriveData).toHaveBeenCalled() + expect(mockSetCurrentCredentialId).toHaveBeenCalledWith('credential-drive') + }) +}) diff --git a/web/app/components/workflow/nodes/data-source/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/data-source/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..8160da6502 --- /dev/null +++ b/web/app/components/workflow/nodes/data-source/__tests__/panel.spec.tsx @@ -0,0 +1,194 @@ +import type { ReactNode } from 'react' +import type { DataSourceNodeType } from '../types' +import type { NodePanelProps } from '@/app/components/workflow/types' +import { fireEvent, render, screen } from '@testing-library/react' +import { toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' +import { useNodesReadOnly } from '@/app/components/workflow/hooks' +import { useStore } from '@/app/components/workflow/store' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import useMatchSchemaType, { getMatchedSchemaType } from '../../_base/components/variable/use-match-schema-type' +import ToolForm from '../../tool/components/tool-form' +import { useConfig } from '../hooks/use-config' +import Panel from '../panel' + +const mockWrapStructuredVarItem = vi.hoisted(() => vi.fn((payload: unknown) => payload)) + +vi.mock('@/app/components/base/tag-input', () => ({ + __esModule: true, + default: ({ + items, + onChange, + placeholder, + }: { + items: string[] + onChange: (items: string[]) => void + placeholder?: string + }) => ( + + ), +})) + +vi.mock('@/app/components/tools/utils/to-form-schema', () => ({ + toolParametersToFormSchemas: vi.fn(), +})) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: vi.fn(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: vi.fn(), +})) + +vi.mock('@/app/components/workflow/utils/tool', () => ({ + wrapStructuredVarItem: (payload: unknown) => mockWrapStructuredVarItem(payload), +})) + +vi.mock('../../_base/components/output-vars', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) =>
{children}
, + VarItem: ({ name, type }: { name: string, type: string }) =>
{`${name}:${type}`}
, +})) + +vi.mock('../../_base/components/variable/object-child-tree-panel/show', () => ({ + __esModule: true, + default: ({ payload }: { payload: { name: string } }) =>
{payload.name}
, +})) + +vi.mock('../../_base/components/variable/use-match-schema-type', () => ({ + __esModule: true, + default: vi.fn(), + getMatchedSchemaType: vi.fn(), +})) + +vi.mock('../../tool/components/tool-form', () => ({ + __esModule: true, + default: vi.fn(({ onChange, onManageInputField }: { onChange: (value: unknown) => void, onManageInputField?: () => void }) => ( +
+ + +
+ )), +})) + +vi.mock('../hooks/use-config', () => ({ + useConfig: vi.fn(), +})) + +const mockUseNodesReadOnly = vi.mocked(useNodesReadOnly) +const mockUseStore = vi.mocked(useStore) +const mockUseConfig = vi.mocked(useConfig) +const mockToolParametersToFormSchemas = vi.mocked(toolParametersToFormSchemas) +const mockUseMatchSchemaType = vi.mocked(useMatchSchemaType) +const mockGetMatchedSchemaType = vi.mocked(getMatchedSchemaType) +const mockToolForm = vi.mocked(ToolForm) + +const setShowInputFieldPanel = vi.fn() + +const createData = (overrides: Partial = {}): DataSourceNodeType => ({ + title: 'Datasource', + desc: '', + type: BlockEnum.DataSource, + plugin_id: 'plugin-1', + provider_type: 'remote', + provider_name: 'provider', + datasource_name: 'source-a', + datasource_label: 'Source A', + datasource_parameters: {}, + datasource_configurations: {}, + fileExtensions: ['pdf'], + ...overrides, +}) + +const panelProps = {} as NodePanelProps['panelProps'] + +describe('data-source/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false, getNodesReadOnly: () => false }) + mockUseStore.mockImplementation((selector) => { + const select = selector as (state: unknown) => unknown + return select({ + dataSourceList: [{ + plugin_id: 'plugin-1', + is_authorized: true, + tools: [{ + name: 'source-a', + parameters: [{ name: 'dataset' }], + }], + }], + pipelineId: 'pipeline-1', + setShowInputFieldPanel, + }) + }) + mockUseConfig.mockReturnValue({ + handleFileExtensionsChange: vi.fn(), + handleParametersChange: vi.fn(), + outputSchema: [], + hasObjectOutput: false, + }) + mockToolParametersToFormSchemas.mockReturnValue([{ name: 'dataset' }] as never) + mockUseMatchSchemaType.mockReturnValue({ schemaTypeDefinitions: {} } as ReturnType) + mockGetMatchedSchemaType.mockReturnValue('') + }) + + it('renders the authorized tool form path and forwards parameter changes', () => { + const handleParametersChange = vi.fn() + mockUseConfig.mockReturnValueOnce({ + handleFileExtensionsChange: vi.fn(), + handleParametersChange, + outputSchema: [{ + name: 'metadata', + value: { type: 'object' }, + }], + hasObjectOutput: true, + }) + mockGetMatchedSchemaType.mockReturnValueOnce('json') + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'tool-form-change' })) + fireEvent.click(screen.getByRole('button', { name: 'manage-input-field' })) + + expect(handleParametersChange).toHaveBeenCalledWith({ dataset: 'docs' }) + expect(setShowInputFieldPanel).toHaveBeenCalledWith(true) + expect(mockToolForm).toHaveBeenCalledWith(expect.objectContaining({ + nodeId: 'data-source-node', + showManageInputField: true, + value: {}, + }), undefined) + expect(screen.getByText('metadata')).toBeInTheDocument() + }) + + it('renders the local-file path and updates supported file extensions', () => { + const handleFileExtensionsChange = vi.fn() + mockUseConfig.mockReturnValueOnce({ + handleFileExtensionsChange, + handleParametersChange: vi.fn(), + outputSchema: [], + hasObjectOutput: false, + }) + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'workflow.nodes.dataSource.supportedFileFormatsPlaceholder' })) + + expect(handleFileExtensionsChange).toHaveBeenCalledWith(['pdf', 'txt']) + expect(screen.getByText(`datasource_type:${VarType.string}`)).toBeInTheDocument() + expect(screen.getByText(`file:${VarType.file}`)).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.branches.spec.tsx b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.branches.spec.tsx new file mode 100644 index 0000000000..09172dd673 --- /dev/null +++ b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.branches.spec.tsx @@ -0,0 +1,308 @@ +import type { CustomRunFormProps, DataSourceNodeType } from '../../types' +import type { NodeRunResult, VarInInspect } from '@/types/workflow' +import { act, renderHook } from '@testing-library/react' +import { useStoreApi } from 'reactflow' +import { useDataSourceStore, useDataSourceStoreWithSelector } from '@/app/components/datasets/documents/create-from-pipeline/data-source/store' +import { BlockEnum, NodeRunningStatus } from '@/app/components/workflow/types' +import { DatasourceType } from '@/models/pipeline' +import { useDatasourceSingleRun } from '@/service/use-pipeline' +import { useInvalidLastRun } from '@/service/use-workflow' +import { fetchNodeInspectVars } from '@/service/workflow' +import { FlowType } from '@/types/common' +import { useNodeDataUpdate, useNodesSyncDraft } from '../../../../hooks' +import useBeforeRunForm from '../use-before-run-form' + +type DataSourceStoreState = { + currentNodeIdRef: { current: string } + currentCredentialId: string + setCurrentCredentialId: (credentialId: string) => void + currentCredentialIdRef: { current: string } + localFileList: Array<{ + file: { + id: string + name: string + type: string + size: number + extension: string + mime_type: string + } + }> + onlineDocuments: Array> + websitePages: Array> + selectedFileIds: string[] + onlineDriveFileList: Array<{ id: string, type: string }> + bucket?: string +} + +type DatasourceSingleRunOptions = { + onError?: () => void + onSettled?: (data?: NodeRunResult) => void +} + +const mockHandleNodeDataUpdate = vi.hoisted(() => vi.fn()) +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockMutateAsync = vi.hoisted(() => vi.fn()) +const mockInvalidLastRun = vi.hoisted(() => vi.fn()) +const mockFetchNodeInspectVars = vi.hoisted(() => vi.fn()) +const mockUseDataSourceStore = vi.hoisted(() => vi.fn()) +const mockUseDataSourceStoreWithSelector = vi.hoisted(() => vi.fn()) + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useStoreApi: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodeDataUpdate: vi.fn(), + useNodesSyncDraft: vi.fn(), +})) + +vi.mock('@/service/use-pipeline', () => ({ + useDatasourceSingleRun: vi.fn(), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidLastRun: vi.fn(), +})) + +vi.mock('@/service/workflow', () => ({ + fetchNodeInspectVars: vi.fn(), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store', () => ({ + useDataSourceStore: vi.fn(), + useDataSourceStoreWithSelector: vi.fn(), +})) + +const mockUseStoreApi = vi.mocked(useStoreApi) +const mockUseNodeDataUpdateHook = vi.mocked(useNodeDataUpdate) +const mockUseNodesSyncDraftHook = vi.mocked(useNodesSyncDraft) +const mockUseDatasourceSingleRunHook = vi.mocked(useDatasourceSingleRun) +const mockUseInvalidLastRunHook = vi.mocked(useInvalidLastRun) +const mockFetchNodeInspectVarsFn = vi.mocked(fetchNodeInspectVars) +const mockUseDataSourceStoreHook = vi.mocked(useDataSourceStore) +const mockUseDataSourceStoreWithSelectorHook = vi.mocked(useDataSourceStoreWithSelector) + +const createData = (overrides: Partial = {}): DataSourceNodeType => ({ + title: 'Datasource', + desc: '', + type: BlockEnum.DataSource, + plugin_id: 'plugin-id', + provider_type: DatasourceType.localFile, + provider_name: 'provider', + datasource_name: 'local-file', + datasource_label: 'Local File', + datasource_parameters: {}, + datasource_configurations: {}, + fileExtensions: ['pdf'], + ...overrides, +}) + +const createProps = (overrides: Partial = {}): CustomRunFormProps => ({ + nodeId: 'data-source-node', + flowId: 'flow-id', + flowType: FlowType.ragPipeline, + payload: createData(), + setRunResult: vi.fn(), + setIsRunAfterSingleRun: vi.fn(), + isPaused: false, + isRunAfterSingleRun: false, + onSuccess: vi.fn(), + onCancel: vi.fn(), + appendNodeInspectVars: vi.fn(), + ...overrides, +}) + +describe('data-source/hooks/use-before-run-form branches', () => { + let dataSourceStoreState: DataSourceStoreState + + beforeEach(() => { + vi.clearAllMocks() + + dataSourceStoreState = { + currentNodeIdRef: { current: 'data-source-node' }, + currentCredentialId: 'credential-1', + setCurrentCredentialId: vi.fn(), + currentCredentialIdRef: { current: 'credential-1' }, + localFileList: [], + onlineDocuments: [], + websitePages: [], + selectedFileIds: [], + onlineDriveFileList: [], + bucket: 'drive-bucket', + } + + mockUseStoreApi.mockReturnValue({ + getState: () => ({ + getNodes: () => [{ id: 'data-source-node', data: { title: 'Datasource' } }], + }), + } as ReturnType) + + mockUseNodeDataUpdateHook.mockReturnValue({ + handleNodeDataUpdate: mockHandleNodeDataUpdate, + handleNodeDataUpdateWithSyncDraft: vi.fn(), + } as ReturnType) + mockUseNodesSyncDraftHook.mockReturnValue({ + handleSyncWorkflowDraft: (...args: unknown[]) => { + mockHandleSyncWorkflowDraft(...args) + const callbacks = args[2] as { onSuccess?: () => void } | undefined + callbacks?.onSuccess?.() + }, + } as ReturnType) + mockUseDatasourceSingleRunHook.mockReturnValue({ + mutateAsync: (...args: unknown[]) => mockMutateAsync(...args), + isPending: false, + } as ReturnType) + mockUseInvalidLastRunHook.mockReturnValue(mockInvalidLastRun) + mockFetchNodeInspectVarsFn.mockImplementation((...args: unknown[]) => mockFetchNodeInspectVars(...args)) + mockUseDataSourceStoreHook.mockImplementation(() => mockUseDataSourceStore()) + mockUseDataSourceStoreWithSelectorHook.mockImplementation(selector => + mockUseDataSourceStoreWithSelector(selector as unknown as (state: DataSourceStoreState) => unknown)) + + mockUseDataSourceStore.mockImplementation(() => ({ + getState: () => dataSourceStoreState, + })) + mockUseDataSourceStoreWithSelector.mockImplementation((selector: (state: DataSourceStoreState) => unknown) => + selector(dataSourceStoreState)) + mockFetchNodeInspectVars.mockResolvedValue([{ name: 'metadata' }] as VarInInspect[]) + }) + + it('derives disabled states for online documents and website crawl sources', () => { + const { result, rerender } = renderHook( + ({ payload }) => useBeforeRunForm(createProps({ payload })), + { + initialProps: { + payload: createData({ provider_type: DatasourceType.onlineDocument }), + }, + }, + ) + + expect(result.current.startRunBtnDisabled).toBe(true) + + dataSourceStoreState.onlineDocuments = [{ + workspace_id: 'workspace-1', + id: 'doc-1', + title: 'Document', + }] + rerender({ payload: createData({ provider_type: DatasourceType.onlineDocument }) }) + expect(result.current.startRunBtnDisabled).toBe(false) + + rerender({ payload: createData({ provider_type: DatasourceType.websiteCrawl }) }) + expect(result.current.startRunBtnDisabled).toBe(true) + + dataSourceStoreState.websitePages = [{ url: 'https://example.com' }] + rerender({ payload: createData({ provider_type: DatasourceType.websiteCrawl }) }) + expect(result.current.startRunBtnDisabled).toBe(false) + }) + + it('returns the settled run result directly when chained single-run execution should stop', async () => { + dataSourceStoreState.localFileList = [{ + file: { + id: 'file-1', + name: 'doc.pdf', + type: 'document', + size: 12, + extension: 'pdf', + mime_type: 'application/pdf', + }, + }] + + mockMutateAsync.mockImplementation((_payload: unknown, options: DatasourceSingleRunOptions) => { + options.onSettled?.({ status: NodeRunningStatus.Succeeded } as NodeRunResult) + return Promise.resolve(undefined) + }) + + const props = createProps({ + isRunAfterSingleRun: true, + payload: createData({ + _singleRunningStatus: NodeRunningStatus.Running, + } as Partial), + }) + const { result } = renderHook(() => useBeforeRunForm(props)) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(props.setRunResult).toHaveBeenCalledWith({ status: NodeRunningStatus.Succeeded }) + expect(mockFetchNodeInspectVars).not.toHaveBeenCalled() + expect(props.onSuccess).not.toHaveBeenCalled() + }) + + it('builds online document datasource info before running', async () => { + dataSourceStoreState.onlineDocuments = [{ + workspace_id: 'workspace-1', + id: 'doc-1', + title: 'Document', + url: 'https://example.com/doc', + }] + + mockMutateAsync.mockImplementation((payload: unknown, options: DatasourceSingleRunOptions) => { + options.onSettled?.({ status: NodeRunningStatus.Succeeded } as NodeRunResult) + return Promise.resolve(payload) + }) + + const { result } = renderHook(() => useBeforeRunForm(createProps({ + payload: createData({ provider_type: DatasourceType.onlineDocument }), + }))) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith(expect.objectContaining({ + datasource_type: DatasourceType.onlineDocument, + datasource_info: { + workspace_id: 'workspace-1', + page: { + id: 'doc-1', + title: 'Document', + url: 'https://example.com/doc', + }, + credential_id: 'credential-1', + }, + }), expect.any(Object)) + }) + + it('builds website crawl datasource info and skips the failure update while paused', async () => { + dataSourceStoreState.websitePages = [{ + url: 'https://example.com', + title: 'Example', + }] + + mockMutateAsync.mockImplementation((payload: unknown, options: DatasourceSingleRunOptions) => { + options.onError?.() + return Promise.resolve(payload) + }) + + const { result } = renderHook(() => useBeforeRunForm(createProps({ + isPaused: true, + payload: createData({ provider_type: DatasourceType.websiteCrawl }), + }))) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(mockMutateAsync).toHaveBeenCalledWith(expect.objectContaining({ + datasource_type: DatasourceType.websiteCrawl, + datasource_info: { + url: 'https://example.com', + title: 'Example', + credential_id: 'credential-1', + }, + }), expect.any(Object)) + expect(mockInvalidLastRun).toHaveBeenCalled() + expect(mockHandleNodeDataUpdate).not.toHaveBeenCalledWith(expect.objectContaining({ + data: expect.objectContaining({ + _singleRunningStatus: NodeRunningStatus.Failed, + }), + })) + }) +}) diff --git a/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.spec.tsx b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.spec.tsx new file mode 100644 index 0000000000..b4e79b3334 --- /dev/null +++ b/web/app/components/workflow/nodes/data-source/hooks/__tests__/use-before-run-form.spec.tsx @@ -0,0 +1,307 @@ +import type { CustomRunFormProps, DataSourceNodeType } from '../../types' +import type { NodeRunResult, VarInInspect } from '@/types/workflow' +import { act, renderHook } from '@testing-library/react' +import { useStoreApi } from 'reactflow' +import { useDataSourceStore, useDataSourceStoreWithSelector } from '@/app/components/datasets/documents/create-from-pipeline/data-source/store' +import { BlockEnum, NodeRunningStatus } from '@/app/components/workflow/types' +import { DatasourceType } from '@/models/pipeline' +import { useDatasourceSingleRun } from '@/service/use-pipeline' +import { useInvalidLastRun } from '@/service/use-workflow' +import { fetchNodeInspectVars } from '@/service/workflow' +import { TransferMethod } from '@/types/app' +import { FlowType } from '@/types/common' +import { useNodeDataUpdate, useNodesSyncDraft } from '../../../../hooks' +import useBeforeRunForm from '../use-before-run-form' + +type DataSourceStoreState = { + currentNodeIdRef: { current: string } + currentCredentialId: string + setCurrentCredentialId: (credentialId: string) => void + currentCredentialIdRef: { current: string } + localFileList: Array<{ + file: { + id: string + name: string + type: string + size: number + extension: string + mime_type: string + } + }> + onlineDocuments: Array> + websitePages: Array> + selectedFileIds: string[] + onlineDriveFileList: Array<{ id: string, type: string }> + bucket?: string +} + +type DatasourceSingleRunOptions = { + onError?: () => void + onSettled?: (data?: NodeRunResult) => void +} + +const mockHandleNodeDataUpdate = vi.hoisted(() => vi.fn()) +const mockHandleSyncWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockMutateAsync = vi.hoisted(() => vi.fn()) +const mockInvalidLastRun = vi.hoisted(() => vi.fn()) +const mockFetchNodeInspectVars = vi.hoisted(() => vi.fn()) +const mockUseDataSourceStore = vi.hoisted(() => vi.fn()) +const mockUseDataSourceStoreWithSelector = vi.hoisted(() => vi.fn()) + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useStoreApi: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodeDataUpdate: vi.fn(), + useNodesSyncDraft: vi.fn(), +})) + +vi.mock('@/service/use-pipeline', () => ({ + useDatasourceSingleRun: vi.fn(), +})) + +vi.mock('@/service/use-workflow', () => ({ + useInvalidLastRun: vi.fn(), +})) + +vi.mock('@/service/workflow', () => ({ + fetchNodeInspectVars: vi.fn(), +})) + +vi.mock('@/app/components/datasets/documents/create-from-pipeline/data-source/store', () => ({ + useDataSourceStore: vi.fn(), + useDataSourceStoreWithSelector: vi.fn(), +})) + +const mockUseStoreApi = vi.mocked(useStoreApi) +const mockUseNodeDataUpdateHook = vi.mocked(useNodeDataUpdate) +const mockUseNodesSyncDraftHook = vi.mocked(useNodesSyncDraft) +const mockUseDatasourceSingleRunHook = vi.mocked(useDatasourceSingleRun) +const mockUseInvalidLastRunHook = vi.mocked(useInvalidLastRun) +const mockFetchNodeInspectVarsFn = vi.mocked(fetchNodeInspectVars) +const mockUseDataSourceStoreHook = vi.mocked(useDataSourceStore) +const mockUseDataSourceStoreWithSelectorHook = vi.mocked(useDataSourceStoreWithSelector) + +const createData = (overrides: Partial = {}): DataSourceNodeType => ({ + title: 'Datasource', + desc: '', + type: BlockEnum.DataSource, + plugin_id: 'plugin-id', + provider_type: DatasourceType.localFile, + provider_name: 'provider', + datasource_name: 'local-file', + datasource_label: 'Local File', + datasource_parameters: {}, + datasource_configurations: {}, + fileExtensions: ['pdf'], + ...overrides, +}) + +const createProps = (overrides: Partial = {}): CustomRunFormProps => ({ + nodeId: 'data-source-node', + flowId: 'flow-id', + flowType: FlowType.ragPipeline, + payload: createData(), + setRunResult: vi.fn(), + setIsRunAfterSingleRun: vi.fn(), + isPaused: false, + isRunAfterSingleRun: false, + onSuccess: vi.fn(), + onCancel: vi.fn(), + appendNodeInspectVars: vi.fn(), + ...overrides, +}) + +describe('data-source/hooks/use-before-run-form', () => { + let dataSourceStoreState: DataSourceStoreState + + beforeEach(() => { + vi.clearAllMocks() + + dataSourceStoreState = { + currentNodeIdRef: { current: 'data-source-node' }, + currentCredentialId: 'credential-1', + setCurrentCredentialId: vi.fn(), + currentCredentialIdRef: { current: 'credential-1' }, + localFileList: [], + onlineDocuments: [], + websitePages: [], + selectedFileIds: [], + onlineDriveFileList: [], + bucket: 'drive-bucket', + } + + mockUseStoreApi.mockReturnValue({ + getState: () => ({ + getNodes: () => [ + { + id: 'data-source-node', + data: { + title: 'Datasource', + }, + }, + ], + }), + } as ReturnType) + + mockUseNodeDataUpdateHook.mockReturnValue({ + handleNodeDataUpdate: mockHandleNodeDataUpdate, + handleNodeDataUpdateWithSyncDraft: vi.fn(), + } as ReturnType) + mockUseNodesSyncDraftHook.mockReturnValue({ + handleSyncWorkflowDraft: (...args: unknown[]) => { + mockHandleSyncWorkflowDraft(...args) + const callbacks = args[2] as { onSuccess?: () => void } | undefined + callbacks?.onSuccess?.() + }, + } as ReturnType) + mockUseDatasourceSingleRunHook.mockReturnValue({ + mutateAsync: (...args: unknown[]) => mockMutateAsync(...args), + isPending: false, + } as ReturnType) + mockUseInvalidLastRunHook.mockReturnValue(mockInvalidLastRun) + mockFetchNodeInspectVarsFn.mockImplementation((...args: unknown[]) => mockFetchNodeInspectVars(...args)) + mockUseDataSourceStoreHook.mockImplementation(() => mockUseDataSourceStore()) + mockUseDataSourceStoreWithSelectorHook.mockImplementation(selector => + mockUseDataSourceStoreWithSelector(selector as unknown as (state: DataSourceStoreState) => unknown)) + + mockUseDataSourceStore.mockImplementation(() => ({ + getState: () => dataSourceStoreState, + })) + mockUseDataSourceStoreWithSelector.mockImplementation((selector: (state: DataSourceStoreState) => unknown) => + selector(dataSourceStoreState)) + mockFetchNodeInspectVars.mockResolvedValue([{ name: 'metadata' }] as VarInInspect[]) + }) + + it('derives the run button disabled state from the selected datasource payload', () => { + const { result, rerender } = renderHook( + ({ payload }) => useBeforeRunForm(createProps({ payload })), + { + initialProps: { + payload: createData(), + }, + }, + ) + + expect(result.current.startRunBtnDisabled).toBe(true) + + dataSourceStoreState.localFileList = [{ + file: { + id: 'file-1', + name: 'doc.pdf', + type: 'document', + size: 12, + extension: 'pdf', + mime_type: 'application/pdf', + }, + }] + rerender({ payload: createData() }) + expect(result.current.startRunBtnDisabled).toBe(false) + + dataSourceStoreState.selectedFileIds = [] + rerender({ + payload: createData({ + provider_type: DatasourceType.onlineDrive, + }), + }) + expect(result.current.startRunBtnDisabled).toBe(true) + }) + + it('syncs the draft, runs the datasource, and appends inspect vars on success', async () => { + dataSourceStoreState.localFileList = [{ + file: { + id: 'file-1', + name: 'doc.pdf', + type: 'document', + size: 12, + extension: 'pdf', + mime_type: 'application/pdf', + }, + }] + + mockMutateAsync.mockImplementation((payload: unknown, options: DatasourceSingleRunOptions) => { + options.onSettled?.({ status: NodeRunningStatus.Succeeded } as NodeRunResult) + return Promise.resolve(payload) + }) + + const props = createProps() + const { result } = renderHook(() => useBeforeRunForm(props)) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(props.setIsRunAfterSingleRun).toHaveBeenCalledWith(true) + expect(mockHandleNodeDataUpdate).toHaveBeenNthCalledWith(1, { + id: 'data-source-node', + data: expect.objectContaining({ + _singleRunningStatus: NodeRunningStatus.Running, + }), + }) + expect(mockMutateAsync).toHaveBeenCalledWith(expect.objectContaining({ + pipeline_id: 'flow-id', + start_node_id: 'data-source-node', + datasource_type: DatasourceType.localFile, + datasource_info: expect.objectContaining({ + related_id: 'file-1', + transfer_method: TransferMethod.local_file, + }), + }), expect.any(Object)) + expect(mockFetchNodeInspectVars).toHaveBeenCalledWith(FlowType.ragPipeline, 'flow-id', 'data-source-node') + expect(props.appendNodeInspectVars).toHaveBeenCalledWith('data-source-node', [{ name: 'metadata' }], [ + { + id: 'data-source-node', + data: { + title: 'Datasource', + }, + }, + ]) + expect(props.onSuccess).toHaveBeenCalled() + expect(mockHandleNodeDataUpdate).toHaveBeenLastCalledWith({ + id: 'data-source-node', + data: expect.objectContaining({ + _isSingleRun: false, + _singleRunningStatus: NodeRunningStatus.Succeeded, + }), + }) + }) + + it('marks the last run invalid and updates the node to failed when the single run errors', async () => { + dataSourceStoreState.selectedFileIds = ['drive-file-1'] + dataSourceStoreState.onlineDriveFileList = [{ + id: 'drive-file-1', + type: 'file', + }] + + mockMutateAsync.mockImplementation((_payload: unknown, options: DatasourceSingleRunOptions) => { + options.onError?.() + return Promise.resolve(undefined) + }) + + const { result } = renderHook(() => useBeforeRunForm(createProps({ + payload: createData({ + provider_type: DatasourceType.onlineDrive, + }), + }))) + + await act(async () => { + result.current.handleRunWithSyncDraft() + await Promise.resolve() + }) + + expect(mockInvalidLastRun).toHaveBeenCalled() + expect(mockHandleNodeDataUpdate).toHaveBeenLastCalledWith({ + id: 'data-source-node', + data: expect.objectContaining({ + _isSingleRun: false, + _singleRunningStatus: NodeRunningStatus.Failed, + }), + }) + }) +}) diff --git a/web/app/components/workflow/nodes/document-extractor/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/document-extractor/__tests__/node.spec.tsx new file mode 100644 index 0000000000..2044d7e6b9 --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/__tests__/node.spec.tsx @@ -0,0 +1,74 @@ +import type { DocExtractorNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import { useNodes } from 'reactflow' +import { BlockEnum } from '@/app/components/workflow/types' +import Node from '../node' + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useNodes: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/variable-label', () => ({ + VariableLabelInNode: ({ + variables, + nodeTitle, + nodeType, + }: { + variables: string[] + nodeTitle?: string + nodeType?: BlockEnum + }) =>
{`${nodeTitle}:${nodeType}:${variables.join('.')}`}
, +})) + +const mockUseNodes = vi.mocked(useNodes) + +const createData = (overrides: Partial = {}): DocExtractorNodeType => ({ + title: 'Document Extractor', + desc: '', + type: BlockEnum.DocExtractor, + variable_selector: ['node-1', 'files'], + is_array_file: false, + ...overrides, +}) + +describe('document-extractor/node', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodes.mockReturnValue([ + { + id: 'node-1', + data: { + title: 'Input Files', + type: BlockEnum.Start, + }, + }, + ] as ReturnType) + }) + + it('renders nothing when no input variable is configured', () => { + const { container } = render( + , + ) + + expect(container).toBeEmptyDOMElement() + }) + + it('renders the selected input variable label', () => { + render( + , + ) + + expect(screen.getByText('workflow.nodes.docExtractor.inputVar')).toBeInTheDocument() + expect(screen.getByText('Input Files:start:node-1.files')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/document-extractor/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/document-extractor/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..06512f94c6 --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/__tests__/panel.spec.tsx @@ -0,0 +1,144 @@ +import type { ReactNode } from 'react' +import type { DocExtractorNodeType } from '../types' +import type { PanelProps } from '@/types/workflow' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { LanguagesSupported } from '@/i18n-config/language' +import { BlockEnum } from '../../../types' +import Panel from '../panel' +import useConfig from '../use-config' + +let mockLocale = 'en-US' + +vi.mock('@/app/components/workflow/nodes/_base/components/field', () => ({ + __esModule: true, + default: ({ title, children }: { title: ReactNode, children: ReactNode }) => ( +
+
{title}
+ {children} +
+ ), +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/output-vars', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) =>
{children}
, + VarItem: ({ name, type }: { name: string, type: string }) =>
{`${name}:${type}`}
, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/split', () => ({ + __esModule: true, + default: () =>
split
, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({ + __esModule: true, + default: ({ + onChange, + }: { + onChange: (value: string[]) => void + }) => , +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-help-link', () => ({ + useNodeHelpLink: () => 'https://docs.example.com/document-extractor', +})) + +vi.mock('@/service/use-common', () => ({ + useFileSupportTypes: () => ({ + data: { + allowed_extensions: ['PDF', 'md', 'md', 'DOCX'], + }, + }), +})) + +vi.mock('@/context/i18n', () => ({ + useLocale: () => mockLocale, +})) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const mockUseConfig = vi.mocked(useConfig) + +const createData = (overrides: Partial = {}): DocExtractorNodeType => ({ + title: 'Document Extractor', + desc: '', + type: BlockEnum.DocExtractor, + variable_selector: ['node-1', 'files'], + is_array_file: false, + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + handleVarChanges: vi.fn(), + filterVar: () => true, + ...overrides, +}) + +const panelProps: PanelProps = { + getInputVars: vi.fn(() => []), + toVarInputs: vi.fn(() => []), + runInputData: {}, + runInputDataRef: { current: {} }, + setRunInputData: vi.fn(), + runResult: null, +} + +describe('document-extractor/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockLocale = 'en-US' + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('wires variable changes and renders supported file types for english locales', async () => { + const user = userEvent.setup() + const handleVarChanges = vi.fn() + + mockUseConfig.mockReturnValueOnce(createConfigResult({ + inputs: createData({ is_array_file: false }), + handleVarChanges, + })) + + render( + , + ) + + await user.click(screen.getByRole('button', { name: 'pick-file-var' })) + + expect(handleVarChanges).toHaveBeenCalledWith(['node-1', 'files']) + expect(screen.getByText('workflow.nodes.docExtractor.supportFileTypes:{"types":"pdf, markdown, docx"}')).toBeInTheDocument() + expect(screen.getByRole('link', { name: 'workflow.nodes.docExtractor.learnMore' })).toHaveAttribute( + 'href', + 'https://docs.example.com/document-extractor', + ) + expect(screen.getByText('text:string')).toBeInTheDocument() + }) + + it('uses chinese separators and array output types when the input is an array of files', () => { + mockLocale = LanguagesSupported[1] + mockUseConfig.mockReturnValueOnce(createConfigResult({ + inputs: createData({ is_array_file: true }), + })) + + render( + , + ) + + expect(screen.getByText('workflow.nodes.docExtractor.supportFileTypes:{"types":"pdf、 markdown、 docx"}')).toBeInTheDocument() + expect(screen.getByText('text:array[string]')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/document-extractor/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/document-extractor/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..d988b2751d --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/__tests__/use-config.spec.ts @@ -0,0 +1,100 @@ +import type { DocExtractorNodeType } from '../types' +import { renderHook } from '@testing-library/react' +import { useStoreApi } from 'reactflow' +import { + useIsChatMode, + useNodesReadOnly, + useWorkflow, + useWorkflowVariables, +} from '@/app/components/workflow/hooks' +import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import useConfig from '../use-config' + +const mockUseStoreApi = vi.mocked(useStoreApi) +const mockUseNodesReadOnly = vi.mocked(useNodesReadOnly) +const mockUseNodeCrud = vi.mocked(useNodeCrud) +const mockUseIsChatMode = vi.mocked(useIsChatMode) +const mockUseWorkflow = vi.mocked(useWorkflow) +const mockUseWorkflowVariables = vi.mocked(useWorkflowVariables) + +vi.mock('reactflow', async () => { + const actual = await vi.importActual('reactflow') + return { + ...actual, + useStoreApi: vi.fn(), + } +}) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useIsChatMode: vi.fn(), + useNodesReadOnly: vi.fn(), + useWorkflow: vi.fn(), + useWorkflowVariables: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +const setInputs = vi.fn() +const getCurrentVariableType = vi.fn() + +const createData = (overrides: Partial = {}): DocExtractorNodeType => ({ + title: 'Document Extractor', + desc: '', + type: BlockEnum.DocExtractor, + variable_selector: ['node-1', 'files'], + is_array_file: false, + ...overrides, +}) + +describe('document-extractor/use-config', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false, getNodesReadOnly: () => false }) + mockUseIsChatMode.mockReturnValue(false) + mockUseWorkflow.mockReturnValue({ + getBeforeNodesInSameBranch: vi.fn(() => [{ id: 'start-node' }]), + } as unknown as ReturnType) + mockUseWorkflowVariables.mockReturnValue({ + getCurrentVariableType, + } as unknown as ReturnType) + mockUseStoreApi.mockReturnValue({ + getState: () => ({ + getNodes: () => [ + { id: 'doc-node', parentId: 'loop-1', data: { type: BlockEnum.DocExtractor } }, + { id: 'loop-1', data: { type: BlockEnum.Loop } }, + ], + }), + } as ReturnType) + mockUseNodeCrud.mockReturnValue({ + inputs: createData(), + setInputs, + } as ReturnType) + }) + + it('updates the selected variable and tracks array file output types', () => { + getCurrentVariableType.mockReturnValue(VarType.arrayFile) + + const { result } = renderHook(() => useConfig('doc-node', createData())) + + result.current.handleVarChanges(['node-2', 'files']) + + expect(getCurrentVariableType).toHaveBeenCalled() + expect(setInputs).toHaveBeenCalledWith(expect.objectContaining({ + variable_selector: ['node-2', 'files'], + is_array_file: true, + })) + }) + + it('only accepts file variables in the picker filter', () => { + const { result } = renderHook(() => useConfig('doc-node', createData())) + + expect(result.current.readOnly).toBe(false) + expect(result.current.filterVar({ type: VarType.file } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.arrayFile } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.string } as never)).toBe(false) + }) +}) diff --git a/web/app/components/workflow/nodes/document-extractor/__tests__/use-single-run-form-params.spec.ts b/web/app/components/workflow/nodes/document-extractor/__tests__/use-single-run-form-params.spec.ts new file mode 100644 index 0000000000..935118f26e --- /dev/null +++ b/web/app/components/workflow/nodes/document-extractor/__tests__/use-single-run-form-params.spec.ts @@ -0,0 +1,43 @@ +import type { DocExtractorNodeType } from '../types' +import { renderHook } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import useSingleRunFormParams from '../use-single-run-form-params' + +const createData = (overrides: Partial = {}): DocExtractorNodeType => ({ + title: 'Document Extractor', + desc: '', + type: BlockEnum.DocExtractor, + variable_selector: ['start', 'files'], + is_array_file: false, + ...overrides, +}) + +describe('document-extractor/use-single-run-form-params', () => { + it('exposes a single files form and updates run input values', () => { + const setRunInputData = vi.fn() + + const { result } = renderHook(() => useSingleRunFormParams({ + id: 'doc-node', + payload: createData(), + runInputData: { files: ['old-file'] }, + runInputDataRef: { current: {} }, + getInputVars: () => [], + setRunInputData, + toVarInputs: () => [], + })) + + expect(result.current.forms).toHaveLength(1) + expect(result.current.forms[0].inputs).toEqual([ + expect.objectContaining({ + variable: 'files', + required: true, + }), + ]) + + result.current.forms[0].onChange({ files: ['new-file'] }) + + expect(setRunInputData).toHaveBeenCalledWith({ files: ['new-file'] }) + expect(result.current.getDependentVars()).toEqual([['start', 'files']]) + expect(result.current.getDependentVar('files')).toEqual(['start', 'files']) + }) +}) diff --git a/web/app/components/workflow/nodes/end/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/end/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..b4218e338b --- /dev/null +++ b/web/app/components/workflow/nodes/end/__tests__/panel.spec.tsx @@ -0,0 +1,58 @@ +import type { EndNodeType } from '../types' +import type { PanelProps } from '@/types/workflow' +import { fireEvent, render, screen } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' + +const mockUseConfig = vi.hoisted(() => vi.fn()) + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +const createData = (overrides: Partial = {}): EndNodeType => ({ + title: 'End', + desc: '', + type: BlockEnum.End, + outputs: [], + ...overrides, +}) + +describe('EndPanel', () => { + const handleVarListChange = vi.fn() + const handleAddVariable = vi.fn() + + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue({ + readOnly: false, + inputs: createData(), + handleVarListChange, + handleAddVariable, + }) + }) + + it('should show the output field and allow adding output variables when writable', () => { + render() + + expect(screen.getByText('workflow.nodes.end.output.variable')).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('add-button')) + + expect(handleAddVariable).toHaveBeenCalledTimes(1) + }) + + it('should hide the add action when the node is read-only', () => { + mockUseConfig.mockReturnValue({ + readOnly: true, + inputs: createData(), + handleVarListChange, + handleAddVariable, + }) + + render() + + expect(screen.queryByTestId('add-button')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/end/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/end/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..8d0cbff547 --- /dev/null +++ b/web/app/components/workflow/nodes/end/__tests__/use-config.spec.ts @@ -0,0 +1,76 @@ +import type { EndNodeType } from '../types' +import { act, renderHook } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import useConfig from '../use-config' + +const mockUseNodesReadOnly = vi.hoisted(() => vi.fn()) +const mockUseNodeCrud = vi.hoisted(() => vi.fn()) +const mockUseVarList = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: () => mockUseNodesReadOnly(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseNodeCrud(...args), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseVarList(...args), +})) + +const createPayload = (overrides: Partial = {}): EndNodeType => ({ + title: 'End', + desc: '', + type: BlockEnum.End, + outputs: [], + ...overrides, +}) + +describe('end/use-config', () => { + const mockHandleVarListChange = vi.fn() + const mockHandleAddVariable = vi.fn() + const mockSetInputs = vi.fn() + let currentInputs: EndNodeType + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: true }) + mockUseNodeCrud.mockReturnValue({ + inputs: currentInputs, + setInputs: mockSetInputs, + }) + mockUseVarList.mockReturnValue({ + handleVarListChange: mockHandleVarListChange, + handleAddVariable: mockHandleAddVariable, + }) + }) + + it('should build var-list handlers against outputs and surface the readonly state', () => { + const { result } = renderHook(() => useConfig('end-node', currentInputs)) + const config = mockUseVarList.mock.calls[0][0] as { setInputs: (inputs: EndNodeType) => void } + + expect(mockUseVarList).toHaveBeenCalledWith(expect.objectContaining({ + inputs: currentInputs, + setInputs: expect.any(Function), + varKey: 'outputs', + })) + expect(result.current.readOnly).toBe(true) + expect(result.current.handleVarListChange).toBe(mockHandleVarListChange) + expect(result.current.handleAddVariable).toBe(mockHandleAddVariable) + + act(() => { + config.setInputs(createPayload({ + outputs: currentInputs.outputs, + })) + }) + + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + outputs: currentInputs.outputs, + })) + }) +}) diff --git a/web/app/components/workflow/nodes/http/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/http/__tests__/node.spec.tsx new file mode 100644 index 0000000000..428aabd99e --- /dev/null +++ b/web/app/components/workflow/nodes/http/__tests__/node.spec.tsx @@ -0,0 +1,67 @@ +import type { HttpNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import { BlockEnum } from '@/app/components/workflow/types' +import Node from '../node' +import { AuthorizationType, BodyType, Method } from '../types' + +const mockReadonlyInputWithSelectVar = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/workflow/nodes/_base/components/readonly-input-with-select-var', () => ({ + __esModule: true, + default: (props: { value: string, nodeId: string, className?: string }) => { + mockReadonlyInputWithSelectVar(props) + return
{props.value}
+ }, +})) + +const createData = (overrides: Partial = {}): HttpNodeType => ({ + title: 'HTTP Request', + desc: '', + type: BlockEnum.HttpRequest, + variables: [], + method: Method.get, + url: 'https://api.example.com', + authorization: { type: AuthorizationType.none }, + headers: '', + params: '', + body: { type: BodyType.none, data: [] }, + timeout: { connect: 5, read: 10, write: 15 }, + ssl_verify: true, + ...overrides, +}) + +describe('http/node', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders the request method and forwards the URL to the readonly input', () => { + render( + , + ) + + expect(screen.getByText('post')).toBeInTheDocument() + expect(screen.getByTestId('readonly-input')).toHaveTextContent('https://api.example.com/users') + expect(mockReadonlyInputWithSelectVar).toHaveBeenCalledWith(expect.objectContaining({ + nodeId: 'http-node', + value: 'https://api.example.com/users', + })) + }) + + it('renders nothing when the request URL is empty', () => { + const { container } = render( + , + ) + + expect(container).toBeEmptyDOMElement() + }) +}) diff --git a/web/app/components/workflow/nodes/http/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/http/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..e8ce5ac5c3 --- /dev/null +++ b/web/app/components/workflow/nodes/http/__tests__/panel.spec.tsx @@ -0,0 +1,295 @@ +import type { ReactNode } from 'react' +import type { HttpNodeType } from '../types' +import type { NodePanelProps } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { BlockEnum } from '@/app/components/workflow/types' +import Panel from '../panel' +import { AuthorizationType, BodyPayloadValueType, BodyType, Method } from '../types' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockAuthorizationModal = vi.hoisted(() => vi.fn()) +const mockCurlPanel = vi.hoisted(() => vi.fn()) +const mockApiInput = vi.hoisted(() => vi.fn()) +const mockKeyValue = vi.hoisted(() => vi.fn()) +const mockEditBody = vi.hoisted(() => vi.fn()) +const mockTimeout = vi.hoisted(() => vi.fn()) + +type ApiInputProps = { + method: Method + url: string + onMethodChange: (method: Method) => void + onUrlChange: (url: string) => void +} + +type KeyValueProps = { + nodeId: string + list: Array<{ key: string, value: string }> + onChange: (value: Array<{ key: string, value: string }>) => void + onAdd: () => void +} + +type EditBodyProps = { + payload: HttpNodeType['body'] + onChange: (value: HttpNodeType['body']) => void +} + +type TimeoutProps = { + payload: HttpNodeType['timeout'] + onChange: (value: HttpNodeType['timeout']) => void +} + +vi.mock('../use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../components/authorization', () => ({ + __esModule: true, + default: (props: { nodeId: string, payload: HttpNodeType['authorization'], onChange: (value: HttpNodeType['authorization']) => void, onHide: () => void }) => { + mockAuthorizationModal(props) + return
{props.nodeId}
+ }, +})) + +vi.mock('../components/curl-panel', () => ({ + __esModule: true, + default: (props: { nodeId: string, onHide: () => void, handleCurlImport: (node: HttpNodeType) => void }) => { + mockCurlPanel(props) + return
{props.nodeId}
+ }, +})) + +vi.mock('../components/api-input', () => ({ + __esModule: true, + default: (props: ApiInputProps) => { + mockApiInput(props) + return ( +
+
{`${props.method}:${props.url}`}
+ + +
+ ) + }, +})) + +vi.mock('../components/key-value', () => ({ + __esModule: true, + default: (props: KeyValueProps) => { + mockKeyValue(props) + return ( +
+
{props.list.map(item => `${item.key}:${item.value}`).join(',')}
+ + +
+ ) + }, +})) + +vi.mock('../components/edit-body', () => ({ + __esModule: true, + default: (props: EditBodyProps) => { + mockEditBody(props) + return ( + + ) + }, +})) + +vi.mock('../components/timeout', () => ({ + __esModule: true, + default: (props: TimeoutProps) => { + mockTimeout(props) + return ( + + ) + }, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/output-vars', () => ({ + __esModule: true, + default: ({ children }: { children: ReactNode }) =>
{children}
, + VarItem: ({ name, type }: { name: string, type: string }) =>
{`${name}:${type}`}
, +})) + +const createData = (overrides: Partial = {}): HttpNodeType => ({ + title: 'HTTP Request', + desc: '', + type: BlockEnum.HttpRequest, + variables: [], + method: Method.get, + url: 'https://api.example.com', + authorization: { type: AuthorizationType.none }, + headers: '', + params: '', + body: { type: BodyType.none, data: [] }, + timeout: { connect: 5, read: 10, write: 15 }, + ssl_verify: true, + ...overrides, +}) + +const panelProps = {} as NodePanelProps['panelProps'] + +describe('http/panel', () => { + const handleMethodChange = vi.fn() + const handleUrlChange = vi.fn() + const setHeaders = vi.fn() + const addHeader = vi.fn() + const setParams = vi.fn() + const addParam = vi.fn() + const setBody = vi.fn() + const showAuthorization = vi.fn() + const hideAuthorization = vi.fn() + const setAuthorization = vi.fn() + const setTimeout = vi.fn() + const showCurlPanel = vi.fn() + const hideCurlPanel = vi.fn() + const handleCurlImport = vi.fn() + const handleSSLVerifyChange = vi.fn() + + const createConfigResult = (overrides: Record = {}) => ({ + readOnly: false, + isDataReady: true, + inputs: createData({ + authorization: { type: AuthorizationType.apiKey, config: null }, + }), + handleMethodChange, + handleUrlChange, + headers: [{ key: 'accept', value: 'application/json' }], + setHeaders, + addHeader, + params: [{ key: 'page', value: '1' }], + setParams, + addParam, + setBody, + isShowAuthorization: false, + showAuthorization, + hideAuthorization, + setAuthorization, + setTimeout, + isShowCurlPanel: false, + showCurlPanel, + hideCurlPanel, + handleCurlImport, + handleSSLVerifyChange, + ...overrides, + }) + + beforeEach(() => { + vi.clearAllMocks() + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders request fields, forwards child changes, and wires header operations', async () => { + const user = userEvent.setup() + + render( + , + ) + + expect(screen.getByText('get:https://api.example.com')).toBeInTheDocument() + expect(screen.getByText('body:string')).toBeInTheDocument() + expect(screen.getByText('status_code:number')).toBeInTheDocument() + expect(screen.getByText('headers:object')).toBeInTheDocument() + expect(screen.getByText('files:Array[File]')).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'emit-method-change' })) + await user.click(screen.getByRole('button', { name: 'emit-url-change' })) + await user.click(screen.getAllByRole('button', { name: 'emit-key-value-change' })[0]!) + await user.click(screen.getAllByRole('button', { name: 'emit-key-value-add' })[0]!) + await user.click(screen.getAllByRole('button', { name: 'emit-key-value-change' })[1]!) + await user.click(screen.getAllByRole('button', { name: 'emit-key-value-add' })[1]!) + await user.click(screen.getByRole('button', { name: 'emit-body-change' })) + await user.click(screen.getByRole('button', { name: 'emit-timeout-change' })) + await user.click(screen.getByText('workflow.nodes.http.authorization.authorization')) + await user.click(screen.getByText('workflow.nodes.http.curl.title')) + await user.click(screen.getByRole('switch')) + + expect(handleMethodChange).toHaveBeenCalledWith(Method.post) + expect(handleUrlChange).toHaveBeenCalledWith('https://changed.example.com') + expect(setHeaders).toHaveBeenCalledWith([{ key: 'x-token', value: '123' }]) + expect(addHeader).toHaveBeenCalledTimes(1) + expect(setParams).toHaveBeenCalledWith([{ key: 'x-token', value: '123' }]) + expect(addParam).toHaveBeenCalledTimes(1) + expect(setBody).toHaveBeenCalledWith({ + type: BodyType.json, + data: [{ type: 'text', value: '{"hello":"world"}' }], + }) + expect(setTimeout).toHaveBeenCalledWith(expect.objectContaining({ connect: 9 })) + expect(showAuthorization).toHaveBeenCalledTimes(1) + expect(showCurlPanel).toHaveBeenCalledTimes(1) + expect(handleSSLVerifyChange).toHaveBeenCalledWith(false) + expect(mockApiInput).toHaveBeenCalledWith(expect.objectContaining({ + method: Method.get, + url: 'https://api.example.com', + })) + }) + + it('returns null before the config data is ready', () => { + mockUseConfig.mockReturnValueOnce(createConfigResult({ isDataReady: false })) + + const { container } = render( + , + ) + + expect(container).toBeEmptyDOMElement() + }) + + it('renders auth and curl panels only when writable and toggled on', () => { + mockUseConfig.mockReturnValueOnce(createConfigResult({ + isShowAuthorization: true, + isShowCurlPanel: true, + })) + + const { rerender } = render( + , + ) + + expect(screen.getByTestId('authorization-modal')).toHaveTextContent('http-node') + expect(screen.getByTestId('curl-panel')).toHaveTextContent('http-node') + + mockUseConfig.mockReturnValueOnce(createConfigResult({ + readOnly: true, + isShowAuthorization: true, + isShowCurlPanel: true, + })) + + rerender( + , + ) + + expect(screen.queryByTestId('authorization-modal')).not.toBeInTheDocument() + expect(screen.queryByTestId('curl-panel')).not.toBeInTheDocument() + expect(screen.getByRole('switch')).toHaveAttribute('aria-disabled', 'true') + }) +}) diff --git a/web/app/components/workflow/nodes/http/__tests__/use-config.spec.ts b/web/app/components/workflow/nodes/http/__tests__/use-config.spec.ts new file mode 100644 index 0000000000..e771122e28 --- /dev/null +++ b/web/app/components/workflow/nodes/http/__tests__/use-config.spec.ts @@ -0,0 +1,271 @@ +import type { HttpNodeType } from '../types' +import { act, renderHook, waitFor } from '@testing-library/react' +import { useNodesReadOnly } from '@/app/components/workflow/hooks' +import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' +import { useStore } from '@/app/components/workflow/store' +import { BlockEnum, VarType } from '@/app/components/workflow/types' +import useVarList from '../../_base/hooks/use-var-list' +import useKeyValueList from '../hooks/use-key-value-list' +import { APIType, AuthorizationType, BodyPayloadValueType, BodyType, Method } from '../types' +import useConfig from '../use-config' + +vi.mock('@/app/components/workflow/hooks', () => ({ + useNodesReadOnly: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-node-crud', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-var-list', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('../hooks/use-key-value-list', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: vi.fn(), +})) + +const mockUseNodesReadOnly = vi.mocked(useNodesReadOnly) +const mockUseNodeCrud = vi.mocked(useNodeCrud) +const mockUseVarList = vi.mocked(useVarList) +const mockUseKeyValueList = vi.mocked(useKeyValueList) +const mockUseStore = vi.mocked(useStore) + +const createPayload = (overrides: Partial = {}): HttpNodeType => ({ + title: 'HTTP Request', + desc: '', + type: BlockEnum.HttpRequest, + variables: [], + method: Method.get, + url: 'https://api.example.com', + authorization: { type: AuthorizationType.none }, + headers: 'accept:application/json', + params: 'page:1', + body: { + type: BodyType.json, + data: '{"name":"alice"}', + }, + timeout: { connect: 5, read: 10, write: 15 }, + ssl_verify: true, + ...overrides, +}) + +describe('http/use-config', () => { + const mockSetInputs = vi.fn() + const mockHandleVarListChange = vi.fn() + const mockHandleAddVariable = vi.fn() + const headerSetList = vi.fn() + const headerAddItem = vi.fn() + const headerToggle = vi.fn() + const paramSetList = vi.fn() + const paramAddItem = vi.fn() + const paramToggle = vi.fn() + let currentInputs: HttpNodeType + let headerFieldChange: ((value: string) => void) | undefined + let paramFieldChange: ((value: string) => void) | undefined + + beforeEach(() => { + vi.clearAllMocks() + currentInputs = createPayload() + headerFieldChange = undefined + paramFieldChange = undefined + + mockUseNodesReadOnly.mockReturnValue({ nodesReadOnly: false, getNodesReadOnly: () => false }) + mockUseNodeCrud.mockImplementation(() => ({ + inputs: currentInputs, + setInputs: mockSetInputs, + })) + mockUseVarList.mockReturnValue({ + handleVarListChange: mockHandleVarListChange, + handleAddVariable: mockHandleAddVariable, + } as ReturnType) + mockUseKeyValueList.mockImplementation((value, onChange) => { + if (value === currentInputs.headers) { + headerFieldChange = onChange + return { + list: [{ id: 'header-1', key: 'accept', value: 'application/json' }], + setList: headerSetList, + addItem: headerAddItem, + isKeyValueEdit: true, + toggleIsKeyValueEdit: headerToggle, + } + } + + paramFieldChange = onChange + return { + list: [{ id: 'param-1', key: 'page', value: '1' }], + setList: paramSetList, + addItem: paramAddItem, + isKeyValueEdit: false, + toggleIsKeyValueEdit: paramToggle, + } + }) + mockUseStore.mockImplementation((selector) => { + const state = { + nodesDefaultConfigs: { + [BlockEnum.HttpRequest]: createPayload({ + method: Method.post, + url: 'https://default.example.com', + headers: '', + params: '', + body: { type: BodyType.none, data: [] }, + timeout: { connect: 1, read: 2, write: 3 }, + ssl_verify: false, + }), + }, + } + + return selector(state as never) + }) + }) + + it('stays pending when the node default config is unavailable', () => { + mockUseStore.mockImplementation((selector) => { + return selector({ nodesDefaultConfigs: {} } as never) + }) + + const { result } = renderHook(() => useConfig('http-node', currentInputs)) + + expect(result.current.isDataReady).toBe(false) + expect(mockSetInputs).not.toHaveBeenCalled() + }) + + it('hydrates defaults, normalizes body payloads, and exposes var-list and key-value helpers', async () => { + const { result } = renderHook(() => useConfig('http-node', currentInputs)) + + await waitFor(() => { + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + method: Method.get, + url: 'https://api.example.com', + body: { + type: BodyType.json, + data: [{ + type: BodyPayloadValueType.text, + value: '{"name":"alice"}', + }], + }, + ssl_verify: true, + })) + }) + + expect(result.current.isDataReady).toBe(true) + expect(result.current.readOnly).toBe(false) + expect(result.current.handleVarListChange).toBe(mockHandleVarListChange) + expect(result.current.handleAddVariable).toBe(mockHandleAddVariable) + expect(result.current.headers).toEqual([{ id: 'header-1', key: 'accept', value: 'application/json' }]) + expect(result.current.setHeaders).toBe(headerSetList) + expect(result.current.addHeader).toBe(headerAddItem) + expect(result.current.isHeaderKeyValueEdit).toBe(true) + expect(result.current.toggleIsHeaderKeyValueEdit).toBe(headerToggle) + expect(result.current.params).toEqual([{ id: 'param-1', key: 'page', value: '1' }]) + expect(result.current.setParams).toBe(paramSetList) + expect(result.current.addParam).toBe(paramAddItem) + expect(result.current.isParamKeyValueEdit).toBe(false) + expect(result.current.toggleIsParamKeyValueEdit).toBe(paramToggle) + expect(result.current.filterVar({ type: VarType.string } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.number } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.secret } as never)).toBe(true) + expect(result.current.filterVar({ type: VarType.file } as never)).toBe(false) + }) + + it('initializes empty body data arrays when the payload body is missing', async () => { + currentInputs = createPayload({ + body: { + type: BodyType.formData, + data: undefined as unknown as HttpNodeType['body']['data'], + }, + }) + + renderHook(() => useConfig('http-node', currentInputs)) + + await waitFor(() => { + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + body: { + type: BodyType.formData, + data: [], + }, + })) + }) + }) + + it('updates request fields, authorization state, curl imports, and ssl verification', async () => { + const { result } = renderHook(() => useConfig('http-node', currentInputs)) + + await waitFor(() => { + expect(result.current.isDataReady).toBe(true) + }) + + mockSetInputs.mockClear() + + act(() => { + result.current.handleMethodChange(Method.delete) + result.current.handleUrlChange('https://changed.example.com') + headerFieldChange?.('x-token:123') + paramFieldChange?.('size:20') + result.current.setBody({ type: BodyType.rawText, data: 'raw payload' }) + result.current.showAuthorization() + }) + + expect(result.current.isShowAuthorization).toBe(true) + + act(() => { + result.current.hideAuthorization() + result.current.setAuthorization({ + type: AuthorizationType.apiKey, + config: { + type: APIType.bearer, + api_key: 'secret', + }, + }) + result.current.setTimeout({ connect: 30, read: 40, write: 50 }) + result.current.showCurlPanel() + }) + + expect(result.current.isShowCurlPanel).toBe(true) + + act(() => { + result.current.hideCurlPanel() + result.current.handleCurlImport(createPayload({ + method: Method.patch, + url: 'https://imported.example.com', + headers: 'authorization:Bearer imported', + params: 'debug:true', + body: { type: BodyType.json, data: [{ type: BodyPayloadValueType.text, value: '{"ok":true}' }] }, + })) + result.current.handleSSLVerifyChange(false) + }) + + expect(result.current.isShowAuthorization).toBe(false) + expect(result.current.isShowCurlPanel).toBe(false) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ method: Method.delete })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ url: 'https://changed.example.com' })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ headers: 'x-token:123' })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ params: 'size:20' })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + body: { type: BodyType.rawText, data: 'raw payload' }, + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + authorization: expect.objectContaining({ + type: AuthorizationType.apiKey, + }), + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + timeout: { connect: 30, read: 40, write: 50 }, + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ + method: Method.patch, + url: 'https://imported.example.com', + headers: 'authorization:Bearer imported', + params: 'debug:true', + body: { type: BodyType.json, data: [{ type: BodyPayloadValueType.text, value: '{"ok":true}' }] }, + })) + expect(mockSetInputs).toHaveBeenCalledWith(expect.objectContaining({ ssl_verify: false })) + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/__tests__/node.spec.tsx b/web/app/components/workflow/nodes/human-input/__tests__/node.spec.tsx new file mode 100644 index 0000000000..915f9136be --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/__tests__/node.spec.tsx @@ -0,0 +1,83 @@ +import type { HumanInputNodeType } from '../types' +import { render, screen } from '@testing-library/react' +import { BlockEnum, InputVarType } from '@/app/components/workflow/types' +import Node from '../node' +import { DeliveryMethodType, UserActionButtonType } from '../types' + +vi.mock('../../_base/components/node-handle', () => ({ + NodeSourceHandle: (props: { handleId: string }) =>
{`handle:${props.handleId}`}
, +})) + +const createData = (overrides: Partial = {}): HumanInputNodeType => ({ + title: 'Human Input', + desc: '', + type: BlockEnum.HumanInput, + delivery_methods: [{ + id: 'dm-webapp', + type: DeliveryMethodType.WebApp, + enabled: true, + }, { + id: 'dm-email', + type: DeliveryMethodType.Email, + enabled: true, + }], + form_content: 'Please review this request', + inputs: [{ + type: InputVarType.textInput, + output_variable_name: 'review_result', + default: { + selector: [], + type: 'constant', + value: '', + }, + }], + user_actions: [{ + id: 'approve', + title: 'Approve', + button_style: UserActionButtonType.Primary, + }, { + id: 'reject', + title: 'Reject', + button_style: UserActionButtonType.Default, + }], + timeout: 3, + timeout_unit: 'day', + ...overrides, +}) + +describe('human-input/node', () => { + it('renders delivery methods, user action handles, and the timeout handle', () => { + render( + , + ) + + expect(screen.getByText('workflow.nodes.humanInput.deliveryMethod.title')).toBeInTheDocument() + expect(screen.getByText('webapp')).toBeInTheDocument() + expect(screen.getByText('email')).toBeInTheDocument() + expect(screen.getByText('approve')).toBeInTheDocument() + expect(screen.getByText('reject')).toBeInTheDocument() + expect(screen.getByText('Timeout')).toBeInTheDocument() + expect(screen.getByText('handle:approve')).toBeInTheDocument() + expect(screen.getByText('handle:reject')).toBeInTheDocument() + expect(screen.getByText('handle:__timeout')).toBeInTheDocument() + }) + + it('keeps the timeout handle when delivery methods and actions are empty', () => { + render( + , + ) + + expect(screen.queryByText('workflow.nodes.humanInput.deliveryMethod.title')).not.toBeInTheDocument() + expect(screen.getByText('Timeout')).toBeInTheDocument() + expect(screen.getByText('handle:__timeout')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/__tests__/panel.spec.tsx b/web/app/components/workflow/nodes/human-input/__tests__/panel.spec.tsx new file mode 100644 index 0000000000..937a2da61a --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/__tests__/panel.spec.tsx @@ -0,0 +1,386 @@ +import type { ReactNode } from 'react' +import type useConfig from '../hooks/use-config' +import type { HumanInputNodeType } from '../types' +import type { NodePanelProps } from '@/app/components/workflow/types' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import copy from 'copy-to-clipboard' +import { toast } from '@/app/components/base/ui/toast' +import { BlockEnum, InputVarType, VarType } from '@/app/components/workflow/types' +import Panel from '../panel' +import { DeliveryMethodType, UserActionButtonType } from '../types' + +const mockUseConfig = vi.hoisted(() => vi.fn()) +const mockUseStore = vi.hoisted(() => vi.fn()) +const mockUseAvailableVarList = vi.hoisted(() => vi.fn()) +const mockDeliveryMethod = vi.hoisted(() => vi.fn()) +const mockFormContent = vi.hoisted(() => vi.fn()) +const mockFormContentPreview = vi.hoisted(() => vi.fn()) +const mockTimeoutInput = vi.hoisted(() => vi.fn()) +const mockUserActionItem = vi.hoisted(() => vi.fn()) + +vi.mock('copy-to-clipboard', () => ({ + __esModule: true, + default: vi.fn(), +})) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + success: vi.fn(), + }, +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + __esModule: true, + default: () =>
tooltip
, +})) + +vi.mock('@/app/components/base/action-button', () => ({ + __esModule: true, + default: (props: { + children: ReactNode + onClick: () => void + }) => ( + + ), +})) + +vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { nodePanelWidth: number }) => unknown) => mockUseStore(selector), +})) + +vi.mock('@/app/components/workflow/nodes/_base/hooks/use-available-var-list', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseAvailableVarList(...args), +})) + +vi.mock('../hooks/use-config', () => ({ + __esModule: true, + default: (...args: unknown[]) => mockUseConfig(...args), +})) + +vi.mock('../components/delivery-method', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + onChange: (methods: HumanInputNodeType['delivery_methods']) => void + }) => { + mockDeliveryMethod(props) + return ( + + ) + }, +})) + +vi.mock('../components/form-content', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + isExpand: boolean + onChange: (value: string) => void + onFormInputsChange: (value: HumanInputNodeType['inputs']) => void + onFormInputItemRename: (oldName: string, newName: string) => void + onFormInputItemRemove: (name: string) => void + }) => { + mockFormContent(props) + return ( +
+
{props.readonly ? 'form-content:readonly' : `form-content:${props.isExpand ? 'expanded' : 'collapsed'}`}
+ + + + +
+ ) + }, +})) + +vi.mock('../components/form-content-preview', () => ({ + __esModule: true, + default: (props: { + onClose: () => void + }) => { + mockFormContentPreview(props) + return ( +
+
form-preview
+ +
+ ) + }, +})) + +vi.mock('../components/timeout', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + onChange: (value: { timeout: number, unit: 'hour' | 'day' }) => void + }) => { + mockTimeoutInput(props) + return ( + + ) + }, +})) + +vi.mock('../components/user-action', () => ({ + __esModule: true, + default: (props: { + readonly: boolean + data: HumanInputNodeType['user_actions'][number] + onChange: (value: HumanInputNodeType['user_actions'][number]) => void + onDelete: (id: string) => void + }) => { + mockUserActionItem(props) + return ( +
+
{`${props.data.id}:${props.readonly ? 'readonly' : 'editable'}`}
+ + +
+ ) + }, +})) + +vi.mock('@/app/components/workflow/nodes/_base/components/output-vars', () => ({ + __esModule: true, + default: (props: { + children: ReactNode + collapsed?: boolean + onCollapse?: (collapsed: boolean) => void + }) => ( +
+ + {props.children} +
+ ), + VarItem: ({ name, type, description }: { name: string, type: string, description: string }) => ( +
{`${name}:${type}:${description}`}
+ ), +})) + +vi.mock('@remixicon/react', () => ({ + RiAddLine: () => add-icon, + RiClipboardLine: () => clipboard-icon, + RiCollapseDiagonalLine: () => collapse-icon, + RiExpandDiagonalLine: () => expand-icon, + RiEyeLine: () => preview-icon, +})) + +const mockCopy = vi.mocked(copy) +const mockToastSuccess = vi.mocked(toast.success) + +const createData = (overrides: Partial = {}): HumanInputNodeType => ({ + title: 'Human Input', + desc: '', + type: BlockEnum.HumanInput, + delivery_methods: [{ + id: 'dm-webapp', + type: DeliveryMethodType.WebApp, + enabled: true, + }], + form_content: 'Please review this request', + inputs: [{ + type: InputVarType.textInput, + output_variable_name: 'review_result', + default: { + selector: [], + type: 'constant', + value: '', + }, + }], + user_actions: [{ + id: 'approve', + title: 'Approve', + button_style: UserActionButtonType.Primary, + }], + timeout: 3, + timeout_unit: 'day', + ...overrides, +}) + +const createConfigResult = (overrides: Partial> = {}): ReturnType => ({ + readOnly: false, + inputs: createData(), + handleDeliveryMethodChange: vi.fn(), + handleUserActionAdd: vi.fn(), + handleUserActionChange: vi.fn(), + handleUserActionDelete: vi.fn(), + handleTimeoutChange: vi.fn(), + handleFormContentChange: vi.fn(), + handleFormInputsChange: vi.fn(), + handleFormInputItemRename: vi.fn(), + handleFormInputItemRemove: vi.fn(), + editorKey: 1, + structuredOutputCollapsed: true, + setStructuredOutputCollapsed: vi.fn(), + ...overrides, +}) + +const renderPanel = (data: HumanInputNodeType = createData()) => { + const props: NodePanelProps = { + id: 'human-input-node', + data, + panelProps: { + getInputVars: vi.fn(() => []), + toVarInputs: vi.fn(() => []), + runInputData: {}, + runInputDataRef: { current: {} }, + setRunInputData: vi.fn(), + runResult: null, + }, + } + + return render() +} + +describe('human-input/panel', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUseStore.mockImplementation(selector => selector({ nodePanelWidth: 480 })) + mockUseAvailableVarList.mockImplementation((_id, options?: { filterVar?: (payload: { type: VarType }) => boolean }) => ({ + availableVars: [{ + variable: ['start', 'email'], + type: VarType.string, + }, { + variable: ['start', 'files'], + type: VarType.file, + }].filter(variable => options?.filterVar ? options.filterVar({ type: variable.type } as never) : true), + availableNodesWithParent: [{ + id: 'start-node', + data: { + title: 'Start', + type: BlockEnum.Start, + }, + }], + })) + mockUseConfig.mockReturnValue(createConfigResult()) + }) + + it('renders editable controls, forwards updates, and toggles preview and output sections', async () => { + const user = userEvent.setup() + const config = createConfigResult() + mockUseConfig.mockReturnValue(config) + + const { container } = renderPanel() + + expect(screen.getByRole('button', { name: 'delivery-method:editable' })).toBeInTheDocument() + expect(screen.getByText('form-content:collapsed')).toBeInTheDocument() + expect(screen.getByText('approve:editable')).toBeInTheDocument() + expect(screen.getByText('review_result:string:Form input value')).toBeInTheDocument() + expect(screen.getByText('__action_id:string:Action ID user triggered')).toBeInTheDocument() + expect(screen.getByText('__rendered_content:string:Rendered content')).toBeInTheDocument() + + await user.click(screen.getByRole('button', { name: 'delivery-method:editable' })) + await user.click(screen.getByRole('button', { name: /workflow\.nodes\.humanInput\.formContent\.preview/ })) + await user.click(screen.getByRole('button', { name: 'change-form-content' })) + await user.click(screen.getByRole('button', { name: 'change-form-inputs' })) + await user.click(screen.getByRole('button', { name: 'rename-form-input' })) + await user.click(screen.getByRole('button', { name: 'remove-form-input' })) + await user.click(screen.getByRole('button', { name: 'action-button' })) + await user.click(screen.getByRole('button', { name: 'change-action-approve' })) + await user.click(screen.getByRole('button', { name: 'delete-action-approve' })) + await user.click(screen.getByRole('button', { name: 'timeout:editable' })) + await user.click(screen.getByRole('button', { name: 'toggle-output-vars' })) + await user.click(screen.getByRole('button', { name: 'close-preview' })) + + const iconContainers = container.querySelectorAll('div.flex.size-6.cursor-pointer') + await user.click(iconContainers[0] as HTMLElement) + await user.click(iconContainers[1] as HTMLElement) + + expect(config.handleDeliveryMethodChange).toHaveBeenCalledWith([{ + id: 'dm-email', + type: DeliveryMethodType.Email, + enabled: true, + }]) + expect(config.handleFormContentChange).toHaveBeenCalledWith('Updated content') + expect(config.handleFormInputsChange).toHaveBeenCalled() + expect(config.handleFormInputItemRename).toHaveBeenCalledWith('name', 'email') + expect(config.handleFormInputItemRemove).toHaveBeenCalledWith('name') + expect(config.handleUserActionAdd).toHaveBeenCalledWith({ + id: 'action_2', + title: 'Button Text 2', + button_style: UserActionButtonType.Default, + }) + expect(config.handleUserActionChange).toHaveBeenCalledWith(0, { + id: 'approve', + title: 'Approve updated', + button_style: UserActionButtonType.Primary, + }) + expect(config.handleUserActionDelete).toHaveBeenCalledWith('approve') + expect(config.handleTimeoutChange).toHaveBeenCalledWith({ timeout: 8, unit: 'hour' }) + expect(config.setStructuredOutputCollapsed).toHaveBeenCalledWith(false) + expect(mockCopy).toHaveBeenCalledWith('Please review this request') + expect(mockToastSuccess).toHaveBeenCalledWith('common.actionMsg.copySuccessfully') + expect(mockFormContentPreview).toHaveBeenCalled() + }) + + it('renders readonly and empty states without preview or add controls', () => { + mockUseConfig.mockReturnValue(createConfigResult({ + readOnly: true, + inputs: createData({ + user_actions: [], + }), + structuredOutputCollapsed: false, + })) + + renderPanel() + + expect(screen.getByRole('button', { name: 'delivery-method:readonly' })).toBeInTheDocument() + expect(screen.getByText('form-content:readonly')).toBeInTheDocument() + expect(screen.getByText('workflow.nodes.humanInput.userActions.emptyTip')).toBeInTheDocument() + expect(screen.queryByRole('button', { name: /workflow\.nodes\.humanInput\.formContent\.preview/ })).not.toBeInTheDocument() + expect(screen.queryByRole('button', { name: 'action-button' })).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'timeout:readonly' })).toBeInTheDocument() + expect(screen.queryByText('form-preview')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/nodes/human-input/components/delivery-method/__tests__/email-configure-modal.spec.tsx b/web/app/components/workflow/nodes/human-input/components/delivery-method/__tests__/email-configure-modal.spec.tsx new file mode 100644 index 0000000000..cec9ffe69a --- /dev/null +++ b/web/app/components/workflow/nodes/human-input/components/delivery-method/__tests__/email-configure-modal.spec.tsx @@ -0,0 +1,180 @@ +import type { EmailConfig } from '../../../types' +import { fireEvent, render, screen } from '@testing-library/react' +import EmailConfigureModal from '../email-configure-modal' + +const mockToastError = vi.hoisted(() => vi.fn()) +const mockUseAppContextSelector = vi.hoisted(() => vi.fn()) + +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: (message: string) => mockToastError(message), + }, +})) + +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: { userProfile: { email: string } }) => string) => + mockUseAppContextSelector(selector), +})) + +vi.mock('../mail-body-input', () => ({ + default: ({ value, onChange }: { value: string, onChange: (value: string) => void }) => ( +