diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 4da070bdbf..105c979c58 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path: - ✅ **Import real project components** directly (including base components and siblings) - ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers -- ❌ **DO NOT mock** base components (`@/app/components/base/*`) +- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`) - ❌ **DO NOT mock** sibling/child components in the same directory > See [Test Structure Template](#test-structure-template) for correct import/mock patterns. @@ -325,12 +325,12 @@ For more detailed information, refer to: ### Reference Examples in Codebase - `web/utils/classnames.spec.ts` - Utility function tests -- `web/app/components/base/button/index.spec.tsx` - Component tests +- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests - `web/__mocks__/provider-context.ts` - Mock factory example ### Project Configuration -- `web/vitest.config.ts` - Vitest configuration +- `web/vite.config.ts` - Vite/Vitest configuration - `web/vitest.setup.ts` - Test environment setup - `web/scripts/analyze-component.js` - Component analysis tool - Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files. diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md index 10b8fb66f9..99258498dd 100644 --- a/.agents/skills/frontend-testing/references/checklist.md +++ b/.agents/skills/frontend-testing/references/checklist.md @@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Integration vs Mocking -- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.) - [ ] Import real project components instead of mocking - [ ] Only mock: API calls, complex context providers, third-party libs with side effects - [ ] Prefer integration testing when using single spec file @@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Mocks -- [ ] **DO NOT mock base components** (`@/app/components/base/*`) +- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`) - [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`) - [ ] Shared mock state reset in `beforeEach` - [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index f58377c4a5..8c2f1c0c58 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -2,29 +2,27 @@ ## ⚠️ Important: What NOT to Mock -### DO NOT Mock Base Components +### DO NOT Mock Base Components or dify-ui Primitives -**Never mock components from `@/app/components/base/`** such as: +**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as: -- `Loading`, `Spinner` -- `Button`, `Input`, `Select` -- `Tooltip`, `Modal`, `Dropdown` -- `Icon`, `Badge`, `Tag` +- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag` +- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast` **Why?** -- Base components will have their own dedicated tests +- These components have their own dedicated tests - Mocking them creates false positives (tests pass but real integration fails) - Using real components tests actual integration behavior ```typescript -// ❌ WRONG: Don't mock base components +// ❌ WRONG: Don't mock base components or dify-ui primitives vi.mock('@/app/components/base/loading', () => () =>
Loading
) -vi.mock('@/app/components/base/button', () => ({ children }: any) => ) +vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => })) -// ✅ CORRECT: Import and use real base components +// ✅ CORRECT: Import and use the real components import Loading from '@/app/components/base/loading' -import Button from '@/app/components/base/button' +import { Button } from '@langgenius/dify-ui/button' // They will render normally in tests ``` @@ -319,7 +317,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ### ✅ DO -1. **Use real base components** - Import from `@/app/components/base/` directly +1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly 1. **Use real project components** - Prefer importing over mocking 1. **Use real Zustand stores** - Set test state via `store.setState()` 1. **Reset mocks in `beforeEach`**, not `afterEach` @@ -330,7 +328,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ### ❌ DON'T -1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.) 1. **Don't mock Zustand store modules** - Use real stores with `setState()` 1. Don't mock components you can import directly 1. Don't create overly simplified mocks that miss conditional logic @@ -342,7 +340,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ``` Need to use a component in test? │ -├─ Is it from @/app/components/base/*? +├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*? │ └─ YES → Import real component, DO NOT mock │ ├─ Is it a project component? diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml deleted file mode 100644 index b0f0a36bc9..0000000000 --- a/.github/workflows/anti-slop.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Anti-Slop PR Check - -on: - pull_request_target: - types: [opened, edited, synchronize] - -permissions: - pull-requests: write - contents: read - -jobs: - anti-slop: - runs-on: ubuntu-latest - steps: - - uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - close-pr: false - failure-add-pr-labels: "needs-revision" diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index fd910531db..717413937f 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -35,7 +35,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -84,7 +84,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -105,7 +105,7 @@ jobs: run: sh .github/workflows/expose_service_ports.sh - name: Set up Sandbox - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -156,7 +156,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 772ab8dd56..35683b112f 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -25,7 +25,7 @@ jobs: - name: Check Docker Compose inputs if: github.event_name != 'merge_group' id: docker-compose-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | docker/generate_docker_compose @@ -35,7 +35,7 @@ jobs: - name: Check web inputs if: github.event_name != 'merge_group' id: web-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | web/** @@ -48,7 +48,7 @@ jobs: - name: Check api inputs if: github.event_name != 'merge_group' id: api-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | api/** @@ -58,7 +58,7 @@ jobs: python-version: "3.11" - if: github.event_name != 'merge_group' - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 - name: Generate Docker Compose if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true' @@ -120,8 +120,7 @@ jobs: - name: ESLint autofix if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' run: | - cd web vp exec eslint --concurrency=2 --prune-suppressions --quiet || true - if: github.event_name != 'merge_group' - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 + uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4 diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index 5991abe3ba..17b867dd6d 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -19,7 +19,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -40,7 +40,7 @@ jobs: cp middleware.env.example middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -69,7 +69,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -94,7 +94,7 @@ jobs: sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index eefb1ebbb9..c55b013dbe 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -76,13 +76,11 @@ jobs: diff += '\\n\\n... (truncated) ...'; } - const body = diff.trim() - ? '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
' - : '### Pyrefly Diff\nNo changes detected.'; - - await github.rest.issues.createComment({ - issue_number: prNumber, - owner: context.repo.owner, - repo: context.repo.repo, - body, - }); + if (diff.trim()) { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body: '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
', + }); + } diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index ac3732579c..eb15cd6f75 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -22,7 +22,7 @@ jobs: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml index 974da99aad..3c6c96a664 100644 --- a/.github/workflows/pyrefly-type-coverage-comment.yml +++ b/.github/workflows/pyrefly-type-coverage-comment.yml @@ -24,7 +24,7 @@ jobs: uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Setup Python & UV - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml index c795c32e31..0599c94eef 100644 --- a/.github/workflows/pyrefly-type-coverage.yml +++ b/.github/workflows/pyrefly-type-coverage.yml @@ -22,7 +22,7 @@ jobs: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c32fc9d0cb..d8c7ebbad3 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -25,7 +25,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | api/** @@ -33,7 +33,7 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: false python-version: "3.12" @@ -73,10 +73,12 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | web/** + e2e/** + sdks/nodejs-client/** packages/** package.json pnpm-lock.yaml @@ -93,16 +95,16 @@ jobs: - name: Restore ESLint cache if: steps.changed-files.outputs.any_changed == 'true' id: eslint-cache-restore - uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 with: - path: web/.eslintcache - key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + path: .eslintcache + key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} restore-keys: | - ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run lint:ci - name: Web tsslint @@ -112,7 +114,7 @@ jobs: - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run type-check - name: Web dead code check @@ -122,9 +124,9 @@ jobs: - name: Save ESLint cache if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 with: - path: web/.eslintcache + path: .eslintcache key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} superlinter: @@ -140,7 +142,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | **.sh diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index 467f31fccf..bf33207a14 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -30,7 +30,7 @@ jobs: persist-credentials: false - name: Use Node.js - uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 with: node-version: 22 cache: '' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 541200293d..eecbbb1a56 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -158,7 +158,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.context.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93 + uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/vdb-tests-full.yml b/.github/workflows/vdb-tests-full.yml index f0def8fe7a..b79e8927d7 100644 --- a/.github/workflows/vdb-tests-full.yml +++ b/.github/workflows/vdb-tests-full.yml @@ -36,7 +36,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -65,7 +65,7 @@ jobs: # tiflash - name: Set up Full Vector Store Matrix - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.yaml diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index f3966f15b9..bd13d662c3 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -33,7 +33,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -62,7 +62,7 @@ jobs: # tiflash - name: Set up Vector Stores for Smoke Coverage - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.yaml diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml index 10dc31bde8..6bd4d4f406 100644 --- a/.github/workflows/web-e2e.yml +++ b/.github/workflows/web-e2e.yml @@ -28,7 +28,7 @@ jobs: uses: ./.github/actions/setup-web - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index f3ab4c62c7..2a5cf19645 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -89,3 +89,37 @@ jobs: flags: web env: CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} + + dify-ui-test: + name: dify-ui Tests + runs-on: ubuntu-latest + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + defaults: + run: + shell: bash + working-directory: ./packages/dify-ui + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup web environment + uses: ./.github/actions/setup-web + + - name: Install Chromium for Browser Mode + run: vp exec playwright install --with-deps chromium + + - name: Run dify-ui tests + run: vp test run --coverage --silent=passed-only + + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 + with: + directory: packages/dify-ui/coverage + flags: dify-ui + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index 53dea88899..836bddbb49 100644 --- a/.gitignore +++ b/.gitignore @@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info .vscode/* !.vscode/launch.json.template +!.vscode/settings.example.json !.vscode/README.md api/.vscode # vscode Code History Extension @@ -236,9 +237,15 @@ scripts/stress-test/reports/ .playwright-mcp/ .serena/ +# vitest browser mode attachments (failure screenshots, traces, etc.) +.vitest-attachments/ +**/__screenshots__/ + # settings *.local.json *.local.md # Code Agent Folder .qoder/* + +.eslintcache diff --git a/.vite-hooks/pre-commit b/.vite-hooks/pre-commit index 13bbd81cf6..d48381bce2 100755 --- a/.vite-hooks/pre-commit +++ b/.vite-hooks/pre-commit @@ -56,44 +56,9 @@ if $api_modified; then fi fi -if $web_modified; then - if $skip_web_checks; then - echo "Git operation in progress, skipping web checks" - exit 0 - fi - - echo "Running ESLint on web module" - - if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then - web_ts_modified=false - else - ts_diff_status=$? - if [ $ts_diff_status -eq 1 ]; then - web_ts_modified=true - else - echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)." - exit $ts_diff_status - fi - fi - - cd ./web || exit 1 - vp staged - - if $web_ts_modified; then - echo "Running TypeScript type-check:tsgo" - if ! npm run type-check:tsgo; then - echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors." - exit 1 - fi - else - echo "No staged TypeScript changes detected, skipping type-check:tsgo" - fi - - echo "Running knip" - if ! npm run knip; then - echo "Knip check failed. Please run 'npm run knip' to fix the errors." - exit 1 - fi - - cd ../ +if $skip_web_checks; then + echo "Git operation in progress, skipping web checks" + exit 0 fi + +vp staged diff --git a/web/.vscode/settings.example.json b/.vscode/settings.example.json similarity index 86% rename from web/.vscode/settings.example.json rename to .vscode/settings.example.json index 4b356f5b7a..7cdbc51a3b 100644 --- a/web/.vscode/settings.example.json +++ b/.vscode/settings.example.json @@ -1,12 +1,16 @@ { - // Disable the default formatter, use eslint instead - "prettier.enable": false, - "editor.formatOnSave": false, + "cucumber.features": [ + "e2e/features/**/*.feature", + ], + "cucumber.glue": [ + "e2e/features/**/*.ts", + ], + + "tailwindCSS.experimental.configFile": "web/app/styles/globals.css", // Auto fix "editor.codeActionsOnSave": { "source.fixAll.eslint": "explicit", - "source.organizeImports": "never" }, // Silent the stylistic rules in your IDE, but still auto fix them diff --git a/api/commands/account.py b/api/commands/account.py index 6a2a2e0428..761323a73d 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,6 +2,7 @@ import base64 import secrets import click +from sqlalchemy.orm import Session from constants.languages import languages from extensions.ext_database import db @@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm): # encrypt password with salt password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = db.session.merge(account) - account.password = base64_password_hashed - account.password_salt = base64_salt - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.password = base64_password_hashed + account.password_salt = base64_salt + session.commit() AccountService.reset_login_error_rate_limit(normalized_email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm): click.echo(click.style(f"Invalid email: {new_email}", fg="red")) return - account = db.session.merge(account) - account.email = normalized_new_email - db.session.commit() + with Session(db.engine) as session: + account = session.merge(account) + account.email = normalized_new_email + session.commit() click.echo(click.style("Email updated successfully.", fg="green")) diff --git a/api/constants/dsl_version.py b/api/constants/dsl_version.py new file mode 100644 index 0000000000..b0fbe0075c --- /dev/null +++ b/api/constants/dsl_version.py @@ -0,0 +1 @@ +CURRENT_APP_DSL_VERSION = "0.6.0" diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 051d08aa36..9102983d86 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -129,6 +129,7 @@ class AppNamePayload(BaseModel): class AppIconPayload(BaseModel): icon: str | None = Field(default=None, description="Icon data") + icon_type: IconType | None = Field(default=None, description="Icon type") icon_background: str | None = Field(default=None, description="Icon background color") @@ -729,7 +730,12 @@ class AppIconApi(Resource): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") + app_model = app_service.update_app_icon( + app_model, + args.icon or "", + args.icon_background or "", + args.icon_type, + ) response_model = AppDetail.model_validate(app_model, from_attributes=True) return response_model.model_dump(mode="json") diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index cead33d14f..9c8b095b9f 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel): def _normalize_value_type(cls, value: Any) -> str: exposed_type = getattr(value, "exposed_type", None) if callable(exposed_type): - return str(exposed_type().value) + return str(exposed_type()) if isinstance(value, str): return value try: diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 5b1abc98dc..d517f695b8 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -18,12 +18,6 @@ from models.enums import AppMCPServerStatus from models.model import AppMCPServer -def _to_timestamp(value: datetime | int | None) -> int | None: - if isinstance(value, datetime): - return int(value.timestamp()) - return value - - class MCPServerCreatePayload(BaseModel): description: str | None = Field(default=None, description="Server description") parameters: dict[str, Any] = Field(..., description="Server parameters configuration") @@ -36,19 +30,25 @@ class MCPServerUpdatePayload(BaseModel): status: str | None = Field(default=None, description="Server status") +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + class AppMCPServerResponse(ResponseModel): id: str name: str server_code: str description: str - status: str + status: AppMCPServerStatus parameters: dict[str, Any] | list[Any] | str created_at: int | None = None updated_at: int | None = None @field_validator("parameters", mode="before") @classmethod - def _parse_json_string(cls, value: Any) -> Any: + def _normalize_parameters(cls, value: Any) -> Any: if isinstance(value, str): try: return json.loads(value) @@ -70,7 +70,9 @@ class AppMCPServerController(Resource): @console_ns.doc("get_app_mcp_server") @console_ns.doc(description="Get MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response( + 200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @login_required @account_initialization_required @setup_required @@ -85,7 +87,9 @@ class AppMCPServerController(Resource): @console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__]) - @console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response( + 201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @account_initialization_required @get_app_model @@ -111,13 +115,15 @@ class AppMCPServerController(Resource): ) db.session.add(server) db.session.commit() - return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201 @console_ns.doc("update_app_mcp_server") @console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__]) - @console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response( + 200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @get_app_model @@ -154,7 +160,7 @@ class AppMCPServerRefreshController(Resource): @console_ns.doc("refresh_app_mcp_server") @console_ns.doc(description="Refresh MCP server configuration and regenerate server code") @console_ns.doc(params={"server_id": "Server ID"}) - @console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__]) + @console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @setup_required diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index f6319573e0..e32ba5f66c 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable): def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: value_type = workflow_draft_var.value_type - return value_type.exposed_type().value + return str(value_type.exposed_type()) class FullContentDict(TypedDict): @@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict result: FullContentDict = { "size_bytes": variable_file.size, - "value_type": variable_file.value_type.exposed_type().value, + "value_type": str(variable_file.value_type.exposed_type()), "length": variable_file.length, "download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True), } @@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.exposed_type().value, + "value_type": str(v.value_type.exposed_type()), "value": v.value, # Do not track edited for env vars. "edited": False, diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ea0fdef0a7..d001dfba64 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -50,6 +50,7 @@ from fields.dataset_fields import ( from fields.document_fields import document_status_fields from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required +from libs.url_utils import normalize_api_base_url from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum from models.enums import ApiTokenType, SegmentStatus @@ -889,7 +890,8 @@ class DatasetApiBaseUrlApi(Resource): @login_required @account_initialization_required def get(self): - return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} + base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/") + return {"api_base_url": normalize_api_base_url(base)} @console_ns.route("/datasets/retrieval-setting") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b61e39716f..3372a967d9 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -3,18 +3,19 @@ import logging from argparse import ArgumentTypeError from collections.abc import Sequence from contextlib import ExitStack +from datetime import datetime from typing import Any, Literal, cast import sqlalchemy as sa from flask import request, send_file -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload -from controllers.common.schema import get_or_create_model, register_schema_models +from controllers.common.schema import register_schema_models from controllers.console import console_ns from core.errors.error import ( LLMBadRequestError, @@ -29,11 +30,9 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db -from fields.dataset_fields import dataset_fields +from fields.base import ResponseModel from fields.document_fields import ( - dataset_and_document_fields, document_fields, - document_metadata_fields, document_status_fields, document_with_segments_fields, ) @@ -72,27 +71,100 @@ from ..wraps import ( logger = logging.getLogger(__name__) -# Register models for flask_restx to avoid dict type issues in Swagger -dataset_model = get_or_create_model("Dataset", dataset_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields) -document_fields_copy = document_fields.copy() -document_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_model = get_or_create_model("Document", document_fields_copy) +def _normalize_enum(value: Any) -> Any: + if isinstance(value, str) or value is None: + return value + return getattr(value, "value", value) -document_with_segments_fields_copy = document_with_segments_fields.copy() -document_with_segments_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) -dataset_and_document_fields_copy = dataset_and_document_fields.copy() -dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) -dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) -dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) +class DatasetResponse(ResponseModel): + id: str + name: str + description: str | None = None + permission: str | None = None + data_source_type: str | None = None + indexing_technique: str | None = None + created_by: str | None = None + created_at: int | None = None + + @field_validator("data_source_type", "indexing_technique", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentMetadataResponse(ResponseModel): + id: str + name: str + type: str + value: str | None = None + + +class DocumentResponse(ResponseModel): + id: str + position: int | None = None + data_source_type: str | None = None + data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict") + data_source_detail_dict: Any = None + dataset_process_rule_id: str | None = None + name: str + created_from: str | None = None + created_by: str | None = None + created_at: int | None = None + tokens: int | None = None + indexing_status: str | None = None + error: str | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + archived: bool | None = None + display_status: str | None = None + word_count: int | None = None + hit_count: int | None = None + doc_form: str | None = None + doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details") + summary_index_status: str | None = None + need_summary: bool | None = None + + @field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("doc_metadata", mode="before") + @classmethod + def _normalize_doc_metadata(cls, value: Any) -> list[Any]: + if value is None: + return [] + return value + + @field_validator("created_at", "disabled_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentWithSegmentsResponse(DocumentResponse): + process_rule_dict: Any = None + completed_segments: int | None = None + total_segments: int | None = None + + +class DatasetAndDocumentResponse(ResponseModel): + dataset: DatasetResponse + documents: list[DocumentResponse] + batch: str class DocumentRetryPayload(BaseModel): @@ -107,6 +179,11 @@ class GenerateSummaryPayload(BaseModel): document_list: list[str] +class DocumentMetadataUpdatePayload(BaseModel): + doc_type: str | None = None + doc_metadata: Any = None + + class DocumentDatasetListParam(BaseModel): page: int = Field(1, title="Page", description="Page number.") limit: int = Field(20, title="Limit", description="Page size.") @@ -124,7 +201,13 @@ register_schema_models( DocumentRetryPayload, DocumentRenamePayload, GenerateSummaryPayload, + DocumentMetadataUpdatePayload, DocumentBatchDownloadZipPayload, + DatasetResponse, + DocumentMetadataResponse, + DocumentResponse, + DocumentWithSegmentsResponse, + DatasetAndDocumentResponse, ) @@ -357,10 +440,10 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) + @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) @@ -398,7 +481,9 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return {"dataset": dataset, "documents": documents, "batch": batch} + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @setup_required @login_required @@ -426,12 +511,13 @@ class DatasetInitApi(Resource): @console_ns.doc("init_dataset") @console_ns.doc(description="Initialize dataset with documents") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) - @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) + @console_ns.response( + 201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__] + ) @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self): @@ -479,9 +565,9 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = {"dataset": dataset, "documents": documents, "batch": batch} - - return response + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/datasets//documents//indexing-estimate") @@ -988,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource): @console_ns.doc("update_document_metadata") @console_ns.doc(description="Update document metadata") @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @console_ns.expect( - console_ns.model( - "UpdateDocumentMetadataRequest", - { - "doc_type": fields.String(description="Document type"), - "doc_metadata": fields.Raw(description="Document metadata"), - }, - ) - ) + @console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__]) @console_ns.response(200, "Document metadata updated successfully") @console_ns.response(404, "Document not found") @console_ns.response(403, "Permission denied") @@ -1009,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - req_data = request.get_json() + req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {}) - doc_type = req_data.get("doc_type") - doc_metadata = req_data.get("doc_metadata") + doc_type = req_data.doc_type + doc_metadata = req_data.doc_metadata # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: @@ -1194,7 +1272,7 @@ class DocumentRenameApi(DocumentResource): @setup_required @login_required @account_initialization_required - @marshal_with(document_model) + @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__]) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator @@ -1212,7 +1290,7 @@ class DocumentRenameApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return document + return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json") @console_ns.route("/datasets//documents//website-sync") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 44404005b2..c01286cc59 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -595,13 +595,25 @@ class ChangeEmailSendEmailApi(Resource): account = None user_email = None email_for_sending = args.email.lower() - if args.phase is not None and args.phase == "new_email": + # Default to the initial phase; any legacy/unexpected client input is + # coerced back to `old_email` so we never trust the caller to declare + # later phases without a verified predecessor token. + send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD + if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW: + send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW if args.token is None: raise InvalidTokenError() reset_data = AccountService.get_change_email_data(args.token) if reset_data is None: raise InvalidTokenError() + + # The token used to request a new-email code must come from the + # old-email verification step. This prevents the bypass described + # in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here. + token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED: + raise InvalidTokenError() user_email = reset_data.get("email", "") if user_email.lower() != current_user.email.lower(): @@ -620,7 +632,7 @@ class ChangeEmailSendEmailApi(Resource): email=email_for_sending, old_email=user_email, language=language, - phase=args.phase, + phase=send_phase, ) return {"result": "success", "data": token} @@ -655,12 +667,31 @@ class ChangeEmailCheckApi(Resource): AccountService.add_change_email_error_rate_limit(user_email) raise EmailCodeError() + # Only advance tokens that were minted by the matching send-code step; + # refuse tokens that have already progressed or lack a phase marker so + # the chain `old_email -> old_email_verified -> new_email -> new_email_verified` + # is strictly enforced. + phase_transitions = { + AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if not isinstance(token_phase, str): + raise InvalidTokenError() + refreshed_phase = phase_transitions.get(token_phase) + if refreshed_phase is None: + raise InvalidTokenError() + # Verified, revoke the first token AccountService.revoke_change_email_token(args.token) - # Refresh token data by generating a new token + # Refresh token data by generating a new token that carries the + # upgraded phase so later steps can check it. _, new_token = AccountService.generate_change_email_token( - user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} + user_email, + code=args.code, + old_email=token_data.get("old_email"), + additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase}, ) AccountService.reset_change_email_error_rate_limit(user_email) @@ -690,13 +721,29 @@ class ChangeEmailResetApi(Resource): if not reset_data: raise InvalidTokenError() - AccountService.revoke_change_email_token(args.token) + # Only tokens that completed both verification phases may be used to + # change the email. This closes GHSA-4q3w-q5mc-45rq where a token from + # the initial send-code step could be replayed directly here. + token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED: + raise InvalidTokenError() + + # Bind the new email to the token that was mailed and verified, so a + # verified token cannot be reused with a different `new_email` value. + token_email = reset_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if normalized_token_email != normalized_new_email: + raise InvalidTokenError() old_email = reset_data.get("old_email", "") current_user, _ = current_account_with_tenant() if current_user.email.lower() != old_email.lower(): raise AccountNotFound() + # Revoke only after all checks pass so failed attempts don't burn a + # legitimately verified token. + AccountService.revoke_change_email_token(args.token) + updated_account = AccountService.update_account_email(current_user, email=normalized_new_email) AccountService.send_change_email_completed_notify_email( diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 471594f349..34c9534de8 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1131,6 +1131,14 @@ class ToolMCPAuthApi(Resource): with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) + parsed = urlparse(server_url) + sanitized_url = f"{parsed.scheme}://{parsed.hostname}{parsed.path}" + logger.warning( + "MCP authorization failed for provider %s (url=%s)", + provider_id, + sanitized_url, + exc_info=True, + ) raise ValueError(f"Failed to connect to MCP server: {e}") from e diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index a5846e2815..2f309262cb 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel): def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ - Get current user + Get current user. NOTE: user_id is not trusted, it could be maliciously set to any value. - As a result, it could only be considered as an end user id. + As a result, it could only be considered as an end user id. Even when a + concrete end-user ID is supplied, lookups must stay tenant-scoped so one + tenant cannot bind another tenant's user record into the plugin request + context. """ if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID @@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: .limit(1) ) else: - user_model = session.get(EndUser, user_id) + user_model = session.scalar( + select(EndUser) + .where( + EndUser.id == user_id, + EndUser.tenant_id == tenant_id, + ) + .limit(1) + ) if not user_model: user_model = EndUser( diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c4353ca7b8..ca4b18cb5e 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel): def normalize_value_type(cls, value: Any) -> str: exposed_type = getattr(value, "exposed_type", None) if callable(exposed_type): - return str(exposed_type().value) + return str(exposed_type()) if isinstance(value, str): try: - return str(SegmentType(value).exposed_type().value) + return str(SegmentType(value).exposed_type()) except ValueError: return value try: diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 790602ef5d..c22102c2ba 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -42,7 +42,7 @@ from graphon.model_runtime.entities import ( ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index d38d24d1e7..29de0b8b1c 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -299,7 +299,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): # update prompt tool for prompt_tool in prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + tool_instance = tool_instances.get(prompt_tool.name) + if tool_instance: + self.update_prompt_message_tool(tool_instance, prompt_tool) iteration_step += 1 diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index dbd7527fc6..5df3df2b3e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 09ddce327e..cae0eee0df 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -18,7 +18,7 @@ from core.moderation.base import ModerationError from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py index 80d504ef1c..a583301f9b 100644 --- a/api/core/app/file_access/scope.py +++ b/api/core/app/file_access/scope.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Generator # Changed from Iterator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass @@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None: @contextmanager -def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]: +def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None] token = _current_file_access_scope.set(scope) try: yield diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index dfe6133cb6..e2e07ebaff 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, TextPromptMessageContent, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 68e5e5f0c8..3a6f9d575a 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage from graphon.file import FileTransferMethod -from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.protocols import WorkflowFileRuntimeProtocol from graphon.file.runtime import set_workflow_file_runtime +from graphon.http.protocols import HttpResponseProtocol if TYPE_CHECKING: from graphon.file import File @@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): return dify_config.MULTIMODAL_SEND_FORMAT def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - return ssrf_proxy.get(url, follow_redirects=follow_redirects) + return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects) def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 87f005a250..d521304615 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs - execution.exceptions_count = runtime_state.exceptions_count + execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count) def _update_node_execution( self, diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index dc831e5cac..f0dcb13b62 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -352,11 +352,11 @@ class DatasourceManager: raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") file_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.CUSTOM, + file_type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference(record_id=str(upload_file.id)), diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 1ab66cceee..38b87e2cd1 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import ( FormType, ProviderEntity, ) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.base.ai_model import AIModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now @@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel): else [], ) - def validate_provider_credentials( - self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None - ): + def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""): """ Validate custom credentials. :param credentials: provider credentials :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate - :param session: optional database session :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.id == credential_id, ) - credential_record = s.execute(stmt).scalar_one_or_none() - # fix origin data + credential_record = session.execute(stmt).scalar_one_or_none() if credential_record and credential_record.encrypted_config: if not credential_record.encrypted_config.startswith("{"): original_credentials = {"openai_api_key": credential_record.encrypted_config} @@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def _generate_provider_credential_name(self, session) -> str: """ @@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: - if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session): raise ValueError(f"Credential with name '{credential_name}' already exists.") else: - credential_name = self._generate_provider_credential_name(session) + credential_name = self._generate_provider_credential_name(pre_session) - credentials = self.validate_provider_credentials(credentials=credentials, session=session) + credentials = self.validate_provider_credentials(credentials=credentials) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) try: new_record = ProviderCredential( @@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel): session.flush() if not provider_record: - # If provider record does not exist, create it provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, @@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_provider_credential_name_exists( - credential_name=credential_name, session=session, exclude_id=credential_id + credential_name=credential_name, session=pre_session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") - credentials = self.validate_provider_credentials( - credentials=credentials, credential_id=credential_id, session=session - ) + credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, @@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel): ProviderCredential.provider_name.in_(self._get_provider_names()), ) - # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: @@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel): model: str, credentials: dict[str, Any], credential_id: str = "", - session: Session | None = None, ): """ Validate custom model credentials. @@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel): :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, @@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type, ) - credential_record = s.execute(stmt).scalar_one_or_none() + credential_record = session.execute(stmt).scalar_one_or_none() original_credentials = ( json.loads(credential_record.encrypted_config) if credential_record and credential_record.encrypted_config @@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def create_custom_model_credential( self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None @@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel): :param credentials: model credentials dict :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: if self._check_custom_model_credential_name_exists( - model=model, model_type=model_type, credential_name=credential_name, session=session + model=model, model_type=model_type, credential_name=credential_name, session=pre_session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") else: credential_name = self._generate_custom_model_credential_name( - model=model, model_type=model_type, session=session + model=model, model_type=model_type, session=pre_session ) - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, model=model, credentials=credentials, session=session - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, model=model, credentials=credentials + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) try: @@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel): session.add(credential) session.flush() - # save provider model if not provider_model_record: provider_model_record = ProviderModel( tenant_id=self.tenant_id, @@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel): :param credential_id: credential id :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, - session=session, + session=pre_session, exclude_id=credential_id, ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, - model=model, - credentials=credentials, - credential_id=credential_id, - session=session, - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, + model=model, + credentials=credentials, + credential_id=credential_id, + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) stmt = select(ProviderModelCredential).where( @@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel): raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b96a9ce380..38864a1830 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -102,7 +102,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode() + inputs_json_str = dumps_with_segments(inputs).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index dc37a36943..f169f247cf 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_ from extensions.ext_hosting_provider import hosting_configuration from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index e38592bb7b..91e92712b7 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client from core.tools.errors import ToolSSRFError +from graphon.http.response import HttpResponse logger = logging.getLogger(__name__) @@ -267,4 +268,47 @@ class SSRFProxy: return patch(url=url, max_retries=max_retries, **kwargs) +def _to_graphon_http_response(response: httpx.Response) -> HttpResponse: + """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper.""" + return HttpResponse( + status_code=response.status_code, + headers=dict(response.headers), + content=response.content, + url=str(response.url) if response.url else None, + reason_phrase=response.reason_phrase, + fallback_text=response.text, + ) + + +class GraphonSSRFProxy: + """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``.""" + + @property + def max_retries_exceeded_error(self) -> type[Exception]: + return max_retries_exceeded_error + + @property + def request_error(self) -> type[Exception]: + return request_error + + def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs)) + + def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs)) + + def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs)) + + def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs)) + + def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs)) + + def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs)) + + ssrf_proxy = SSRFProxy() +graphon_ssrf_proxy = GraphonSSRFProxy() diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 5c3cd0d8f8..acba3e666b 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -303,9 +303,16 @@ class StreamableHTTPTransport: if response.status_code == 404: if isinstance(message.root, JSONRPCRequest): + error_msg = ( + f"MCP server URL returned 404 Not Found: {self.url} " + "— verify the server URL is correct and the server is running" + if is_initialization + else "Session terminated by server" + ) self._send_session_terminated_error( ctx.server_to_client_queue, message.root.id, + message=error_msg, ) return @@ -381,12 +388,13 @@ class StreamableHTTPTransport: self, server_to_client_queue: ServerToClientQueue, request_id: RequestId, + message: str = "Session terminated by server", ): """Send a session terminated error response.""" jsonrpc_error = JSONRPCError( jsonrpc="2.0", id=request_id, - error=ErrorData(code=32600, message="Session terminated by server"), + error=ErrorData(code=32600, message=message), ) session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) server_to_client_queue.put(session_message) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index d8d8dfedd8..86d0e3baaa 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence -from typing import IO, Any, Literal, Optional, Union, cast, overload +from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from configs import dify_config from core.entities import PluginCredentialType @@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from models.provider import ProviderType logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") class ModelInstance: @@ -168,7 +170,7 @@ class ModelInstance: return cast( Union[LLMResult, Generator], self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -193,7 +195,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -213,7 +215,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, texts=texts, @@ -235,7 +237,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, @@ -252,7 +254,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, texts=texts, @@ -277,7 +279,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, query=query, @@ -305,7 +307,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke_multimodal_rerank, + self.model_type_instance.invoke_multimodal_rerank, model=self.model_name, credentials=self.credentials, query=query, @@ -324,7 +326,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, text=text, @@ -340,7 +342,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, file=file, @@ -357,14 +359,14 @@ class ModelInstance: if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, content_text=content_text, voice=voice, ) - def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Round-robin invoke :param function: function to invoke diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index fda00ac3b9..d78ce90aa1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -1,8 +1,8 @@ from enum import StrEnum -from pydantic import BaseModel, ValidationInfo, field_validator +from pydantic import BaseModel -from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path +from core.ops.utils import validate_project_name, validate_url class TracingProviderEnum(StrEnum): @@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel): return validate_project_name(v, default_name) -class ArizeConfig(BaseTracingConfig): - """ - Model class for Arize tracing config. - """ - - api_key: str | None = None - space_id: str | None = None - project: str | None = None - endpoint: str = "https://otlp.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://otlp.arize.com") - - -class PhoenixConfig(BaseTracingConfig): - """ - Model class for Phoenix tracing config. - """ - - api_key: str | None = None - project: str | None = None - endpoint: str = "https://app.phoenix.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://app.phoenix.arize.com") - - -class LangfuseConfig(BaseTracingConfig): - """ - Model class for Langfuse tracing config. - """ - - public_key: str - secret_key: str - host: str = "https://api.langfuse.com" - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://api.langfuse.com") - - -class LangSmithConfig(BaseTracingConfig): - """ - Model class for Langsmith tracing config. - """ - - api_key: str - project: str - endpoint: str = "https://api.smith.langchain.com" - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # LangSmith only allows HTTPS - return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) - - -class OpikConfig(BaseTracingConfig): - """ - Model class for Opik tracing config. - """ - - api_key: str | None = None - project: str | None = None - workspace: str | None = None - url: str = "https://www.comet.com/opik/api/" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "Default Project") - - @field_validator("url") - @classmethod - def url_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") - - -class WeaveConfig(BaseTracingConfig): - """ - Model class for Weave tracing config. - """ - - api_key: str - entity: str | None = None - project: str - endpoint: str = "https://trace.wandb.ai" - host: str | None = None - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # Weave only allows HTTPS for endpoint - return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - if v is not None and v.strip() != "": - return validate_url(v, v, allowed_schemes=("https", "http")) - return v - - -class AliyunConfig(BaseTracingConfig): - """ - Model class for Aliyun tracing config. - """ - - app_name: str = "dify_app" - license_key: str - endpoint: str - - @field_validator("app_name") - @classmethod - def app_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - @field_validator("license_key") - @classmethod - def license_key_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("License key cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # aliyun uses two URL formats, which may include a URL path - return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") - - -class TencentConfig(BaseTracingConfig): - """ - Tencent APM tracing config - """ - - token: str - endpoint: str - service_name: str - - @field_validator("token") - @classmethod - def token_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("Token cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") - - @field_validator("service_name") - @classmethod - def service_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - -class MLflowConfig(BaseTracingConfig): - """ - Model class for MLflow tracing config. - """ - - tracking_uri: str = "http://localhost:5000" - experiment_id: str = "0" # Default experiment id in MLflow is 0 - username: str | None = None - password: str | None = None - - @field_validator("tracking_uri") - @classmethod - def tracking_uri_validator(cls, v, info: ValidationInfo): - if isinstance(v, str) and v.startswith("databricks"): - raise ValueError( - "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." - ) - return validate_url_with_path(v, "http://localhost:5000") - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - -class DatabricksConfig(BaseTracingConfig): - """ - Model class for Databricks (Databricks-managed MLflow) tracing config. - """ - - experiment_id: str - host: str - client_id: str | None = None - client_secret: str | None = None - personal_access_token: str | None = None - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index cd63951537..e7ba6e502b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict): class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): def __getitem__(self, provider: str) -> TracingProviderConfigEntry: - match provider: - case TracingProviderEnum.LANGFUSE: - from core.ops.entities.config_entity import LangfuseConfig - from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace + try: + match provider: + case TracingProviderEnum.LANGFUSE: + from dify_trace_langfuse.config import LangfuseConfig + from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace - return { - "config_class": LangfuseConfig, - "secret_keys": ["public_key", "secret_key"], - "other_keys": ["host", "project_key"], - "trace_instance": LangFuseDataTrace, - } + return { + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, + } - case TracingProviderEnum.LANGSMITH: - from core.ops.entities.config_entity import LangSmithConfig - from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace + case TracingProviderEnum.LANGSMITH: + from dify_trace_langsmith.config import LangSmithConfig + from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace - return { - "config_class": LangSmithConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": LangSmithDataTrace, - } + return { + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + } - case TracingProviderEnum.OPIK: - from core.ops.entities.config_entity import OpikConfig - from core.ops.opik_trace.opik_trace import OpikDataTrace + case TracingProviderEnum.OPIK: + from dify_trace_opik.config import OpikConfig + from dify_trace_opik.opik_trace import OpikDataTrace - return { - "config_class": OpikConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "url", "workspace"], - "trace_instance": OpikDataTrace, - } + return { + "config_class": OpikConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "url", "workspace"], + "trace_instance": OpikDataTrace, + } - case TracingProviderEnum.WEAVE: - from core.ops.entities.config_entity import WeaveConfig - from core.ops.weave_trace.weave_trace import WeaveDataTrace + case TracingProviderEnum.WEAVE: + from dify_trace_weave.config import WeaveConfig + from dify_trace_weave.weave_trace import WeaveDataTrace - return { - "config_class": WeaveConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint", "host"], - "trace_instance": WeaveDataTrace, - } - case TracingProviderEnum.ARIZE: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import ArizeConfig + return { + "config_class": WeaveConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "entity", "endpoint", "host"], + "trace_instance": WeaveDataTrace, + } + case TracingProviderEnum.ARIZE: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import ArizeConfig - return { - "config_class": ArizeConfig, - "secret_keys": ["api_key", "space_id"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.PHOENIX: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import PhoenixConfig + return { + "config_class": ArizeConfig, + "secret_keys": ["api_key", "space_id"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.PHOENIX: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import PhoenixConfig - return { - "config_class": PhoenixConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.ALIYUN: - from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace - from core.ops.entities.config_entity import AliyunConfig + return { + "config_class": PhoenixConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.ALIYUN: + from dify_trace_aliyun.aliyun_trace import AliyunDataTrace + from dify_trace_aliyun.config import AliyunConfig - return { - "config_class": AliyunConfig, - "secret_keys": ["license_key"], - "other_keys": ["endpoint", "app_name"], - "trace_instance": AliyunDataTrace, - } - case TracingProviderEnum.MLFLOW: - from core.ops.entities.config_entity import MLflowConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": AliyunConfig, + "secret_keys": ["license_key"], + "other_keys": ["endpoint", "app_name"], + "trace_instance": AliyunDataTrace, + } + case TracingProviderEnum.MLFLOW: + from dify_trace_mlflow.config import MLflowConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": MLflowConfig, - "secret_keys": ["password"], - "other_keys": ["tracking_uri", "experiment_id", "username"], - "trace_instance": MLflowDataTrace, - } - case TracingProviderEnum.DATABRICKS: - from core.ops.entities.config_entity import DatabricksConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": MLflowConfig, + "secret_keys": ["password"], + "other_keys": ["tracking_uri", "experiment_id", "username"], + "trace_instance": MLflowDataTrace, + } + case TracingProviderEnum.DATABRICKS: + from dify_trace_mlflow.config import DatabricksConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": DatabricksConfig, - "secret_keys": ["personal_access_token", "client_secret"], - "other_keys": ["host", "client_id", "experiment_id"], - "trace_instance": MLflowDataTrace, - } + return { + "config_class": DatabricksConfig, + "secret_keys": ["personal_access_token", "client_secret"], + "other_keys": ["host", "client_id", "experiment_id"], + "trace_instance": MLflowDataTrace, + } - case TracingProviderEnum.TENCENT: - from core.ops.entities.config_entity import TencentConfig - from core.ops.tencent_trace.tencent_trace import TencentDataTrace + case TracingProviderEnum.TENCENT: + from dify_trace_tencent.config import TencentConfig + from dify_trace_tencent.tencent_trace import TencentDataTrace - return { - "config_class": TencentConfig, - "secret_keys": ["token"], - "other_keys": ["endpoint", "service_name"], - "trace_instance": TencentDataTrace, - } + return { + "config_class": TencentConfig, + "secret_keys": ["token"], + "other_keys": ["endpoint", "service_name"], + "trace_instance": TencentDataTrace, + } - case _: - raise KeyError(f"Unsupported tracing provider: {provider}") + case _: + raise KeyError(f"Unsupported tracing provider: {provider}") + except ImportError: + raise ImportError(f"Provider {provider} is not installed.") provider_config_map = OpsTraceProviderConfigMap() diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index e3fba4ef3a..4e66d58b5e 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime): if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") file_name = ( - provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us ) elif icon_type.lower() == "icon_small_dark": if not provider_schema.icon_small_dark: raise ValueError(f"Provider {provider} does not have small dark icon.") file_name = ( - provider_schema.icon_small_dark.zh_Hans + provider_schema.icon_small_dark.zh_hans if lang.lower() == "zh_hans" - else provider_schema.icon_small_dark.en_US + else provider_schema.icon_small_dark.en_us ) else: raise ValueError(f"Unsupported icon type: {icon_type}.") diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 8f1d51f08a..7c6280fe93 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index c3bbe8fc09..8969825be4 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -70,12 +70,32 @@ class ProviderManager: Request-bound managers may carry caller identity in that runtime, and the resulting ``ProviderConfiguration`` objects must reuse it for downstream model-type and schema lookups. + + Configuration assembly is cached per manager instance so call chains that + share one request-scoped manager can reuse the same provider graph instead + of rebuilding it for every lookup. Call ``clear_configurations_cache()`` + when a long-lived manager needs to observe writes performed within the same + instance scope. """ + decoding_rsa_key: Any | None + decoding_cipher_rsa: Any | None + _model_runtime: ModelRuntime + _configurations_cache: dict[str, ProviderConfigurations] + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None self._model_runtime = model_runtime + self._configurations_cache = {} + + def clear_configurations_cache(self, tenant_id: str | None = None) -> None: + """Drop assembled provider configurations cached on this manager instance.""" + if tenant_id is None: + self._configurations_cache.clear() + return + + self._configurations_cache.pop(tenant_id, None) def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -114,6 +134,10 @@ class ProviderManager: :param tenant_id: :return: """ + cached_configurations = self._configurations_cache.get(tenant_id) + if cached_configurations is not None: + return cached_configurations + # Get all provider records of the workspace provider_name_to_provider_records_dict = self._get_all_providers(tenant_id) @@ -273,6 +297,8 @@ class ProviderManager: provider_configurations[str(provider_id_entity)] = provider_configuration + self._configurations_cache[tenant_id] = provider_configurations + # Return the encapsulated object return provider_configurations diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index ed264878d3..242da520c1 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -139,8 +139,10 @@ class Jieba(BaseKeyword): "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table - keyword_data_source_type = dataset_keyword_table.data_source_type + keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file" if keyword_data_source_type == "database": + if dataset_keyword_table is None: + return dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict) db.session.commit() else: diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 84f35c25f8..1ca6303af6 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,4 +1,5 @@ import re +from collections.abc import Callable from operator import itemgetter from typing import cast @@ -80,12 +81,14 @@ class JiebaKeywordTableHandler: def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs): # Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable. - top_k = kwargs.pop("topK", top_k) + top_k = cast(int | None, kwargs.pop("topK", top_k)) + if top_k is None: + top_k = 20 cut = getattr(jieba, "cut", None) if self._lcut: tokens = self._lcut(sentence) elif callable(cut): - tokens = list(cut(sentence)) + tokens = list(cast(Callable[[str], list[str]], cut)(sentence)) else: tokens = re.findall(r"\w+", sentence) @@ -108,7 +111,7 @@ class JiebaKeywordTableHandler: sentence=text, topK=max_keywords_per_chunk, ) - # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default. keywords = cast(list[str], keywords) return set(self._expand_tokens_with_subtokens(set(keywords))) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 7e71d67ec0..2997710daf 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -158,7 +158,7 @@ class RetrievalService: ) if futures: - for future in concurrent.futures.as_completed(futures, timeout=3600): + for _ in concurrent.futures.as_completed(futures, timeout=3600): if exceptions: for f in futures: f.cancel() diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 4926f44f16..a9995778f7 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index fbd2a6db93..b679edab36 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -94,6 +94,7 @@ class ExtractProcessor: cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE: + upload_file = extract_setting.upload_file with tempfile.TemporaryDirectory() as temp_dir: upload_file = extract_setting.upload_file if not file_path: @@ -104,6 +105,7 @@ class ExtractProcessor: storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() + assert upload_file is not None, "upload_file is required" etl_type = dify_config.ETL_TYPE extractor: BaseExtractor | None = None if etl_type == "Unstructured": diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 052fca930d..0330a43b28 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -3,6 +3,7 @@ Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). """ +import inspect import logging import mimetypes import os @@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor): file_path: Path to the file to load. """ + _closed: bool + def __init__(self, file_path: str, tenant_id: str, user_id: str): """Initialize with file path.""" + self._closed = False self.file_path = file_path self.tenant_id = tenant_id self.user_id = user_id @@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor): elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") + def close(self) -> None: + """Best-effort cleanup for downloaded temporary files.""" + if getattr(self, "_closed", False): + return + + self._closed = True + temp_file = getattr(self, "temp_file", None) + if temp_file is None: + return + + try: + close_result = temp_file.close() + if inspect.isawaitable(close_result): + close_awaitable = getattr(close_result, "close", None) + if callable(close_awaitable): + close_awaitable() + except Exception: + logger.debug("Failed to cleanup downloaded word temp file", exc_info=True) + def __del__(self): - if hasattr(self, "temp_file"): - self.temp_file.close() + self.close() def extract(self) -> list[Document]: """Load given path as single page.""" diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index f8242efe31..7ffa9afafd 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): try: # Create File object directly (similar to DatasetRetrieval) file_obj = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 1453fe020b..5631b3a921 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile @@ -517,11 +517,11 @@ class DatasetRetrieval: if attachments_with_bindings: for _, upload_file in attachments_with_bindings: attachment_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index e617a9660e..426d1b67dc 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -28,7 +28,7 @@ class FunctionCallMultiDatasetRouter: SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result: LLMResult = model_instance.invoke_llm( + result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType] prompt_messages=prompt_messages, tools=dataset_tools, stream=False, diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 2581c354dd..52c9a02f97 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -4,12 +4,12 @@ from __future__ import annotations import codecs import re -from collections.abc import Collection +from collections.abc import Set as AbstractSet from typing import Any, Literal from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer +from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder[T: EnhanceRecursiveCharacterTextSplitter]( cls: type[T], embedding_model_instance: ModelInstance | None, - allowed_special: Literal["all"] | set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ) -> T: def _token_encoder(texts: list[str]) -> list[int]: @@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return [len(text) for text in texts] + _ = _token_encoder # kept for future token-length wiring return cls(length_function=_character_encoder, **kwargs) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 7f2117e2dd..a8d9013fbc 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -4,7 +4,8 @@ import copy import logging import re from abc import ABC, abstractmethod -from collections.abc import Callable, Collection, Iterable, Sequence, Set +from collections.abc import Callable, Iterable, Sequence +from collections.abc import Set as AbstractSet from dataclasses import dataclass from typing import Any, Literal @@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter): self, encoding_name: str = "gpt2", model_name: str | None = None, - allowed_special: Literal["all"] | Set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ): """Create a new TextSplitter.""" @@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter): else: enc = tiktoken.get_encoding(encoding_name) self._tokenizer = enc - self._allowed_special = allowed_special - self._disallowed_special = disallowed_special + self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special + self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special def split_text(self, text: str) -> list[str]: def _encode(_text: str) -> list[int]: diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 02625e242f..740d727e26 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -8,7 +8,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 4c3efd6ff9..2b26832b44 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -38,6 +38,17 @@ class ToolCredentialPolicyViolationError(ValueError): pass +class ApiToolProviderNotFoundError(ValueError): + error_code = "api_tool_provider_not_found" + provider_name: str + tenant_id: str + + def __init__(self, provider_name: str, tenant_id: str): + self.provider_name = provider_name + self.tenant_id = tenant_id + super().__init__(f"api provider {provider_name} does not exist") + + class WorkflowToolHumanInputNotSupportedError(BaseHTTPException): error_code = "workflow_tool_human_input_not_supported" description = "Workflow with Human Input nodes cannot be published as a workflow tool." diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index b3424cd9a5..c87e8a3ae0 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -28,7 +28,7 @@ class ToolFileManager: def _build_graph_file_reference(tool_file: ToolFile) -> File: extension = guess_extension(tool_file.mimetype) or ".bin" return File( - type=get_file_type_by_mime_type(tool_file.mimetype), + file_type=get_file_type_by_mime_type(tool_file.mimetype), transfer_method=FileTransferMethod.TOOL_FILE, remote_url=tool_file.original_url, reference=build_file_reference(record_id=str(tool_file.id)), diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f4588904d3..87cf6d7085 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1082,7 +1082,12 @@ class ToolManager: continue tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + variable_selector = tool_input.value + if not isinstance(variable_selector, list) or not all( + isinstance(selector_part, str) for selector_part in variable_selector + ): + raise ToolParameterError("Variable tool input must be a variable selector") + variable = variable_pool.get(variable_selector) if variable is None: raise ToolParameterError(f"Variable {tool_input.value} does not exist") parameter_value = variable.value diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 79d0c114d4..5679466cbc 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -41,6 +41,10 @@ def safe_json_value(v): return v.hex() elif isinstance(v, memoryview): return v.tobytes().hex() + elif isinstance(v, np.integer): + return int(v) + elif isinstance(v, np.floating): + return float(v) elif isinstance(v, np.ndarray): return v.tolist() elif isinstance(v, dict): diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e1d41cb39..a3623d4ecd 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index ed3ed3e0de..94a2c0427b 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -105,7 +105,7 @@ class Article: def extract_using_readabilipy(html: str): - json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True) + json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False) article = Article( title=json_article.get("title") or "", author=json_article.get("byline") or "", diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 52ab605963..cd8c6352b5 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -357,7 +357,10 @@ class WorkflowTool(Tool): def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]: file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) - transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) + transfer_method_value = file_dict.get("transfer_method") + if not isinstance(transfer_method_value, str): + raise ValueError("Workflow file mapping is missing a valid transfer_method") + transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.TOOL_FILE: file_dict["tool_file_id"] = file_id diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py similarity index 74% rename from api/core/workflow/human_input_compat.py rename to api/core/workflow/human_input_adapter.py index 75a0a0c202..4b765e6aea 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_adapter.py @@ -1,8 +1,8 @@ -"""Workflow-layer adapters for legacy human-input payload keys. +"""Workflow-to-Graphon adapters for persisted node payloads. -Stored workflow graphs and editor payloads may still use Dify-specific human -input recipient keys. Normalize them here before handing configs to -`graphon` so graph-owned models only see graph-neutral field names. +Stored workflow graphs and editor payloads still contain a small set of +Dify-owned field spellings and value shapes. Adapt them here before handing the +payload to Graphon so Graphon-owned models only see current contracts. """ from __future__ import annotations @@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None: return None -def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") @@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: - normalized = normalize_human_input_node_data_for_graph(node_data) + normalized = adapt_human_input_node_data_for_graph(node_data) raw_delivery_methods = normalized.get("delivery_methods") if not isinstance(raw_delivery_methods, list): return [] @@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b return False -def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") - if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: - return normalized - return normalize_human_input_node_data_for_graph(normalized) + node_type = normalized.get("type") + if node_type == BuiltinNodeTypes.HUMAN_INPUT: + return adapt_human_input_node_data_for_graph(normalized) + if node_type == BuiltinNodeTypes.TOOL: + return _adapt_tool_node_data_for_graph(normalized) + return normalized -def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_config) if normalized is None: raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") @@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) if data_mapping is None: return normalized - normalized["data"] = normalize_node_data_for_graph(data_mapping) + normalized["data"] = adapt_node_data_for_graph(data_mapping) return normalized +def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(node_data) + + raw_tool_configurations = normalized.get("tool_configurations") + if not isinstance(raw_tool_configurations, Mapping): + return normalized + + existing_tool_parameters = normalized.get("tool_parameters") + normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {} + normalized_tool_configurations: dict[str, Any] = {} + found_legacy_tool_inputs = False + + for name, value in raw_tool_configurations.items(): + if not isinstance(value, Mapping): + normalized_tool_configurations[name] = value + continue + + input_type = value.get("type") + input_value = value.get("value") + if input_type not in {"mixed", "variable", "constant"}: + normalized_tool_configurations[name] = value + continue + + found_legacy_tool_inputs = True + normalized_tool_parameters.setdefault(name, dict(value)) + + flattened_value = _flatten_legacy_tool_configuration_value( + input_type=input_type, + input_value=input_value, + ) + if flattened_value is not None: + normalized_tool_configurations[name] = flattened_value + + if not found_legacy_tool_inputs: + return normalized + + normalized["tool_parameters"] = normalized_tool_parameters + normalized["tool_configurations"] = normalized_tool_configurations + return normalized + + +def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None: + if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool): + return input_value + + if ( + input_type == "variable" + and isinstance(input_value, list) + and all(isinstance(item, str) for item in input_value) + ): + return "{{#" + ".".join(input_value) + "#}}" + + return None + + def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: normalized = dict(recipients) @@ -291,9 +349,9 @@ __all__ = [ "MemberRecipient", "WebAppDeliveryMethod", "_WebAppDeliveryConfig", + "adapt_human_input_node_data_for_graph", + "adapt_node_config_for_graph", + "adapt_node_data_for_graph", "is_human_input_webapp_enabled", - "normalize_human_input_node_data_for_graph", - "normalize_node_config_for_graph", - "normalize_node_data_for_graph", "parse_human_input_delivery_methods", ] diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 351da3444f..de4eae1b22 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, ) -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.trigger.constants import TRIGGER_NODE_TYPES -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.node_runtime import ( DifyFileReferenceFactory, DifyHumanInputNodeRuntime, @@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType from graphon.file.file_manager import file_manager from graphon.graph.graph import NodeFactory from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.nodes.base.node import Node from graphon.nodes.code.code_node import WorkflowCodeExecutor from graphon.nodes.code.entities import CodeLanguage @@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + """Resolve the production node class for the requested type/version.""" node_mapping = get_node_type_classes_mapping().get(node_type) if not node_mapping: raise ValueError(f"No class mapping found for node type: {node_type}") @@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory): ) self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH - self._http_request_http_client = ssrf_proxy + self._http_request_http_client = graphon_ssrf_proxy self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( self._dify_context, conversation_id_getter=self._conversation_id, @@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) + # Graph configs are initially validated against permissive shared node data. + # Re-validate using the resolved node class so workflow-local node schemas + # stay explicit and constructors receive the concrete typed payload. + resolved_node_data = self._validate_resolved_node_data(node_class, node_data) node_type = node_data.type node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { BuiltinNodeTypes.CODE: lambda: { @@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory): ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=False, include_llm_file_saver=False, @@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory): } node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, **node_init_kwargs, @@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory): """ Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class. """ - return node_class.validate_node_data(node_data) + validate_node_data = getattr(node_class, "validate_node_data", None) + if callable(validate_node_data): + return cast("BaseNodeData", validate_node_data(node_data)) + return node_data @staticmethod def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 2e632e56f0..b8725853c4 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload from sqlalchemy import select from sqlalchemy.orm import Session @@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import ( ) from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.llm.runtime_protocols import ( PreparedLLMProtocol, @@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from .human_input_compat import ( +from .human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, DeliveryMethodType, @@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: return self._model_instance.get_llm_num_tokens(prompt_messages) + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + def invoke_llm( self, *, @@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): stream=stream, ) + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + def invoke_llm_with_structured_output( self, *, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7b000101b0..68a24e86b1 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.workflow.system_variables import SystemVariableKey, get_system_text -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node @@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: AgentNodeData, + *, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, - *, strategy_resolver: AgentStrategyResolver, presentation_provider: AgentStrategyPresentationProvider, runtime_support: AgentRuntimeSupport, message_transformer: AgentMessageTransformer, ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index e4f6b3b470..f3006c4242 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import SystemVariableKey, get_system_segment -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, @@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: DatasourceNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d5cab05dbe..9c1b7ab2c4 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node @@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeIndexNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: - super().__init__(id, config, graph_init_params, graph_runtime_state) + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) self.index_processor = IndexProcessor() self.summary_index_service = SummaryIndex() diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 47ad14b499..25f73e446d 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict, from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, @@ -50,6 +49,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None: + if value is None or isinstance(value, (str, float)): + return value + if isinstance(value, int) and not isinstance(value, bool): + return value + return str(value) + + +def _normalize_metadata_filter_sequence_item(value: object) -> str: + return value if isinstance(value, str) else str(value) + + class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL @@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeRetrievalNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, @@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD resolved_conditions: list[Condition] = [] for cond in conditions.conditions or []: value = cond.value + resolved_value: str | Sequence[str] | int | float | None if isinstance(value, str): segment_group = variable_pool.convert_template(value) if len(segment_group.value) == 1: - resolved_value = segment_group.value[0].to_object() + resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object()) else: resolved_value = segment_group.text elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value): - resolved_values = [] - for v in value: # type: ignore + resolved_values: list[str] = [] + for v in value: segment_group = variable_pool.convert_template(v) if len(segment_group.value) == 1: - resolved_values.append(segment_group.value[0].to_object()) + resolved_values.append( + _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object()) + ) else: resolved_values.append(segment_group.text) resolved_value = resolved_values diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index b7e7a6e60f..0c535a1c5b 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -6,9 +6,9 @@ import click from sqlalchemy import select from werkzeug.exceptions import NotFound +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.document_index_event import document_index_created -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document from models.enums import IndexingStatus @@ -22,24 +22,25 @@ def handle(sender, **kwargs): document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) + with session_factory.create_session() as session: + for document_id in document_ids: + logger.info(click.style(f"Start process document: {document_id}", fg="green")) - document = db.session.scalar( - select(Document).where( - Document.id == document_id, - Document.dataset_id == dataset_id, + document = session.scalar( + select(Document).where( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) ) - ) - if not document: - raise NotFound("Document not found") + if not document: + raise NotFound("Document not found") - document.indexing_status = IndexingStatus.PARSING - document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() + document.indexing_status = IndexingStatus.PARSING + document.processing_started_at = naive_utc_now() + documents.append(document) + session.add(document) + session.commit() with contextlib.suppress(Exception): try: diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py index 57412cc4ad..38e102d5fd 100644 --- a/api/events/event_handlers/create_installed_app_when_app_created.py +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -1,5 +1,5 @@ +from core.db.session_factory import session_factory from events.app_event import app_was_created -from extensions.ext_database import db from models.model import InstalledApp @@ -12,5 +12,6 @@ def handle(sender, **kwargs): app_id=app.id, app_owner_tenant_id=app.tenant_id, ) - db.session.add(installed_app) - db.session.commit() + with session_factory.create_session() as session: + session.add(installed_app) + session.commit() diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 84be592b1a..5e2a456dce 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,5 @@ +from core.db.session_factory import session_factory from events.app_event import app_was_created -from extensions.ext_database import db from models.enums import CustomizeTokenStrategy from models.model import Site @@ -22,6 +22,6 @@ def handle(sender, **kwargs): created_by=app.created_by, updated_by=app.updated_by, ) - - db.session.add(site) - db.session.commit() + with session_factory.create_session() as session: + session.add(site) + session.commit() diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index 288d37d265..1d2ad4d445 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -10,8 +10,8 @@ from typing import Any from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from models import ToolFile, UploadFile @@ -135,29 +135,30 @@ def _build_from_local_file( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if row is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + row = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type", "custom"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - reference=build_file_reference(record_id=str(row.id)), - size=row.size, - storage_key=row.key, - ) + return File( + file_id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) def _build_from_remote_url( @@ -179,32 +180,33 @@ def _build_from_remote_url( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if upload_file is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type( - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - ) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - reference=build_file_reference(record_id=str(upload_file.id)), - size=upload_file.size, - storage_key=upload_file.key, - ) + return File( + file_id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) url = mapping.get("url") or mapping.get("remote_url") if not url: @@ -220,9 +222,9 @@ def _build_from_remote_url( ) return File( - id=mapping.get("id"), + file_id=mapping.get("id"), filename=filename, - type=file_type, + file_type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, @@ -247,30 +249,31 @@ def _build_from_tool_file( ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) - tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") + with session_factory.create_session() as session: + tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - reference=build_file_reference(record_id=str(tool_file.id)), - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) + return File( + file_id=mapping.get("id"), + filename=tool_file.name, + file_type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) def _build_from_datasource_file( @@ -289,31 +292,32 @@ def _build_from_datasource_file( UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) - datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + with session_factory.create_session() as session: + datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("datasource_file_id"), - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - reference=build_file_reference(record_id=str(datasource_file.id)), - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) + return File( + file_id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + file_type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index b5acbbbcb4..d518114777 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False): def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): - return v.value_type.exposed_type().value + return str(v.value_type.exposed_type()) else: value_type = v.get("value_type") if value_type is None: raise ValueError("value_type is required but not provided") - return value_type.exposed_type().value + return str(value_type.exposed_type()) diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index cf4a71d545..e4219ba1ee 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel): def _normalize_value_type(cls, value: Any) -> str: exposed_type = getattr(value, "exposed_type", None) if callable(exposed_type): - return str(exposed_type().value) + return str(exposed_type()) if isinstance(value, str): try: - return str(SegmentType(value).exposed_type().value) + return str(SegmentType(value).exposed_type()) except ValueError: return value try: diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index f9b5e98936..6e947858ba 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.exposed_type().value, + "value_type": str(value.value_type.exposed_type()), "description": value.description, } if isinstance(value, dict): diff --git a/api/libs/flask_utils.py b/api/libs/flask_utils.py index 52fc787c79..838af2bf32 100644 --- a/api/libs/flask_utils.py +++ b/api/libs/flask_utils.py @@ -1,5 +1,5 @@ import contextvars -from collections.abc import Iterator +from collections.abc import Generator # Changed from Iterator from contextlib import contextmanager from typing import TYPE_CHECKING @@ -13,7 +13,7 @@ if TYPE_CHECKING: def preserve_flask_contexts( flask_app: Flask, context_vars: contextvars.Context, -) -> Iterator[None]: +) -> Generator[None, None, None]: # Changed from Iterator[None] """ A context manager that handles: 1. flask-login's UserProxy copy diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 9b53918f24..934aacb45b 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -6,8 +6,8 @@ from flask_login import current_user from pydantic import TypeAdapter from sqlalchemy import select +from core.db.session_factory import session_factory from core.helper.http_client_pooling import get_pooled_http_client -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def save_internal_access_token(self, access_token: str) -> None: workspace_name = self.notion_workspace_name(access_token) @@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def sync_data_source(self, binding_id: str) -> None: # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, + ) ) - ) - if data_source_binding: - # get all authorized pages - pages = self.get_authorized_pages(data_source_binding.access_token) - source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) - new_source_info = self._build_source_info( - workspace_name=source_info["workspace_name"], - workspace_icon=source_info["workspace_icon"], - workspace_id=source_info["workspace_id"], - pages=pages, - ) - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - raise ValueError("Data source binding not found") + if data_source_binding: + # get all authorized pages + pages = self.get_authorized_pages(data_source_binding.access_token) + source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) + new_source_info = self._build_source_info( + workspace_name=source_info["workspace_name"], + workspace_icon=source_info["workspace_icon"], + workspace_id=source_info["workspace_id"], + pages=pages, + ) + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + raise ValueError("Data source binding not found") def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: pages: list[NotionPageSummary] = [] diff --git a/api/libs/token.py b/api/libs/token.py index a34db70764..5b043465ac 100644 --- a/api/libs/token.py +++ b/api/libs/token.py @@ -47,23 +47,17 @@ def _cookie_domain() -> str | None: def _real_cookie_name(cookie_name: str) -> str: if is_secure() and _cookie_domain() is None: return "__Host-" + cookie_name - else: - return cookie_name + return cookie_name def _try_extract_from_header(request: Request) -> str | None: auth_header = request.headers.get("Authorization") - if auth_header: - if " " not in auth_header: - return None - else: - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - return None - else: - return auth_token - return None + if not auth_header or " " not in auth_header: + return None + auth_scheme, auth_token = auth_header.split(None, 1) + if auth_scheme.lower() != "bearer": + return None + return auth_token def extract_refresh_token(request: Request) -> str | None: @@ -90,14 +84,9 @@ def extract_webapp_access_token(request: Request) -> str | None: def extract_webapp_passport(app_code: str, request: Request) -> str | None: - def _try_extract_passport_token_from_cookie(request: Request) -> str | None: - return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) - - def _try_extract_passport_token_from_header(request: Request) -> str | None: - return request.headers.get(HEADER_NAME_PASSPORT) - - ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request) - return ret + return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) or request.headers.get( + HEADER_NAME_PASSPORT + ) def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"): @@ -209,22 +198,18 @@ def check_csrf_token(request: Request, user_id: str): if not csrf_token: _unauthorized() - verified = {} try: verified = PassportService().verify(csrf_token) - except: + except Exception: _unauthorized() + raise # unreachable, but helps the type checker see verified is always bound if verified.get("sub") != user_id: _unauthorized() exp: int | None = verified.get("exp") - if not exp: + if not exp or exp < int(datetime.now(UTC).timestamp()): _unauthorized() - else: - time_now = int(datetime.now().timestamp()) - if exp < time_now: - _unauthorized() def generate_csrf_token(user_id: str) -> str: diff --git a/api/libs/url_utils.py b/api/libs/url_utils.py new file mode 100644 index 0000000000..adcac3add0 --- /dev/null +++ b/api/libs/url_utils.py @@ -0,0 +1,3 @@ +def normalize_api_base_url(base_url: str) -> str: + """Normalize a base URL to always end with /v1, avoiding double /v1 suffixes.""" + return base_url.rstrip("/").removesuffix("/v1").rstrip("/") + "/v1" diff --git a/api/models/comment.py b/api/models/comment.py index 308339e6f6..1154e16788 100644 --- a/api/models/comment.py +++ b/api/models/comment.py @@ -3,6 +3,7 @@ from datetime import datetime from typing import Optional +import sqlalchemy as sa from sqlalchemy import Index, func from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -36,24 +37,24 @@ class WorkflowComment(Base): __tablename__ = "workflow_comments" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), Index("workflow_comments_app_idx", "tenant_id", "app_id"), Index("workflow_comments_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - position_x: Mapped[float] = mapped_column(db.Float) - position_y: Mapped[float] = mapped_column(db.Float) - content: Mapped[str] = mapped_column(db.Text, nullable=False) + position_x: Mapped[float] = mapped_column(sa.Float) + position_y: Mapped[float] = mapped_column(sa.Float) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime) + resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime) resolved_by: Mapped[str | None] = mapped_column(StringUUID) # Relationships @@ -143,20 +144,20 @@ class WorkflowCommentReply(Base): __tablename__ = "workflow_comment_replies" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), Index("comment_replies_comment_idx", "comment_id"), Index("comment_replies_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()")) comment_id: Mapped[str] = mapped_column( - StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) - content: Mapped[str] = mapped_column(db.Text, nullable=False) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) # Relationships comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") @@ -187,18 +188,18 @@ class WorkflowCommentMention(Base): __tablename__ = "workflow_comment_mentions" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), Index("comment_mentions_comment_idx", "comment_id"), Index("comment_mentions_reply_idx", "reply_id"), Index("comment_mentions_user_idx", "mentioned_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()")) comment_id: Mapped[str] = mapped_column( - StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) reply_id: Mapped[str | None] = mapped_column( - StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True + StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True ) mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) diff --git a/api/models/dataset.py b/api/models/dataset.py index 50301dd2d7..eee5c39a0e 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1715,7 +1715,7 @@ class SegmentAttachmentBinding(TypeBase): ) -class DocumentSegmentSummary(Base): +class DocumentSegmentSummary(TypeBase): __tablename__ = "document_segment_summaries" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), @@ -1725,25 +1725,40 @@ class DocumentSegmentSummary(Base): sa.Index("document_segment_summaries_status_idx", "status"), ) - id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, + ) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # corresponds to DocumentSegment.id or parent chunk id chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - summary_content: Mapped[str] = mapped_column(LongText, nullable=True) - summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) - summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) - tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - status: Mapped[str] = mapped_column( - EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'") + summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + status: Mapped[SummaryStatus] = mapped_column( + EnumText(SummaryStatus, length=32), + nullable=False, + server_default=sa.text("'generating'"), + default=SummaryStatus.GENERATING, + ) + error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) - error: Mapped[str] = mapped_column(LongText, nullable=True) - enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - disabled_by = mapped_column(StringUUID, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) def __repr__(self): diff --git a/api/models/human_input.py b/api/models/human_input.py index b4c7a634b6..7447d3efcb 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -6,7 +6,7 @@ import sqlalchemy as sa from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from core.workflow.human_input_compat import DeliveryMethodType +from core.workflow.human_input_adapter import DeliveryMethodType from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string diff --git a/api/models/model.py b/api/models/model.py index 7fe0731098..a1117fc43a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -25,6 +25,7 @@ from graphon.enums import WorkflowExecutionStatus from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] +from libs.url_utils import normalize_api_base_url from libs.uuid_utils import uuidv7 from models.utils.file_input_compat import build_file_from_input_mapping @@ -446,7 +447,8 @@ class App(Base): @property def api_base_url(self) -> str: - return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" + base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/") + return normalize_api_base_url(base) @property def tenant(self) -> Tenant | None: diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index a2dc8f6157..77dcbd13d4 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -5,7 +5,8 @@ from functools import lru_cache from typing import Any from core.workflow.file_reference import parse_file_reference -from graphon.file import File, FileTransferMethod +from graphon.file import File, FileTransferMethod, FileType +from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object @lru_cache(maxsize=1) @@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id( return tenant_resolver() +def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File: + """Build a graph `File` directly from serialized metadata.""" + + def _coerce_file_type(value: Any) -> FileType: + if isinstance(value, FileType): + return value + if isinstance(value, str): + return FileType.value_of(value) + raise ValueError("file type is required in file mapping") + + mapping = dict(file_mapping) + transfer_method_value = mapping.get("transfer_method") + if isinstance(transfer_method_value, FileTransferMethod): + transfer_method = transfer_method_value + elif isinstance(transfer_method_value, str): + transfer_method = FileTransferMethod.value_of(transfer_method_value) + else: + raise ValueError("transfer_method is required in file mapping") + + file_id = mapping.get("file_id") + if not isinstance(file_id, str) or not file_id: + legacy_id = mapping.get("id") + file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None + + related_id = resolve_file_record_id(mapping) + if related_id is None: + raw_related_id = mapping.get("related_id") + related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None + + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + remote_url = url if isinstance(url, str) and url else None + + reference = mapping.get("reference") + if not isinstance(reference, str) or not reference: + reference = None + + filename = mapping.get("filename") + if not isinstance(filename, str): + filename = None + + extension = mapping.get("extension") + if not isinstance(extension, str): + extension = None + + mime_type = mapping.get("mime_type") + if not isinstance(mime_type, str): + mime_type = None + + size = mapping.get("size", -1) + if not isinstance(size, int): + size = -1 + + storage_key = mapping.get("storage_key") + if not isinstance(storage_key, str): + storage_key = None + + tenant_id = mapping.get("tenant_id") + if not isinstance(tenant_id, str): + tenant_id = None + + dify_model_identity = mapping.get("dify_model_identity") + if not isinstance(dify_model_identity, str): + dify_model_identity = FILE_MODEL_IDENTITY + + tool_file_id = mapping.get("tool_file_id") + if not isinstance(tool_file_id, str): + tool_file_id = None + + upload_file_id = mapping.get("upload_file_id") + if not isinstance(upload_file_id, str): + upload_file_id = None + + datasource_file_id = mapping.get("datasource_file_id") + if not isinstance(datasource_file_id, str): + datasource_file_id = None + + return File( + file_id=file_id, + tenant_id=tenant_id, + file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))), + transfer_method=transfer_method, + remote_url=remote_url, + reference=reference, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + storage_key=storage_key, + dify_model_identity=dify_model_identity, + url=remote_url, + tool_file_id=tool_file_id, + upload_file_id=upload_file_id, + datasource_file_id=datasource_file_id, + ) + + +def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any: + """Recursively rebuild serialized graph file payloads into `File` objects. + + `graphon` 0.2.2 no longer accepts legacy serialized file mappings via + `model_validate_json()`. Dify keeps this recovery path at the model boundary + so historical JSON blobs remain readable without reintroducing global graph + patches or test-local coercion. + """ + if isinstance(value, list): + return [rebuild_serialized_graph_files_without_lookup(item) for item in value] + + if isinstance(value, dict): + if maybe_file_object(value): + return build_file_from_mapping_without_lookup(file_mapping=value) + return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()} + + return value + + def build_file_from_stored_mapping( *, file_mapping: Mapping[str, Any], @@ -76,12 +195,7 @@ def build_file_from_stored_mapping( pass if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: - remote_url = mapping.get("remote_url") - if not isinstance(remote_url, str) or not remote_url: - url = mapping.get("url") - if isinstance(url, str) and url: - mapping["remote_url"] = url - return File.model_validate(mapping) + return build_file_from_mapping_without_lookup(file_mapping=mapping) return file_factory.build_from_mapping( mapping=mapping, diff --git a/api/models/workflow.py b/api/models/workflow.py index dfda03c2ee..d127244b0f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, @@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID -from .utils.file_input_compat import build_file_from_stored_mapping +from .utils.file_input_compat import ( + build_file_from_mapping_without_lookup, + build_file_from_stored_mapping, +) logger = logging.getLogger(__name__) @@ -290,7 +293,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -1688,7 +1691,7 @@ class WorkflowDraftVariable(Base): return cast(Any, value) normalized_file = dict(value) normalized_file.pop("tenant_id", None) - return File.model_validate(normalized_file) + return build_file_from_mapping_without_lookup(file_mapping=normalized_file) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] @@ -1698,7 +1701,7 @@ class WorkflowDraftVariable(Base): for item in value_list: normalized_file = dict(cast(dict[str, Any], item)) normalized_file.pop("tenant_id", None) - file_list.append(File.model_validate(normalized_file)) + file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file)) return cast(Any, file_list) else: return cast(Any, value) diff --git a/api/providers/README.md b/api/providers/README.md index a00ec8bc52..5d5e6db9af 100644 --- a/api/providers/README.md +++ b/api/providers/README.md @@ -10,3 +10,6 @@ This directory holds **optional workspace packages** that plug into Dify’s API Provider tests often live next to the package, e.g. `providers///tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`). +## Excluding Providers + +In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-` and `--group trace-` to enable specific providers. diff --git a/api/providers/trace/README.md b/api/providers/trace/README.md new file mode 100644 index 0000000000..a7ffa5ed26 --- /dev/null +++ b/api/providers/trace/README.md @@ -0,0 +1,78 @@ +# Trace providers + +This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others). + +Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping. + +## Architecture + +| Layer | Location | Role | +|--------|----------|------| +| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads | +| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class | +| Your package | `api/providers/trace/trace-/` | Pydantic config + subclass of `BaseTraceInstance` | + +At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype. + +## What you implement + +### 1. Config model (`BaseTracingConfig`) + +Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate. + +Fields fall into two groups used by the manager: + +- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords). +- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints). + +List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct. + +### 2. Trace instance (`BaseTraceInstance`) + +Subclass `BaseTraceInstance` and implement: + +```python +def trace(self, trace_info: BaseTraceInfo) -> None: + ... +``` + +Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including: + +- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace` +- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo` +- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo` + +You may ignore categories your backend does not support; existing providers often no-op unhandled types. + +Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context. + +### 3. Register in the API core + +Upstream changes are required so Dify knows your provider exists: + +1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`). +2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning: + - `config_class`: your Pydantic config type + - `secret_keys` / `other_keys`: lists of field names as above + - `trace_instance`: your `BaseTraceInstance` subclass + Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`. + +If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app. + +## Package layout + +Each provider is a normal uv workspace member, for example: + +- `api/providers/trace/trace-/pyproject.toml` — project name `dify-trace-`, dependencies on vendor SDKs +- `api/providers/trace/trace-/src/dify_trace_/` — `config.py`, `_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`. + +Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`. + +## Wiring into the `api` workspace + +In `api/pyproject.toml`: + +1. **`[tool.uv.sources]`** — `dify-trace- = { workspace = true }` +2. **`[dependency-groups]`** — add `trace- = ["dify-trace-"]` and include `dify-trace-` in `trace-all` if it should ship with the default bundle + +After changing metadata, run **`uv sync`** from `api/`. diff --git a/api/providers/trace/trace-aliyun/pyproject.toml b/api/providers/trace/trace-aliyun/pyproject.toml new file mode 100644 index 0000000000..bcef7e9fb1 --- /dev/null +++ b/api/providers/trace/trace-aliyun/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-aliyun" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Aliyun)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py similarity index 98% rename from api/core/ops/aliyun_trace/aliyun_trace.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py index 76e81242f4..54d2f8167f 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py @@ -4,7 +4,20 @@ from collections.abc import Sequence from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.data_exporter.traceclient import ( TraceClient, build_endpoint, convert_datetime_to_nanoseconds, @@ -12,8 +25,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( convert_to_trace_id, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata +from dify_trace_aliyun.entities.semconv import ( DIFY_APP_ID, GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, @@ -32,7 +45,7 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -44,19 +57,6 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) -from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import AliyunConfig -from core.ops.entities.trace_entity import ( - BaseTraceInfo, - DatasetRetrievalTraceInfo, - GenerateNameTraceInfo, - MessageTraceInfo, - ModerationTraceInfo, - SuggestedQuestionTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from graphon.entities import WorkflowNodeExecution from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey diff --git a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py new file mode 100644 index 0000000000..e0133e6cc9 --- /dev/null +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py @@ -0,0 +1,32 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class AliyunConfig(BaseTracingConfig): + """ + Model class for Aliyun tracing config. + """ + + app_name: str = "dify_app" + license_key: str + endpoint: str + + @field_validator("app_name") + @classmethod + def app_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") + + @field_validator("license_key") + @classmethod + def license_key_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("License key cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # aliyun uses two URL formats, which may include a URL path + return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/data_exporter/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py similarity index 98% rename from api/core/ops/aliyun_trace/data_exporter/traceclient.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py index 67d5163b0f..00aab6bf89 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py @@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData -from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE INVALID_SPAN_ID: Final[int] = 0x0000000000000000 INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000 diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/semconv.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py diff --git a/api/core/ops/arize_phoenix_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed similarity index 100% rename from api/core/ops/arize_phoenix_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed diff --git a/api/core/ops/aliyun_trace/utils.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py similarity index 97% rename from api/core/ops/aliyun_trace/utils.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py index 2e02a186cc..5678c66adb 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py @@ -4,7 +4,8 @@ from typing import Any, TypedDict from opentelemetry.trace import Link, Status, StatusCode -from core.ops.aliyun_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -13,7 +14,6 @@ from core.ops.aliyun_trace.entities.semconv import ( OUTPUT_VALUE, GenAISpanKind, ) -from core.rag.models.document import Document from extensions.ext_database import db from graphon.entities import WorkflowNodeExecution from graphon.enums import WorkflowNodeExecutionStatus @@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status: def create_links_from_trace_id(trace_id: str | None) -> list[Link]: - from core.ops.aliyun_trace.data_exporter.traceclient import create_link + from dify_trace_aliyun.data_exporter.traceclient import create_link links = [] if trace_id: diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py similarity index 86% rename from api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py index acb43d4036..286dda419c 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py @@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from opentelemetry.sdk.trace import ReadableSpan -from opentelemetry.trace import SpanKind, Status, StatusCode - -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from dify_trace_aliyun.data_exporter.traceclient import ( INVALID_SPAN_ID, SpanBuilder, TraceClient, @@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( create_link, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.trace import SpanKind, Status, StatusCode @pytest.fixture @@ -41,8 +40,8 @@ def trace_client_factory(): class TestTraceClient: - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname") def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): mock_gethostname.return_value = "test-host" client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -56,7 +55,7 @@ class TestTraceClient: client.shutdown() assert client.done is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -64,8 +63,8 @@ class TestTraceClient: client.export(spans) mock_exporter.export.assert_called_once_with(spans) - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 405 @@ -74,8 +73,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 500 @@ -84,8 +83,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is False - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): mock_head.side_effect = httpx.RequestError("Connection error") @@ -93,12 +92,12 @@ class TestTraceClient: with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"): client.api_check() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_get_project_url(self, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_add_span(self, mock_exporter_class, trace_client_factory): client = trace_client_factory( service_name="test-service", @@ -134,8 +133,8 @@ class TestTraceClient: assert len(client.queue) == 2 mock_notify.assert_called_once() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.logger") def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) @@ -159,7 +158,7 @@ class TestTraceClient: assert len(client.queue) == 1 mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export_batch_error(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value mock_exporter.export.side_effect = Exception("Export failed") @@ -168,11 +167,11 @@ class TestTraceClient: mock_span = MagicMock(spec=ReadableSpan) client.queue.append(mock_span) - with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger: + with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger: client._export_batch() mock_logger.warning.assert_called() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_worker_loop(self, mock_exporter_class, trace_client_factory): # We need to test the wait timeout in _worker # But _worker runs in a thread. Let's mock condition.wait. @@ -189,7 +188,7 @@ class TestTraceClient: # mock_wait might have been called assert mock_wait.called or client.done - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -268,7 +267,7 @@ def test_generate_span_id(): assert span_id != INVALID_SPAN_ID # Test retry loop - with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand: + with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand: mock_rand.side_effect = [INVALID_SPAN_ID, 999] span_id = generate_span_id() assert span_id == 999 @@ -290,7 +289,7 @@ def test_convert_to_trace_id(): def test_convert_string_to_id(): assert convert_string_to_id("test") > 0 # Test with None string - with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen: + with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen: mock_gen.return_value = 12345 assert convert_string_to_id(None) == 12345 diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py index 2fcb927e0c..38d33dd21b 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -1,11 +1,10 @@ import pytest +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event from opentelemetry.trace import SpanKind, Status, StatusCode from pydantic import ValidationError -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata - class TestTraceMetadata: def test_trace_metadata_init(self): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py index 3961555b9a..9cab40748f 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py @@ -1,4 +1,4 @@ -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( ACS_ARMS_SERVICE_FEATURE, GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py similarity index 99% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py index c2324fdec4..c1b11c9186 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py @@ -4,12 +4,11 @@ from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import MagicMock +import dify_trace_aliyun.aliyun_trace as aliyun_trace_module import pytest -from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags - -import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module -from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.aliyun_trace import AliyunDataTrace +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, GEN_AI_OUTPUT_MESSAGE, @@ -24,7 +23,8 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.entities.config_entity import AliyunConfig +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py similarity index 95% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py index e4d8f2d5ea..a9e7b80c2a 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py @@ -1,9 +1,7 @@ import json from unittest.mock import MagicMock -from opentelemetry.trace import Link, StatusCode - -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -11,7 +9,7 @@ from core.ops.aliyun_trace.entities.semconv import ( INPUT_VALUE, OUTPUT_VALUE, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -23,6 +21,8 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) +from opentelemetry.trace import Link, StatusCode + from core.rag.models.document import Document from graphon.entities import WorkflowNodeExecution from graphon.enums import WorkflowNodeExecutionStatus @@ -48,7 +48,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = end_user_data - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -63,7 +63,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = None - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -112,9 +112,9 @@ def test_get_workflow_node_status(): def test_create_links_from_trace_id(monkeypatch): # Mock create_link mock_link = MagicMock(spec=Link) - import core.ops.aliyun_trace.data_exporter.traceclient + import dify_trace_aliyun.data_exporter.traceclient - monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) + monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) # Trace ID None assert create_links_from_trace_id(None) == [] diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..1b24ee7421 --- /dev/null +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py @@ -0,0 +1,85 @@ +import pytest +from dify_trace_aliyun.config import AliyunConfig +from pydantic import ValidationError + + +class TestAliyunConfig: + """Test cases for AliyunConfig""" + + def test_valid_config(self): + """Test valid Aliyun configuration""" + config = AliyunConfig( + app_name="test_app", + license_key="test_license_key", + endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", + ) + assert config.app_name == "test_app" + assert config.license_key == "test_license_key" + assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + assert config.app_name == "dify_app" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + AliyunConfig() + + with pytest.raises(ValidationError): + AliyunConfig(license_key="test_license") + + with pytest.raises(ValidationError): + AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_app_name_validation_empty(self): + """Test app_name validation with empty value""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" + ) + assert config.app_name == "dify_app" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = AliyunConfig(license_key="test_license", endpoint="") + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation preserves path for Aliyun endpoints""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + ) + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_license_key_required(self): + """Test that license_key is required and cannot be empty""" + with pytest.raises(ValidationError): + AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_valid_endpoint_format_examples(self): + """Test valid endpoint format examples from comments""" + valid_endpoints = [ + # cms2.0 public endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", + # cms2.0 intranet endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", + # xtrace public endpoint + "http://tracing-cn-heyuan.arms.aliyuncs.com", + # xtrace intranet endpoint + "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", + ] + + for endpoint in valid_endpoints: + config = AliyunConfig(license_key="test_license", endpoint=endpoint) + assert config.endpoint == endpoint diff --git a/api/providers/trace/trace-arize-phoenix/pyproject.toml b/api/providers/trace/trace-arize-phoenix/pyproject.toml new file mode 100644 index 0000000000..9e756944c9 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-arize-phoenix" +version = "0.0.1" +dependencies = [ + "arize-phoenix-otel~=0.15.0", +] +description = "Dify ops tracing provider (Arize / Phoenix)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langfuse_trace/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py similarity index 100% rename from api/core/ops/langfuse_trace/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py similarity index 99% rename from api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index 78516e1a22..96df49ed0e 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -25,7 +25,6 @@ from opentelemetry.util.types import AttributeValue from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -39,6 +38,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py new file mode 100644 index 0000000000..6eac5b30d2 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py @@ -0,0 +1,45 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class ArizeConfig(BaseTracingConfig): + """ + Model class for Arize tracing config. + """ + + api_key: str | None = None + space_id: str | None = None + project: str | None = None + endpoint: str = "https://otlp.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://otlp.arize.com") + + +class PhoenixConfig(BaseTracingConfig): + """ + Model class for Phoenix tracing config. + """ + + api_key: str | None = None + project: str | None = None + endpoint: str = "https://app.phoenix.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://app.phoenix.arize.com") diff --git a/api/core/ops/langfuse_trace/entities/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed similarity index 100% rename from api/core/ops/langfuse_trace/entities/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index 4ce9e22fd7..b0691a87ea 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -2,11 +2,7 @@ from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, patch import pytest -from opentelemetry.sdk.trace import Tracer -from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes -from opentelemetry.trace import StatusCode - -from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( +from dify_trace_arize_phoenix.arize_phoenix_trace import ( ArizePhoenixDataTrace, datetime_to_nanos, error_to_string, @@ -15,7 +11,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( setup_tracer, wrap_span_metadata, ) -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from opentelemetry.sdk.trace import Tracer +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import StatusCode + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -80,7 +80,7 @@ def test_datetime_to_nanos(): expected = int(dt.timestamp() * 1_000_000_000) assert datetime_to_nanos(dt) == expected - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt: mock_now = MagicMock() mock_now.timestamp.return_value = 1704110400.0 mock_dt.now.return_value = mock_now @@ -142,8 +142,8 @@ def test_wrap_span_metadata(): assert res == {"a": 1, "b": 2, "created_from": "Dify"} -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_arize(mock_provider, mock_exporter): config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") setup_tracer(config) @@ -151,8 +151,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter): assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1" -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_phoenix(mock_provider, mock_exporter): config = PhoenixConfig(endpoint="http://p.com", project="p") setup_tracer(config) @@ -162,7 +162,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter): def test_setup_tracer_exception(): config = ArizeConfig(endpoint="http://a.com", project="p") - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): with pytest.raises(Exception, match="boom"): setup_tracer(config) @@ -172,7 +172,7 @@ def test_setup_tracer_exception(): @pytest.fixture def trace_instance(): - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup: mock_tracer = MagicMock(spec=Tracer) mock_processor = MagicMock() mock_setup.return_value = (mock_tracer, mock_processor) @@ -228,9 +228,9 @@ def test_trace_exception(trace_instance): trace_instance.trace(_make_workflow_info()) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -262,7 +262,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac assert trace_instance.tracer.start_span.call_count >= 2 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_no_app_id(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -271,7 +271,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance): trace_instance.workflow_trace(info) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_success(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() @@ -291,7 +291,7 @@ def test_message_trace_success(mock_db, trace_instance): assert trace_instance.tracer.start_span.call_count >= 1 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_with_error(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py index 4b925390d9..a01c63ae61 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py @@ -1,6 +1,6 @@ +from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind from openinference.semconv.trace import OpenInferenceSpanKindValues -from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..11e951c3b1 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py @@ -0,0 +1,88 @@ +import pytest +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from pydantic import ValidationError + + +class TestArizeConfig: + """Test cases for ArizeConfig""" + + def test_valid_config(self): + """Test valid Arize configuration""" + config = ArizeConfig( + api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" + ) + assert config.api_key == "test_key" + assert config.space_id == "test_space" + assert config.project == "test_project" + assert config.endpoint == "https://custom.arize.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = ArizeConfig() + assert config.api_key is None + assert config.space_id is None + assert config.project is None + assert config.endpoint == "https://otlp.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = ArizeConfig(project="") + assert config.project == "default" + + def test_project_validation_none(self): + """Test project validation with None value""" + config = ArizeConfig(project=None) + assert config.project == "default" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = ArizeConfig(endpoint="") + assert config.endpoint == "https://otlp.arize.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation normalizes URL by removing path""" + config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") + assert config.endpoint == "https://custom.arize.com" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="ftp://invalid.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="invalid.com") + + +class TestPhoenixConfig: + """Test cases for PhoenixConfig""" + + def test_valid_config(self): + """Test valid Phoenix configuration""" + config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.phoenix.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = PhoenixConfig() + assert config.api_key is None + assert config.project is None + assert config.endpoint == "https://app.phoenix.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = PhoenixConfig(project="") + assert config.project == "default" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" diff --git a/api/providers/trace/trace-langfuse/pyproject.toml b/api/providers/trace/trace-langfuse/pyproject.toml new file mode 100644 index 0000000000..27d2273a69 --- /dev/null +++ b/api/providers/trace/trace-langfuse/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langfuse" +version = "0.0.1" +dependencies = [ + "langfuse>=4.2.0,<5.0.0", +] +description = "Dify ops tracing provider (Langfuse)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langsmith_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py new file mode 100644 index 0000000000..90d1a2846b --- /dev/null +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py @@ -0,0 +1,19 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class LangfuseConfig(BaseTracingConfig): + """ + Model class for Langfuse tracing config. + """ + + public_key: str + secret_key: str + host: str = "https://api.langfuse.com" + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://api.langfuse.com") diff --git a/api/core/ops/langsmith_trace/entities/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py similarity index 100% rename from api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py similarity index 99% rename from api/core/ops/langfuse_trace/langfuse_trace.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py index 7eacc2be46..68881378a7 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py @@ -16,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -28,7 +27,10 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( +from core.ops.utils import filter_none_values +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( GenerationUsage, LangfuseGeneration, LangfuseSpan, @@ -36,8 +38,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( LevelEnum, UnitEnum, ) -from core.ops.utils import filter_none_values -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/mlflow_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed similarity index 100% rename from api/core/ops/mlflow_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py index a0bcc92795..952f10c34f 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py @@ -5,8 +5,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( + LangfuseGeneration, + LangfuseSpan, + LangfuseTrace, + LevelEnum, + UnitEnum, +) +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -17,14 +25,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( - LangfuseGeneration, - LangfuseSpan, - LangfuseTrace, - LevelEnum, - UnitEnum, -) -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -43,7 +43,7 @@ def langfuse_config(): def trace_instance(langfuse_config, monkeypatch): # Mock Langfuse client to avoid network calls mock_client = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client) instance = LangFuseDataTrace(langfuse_config) return instance @@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch): def test_init(langfuse_config, monkeypatch): mock_langfuse = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangFuseDataTrace(langfuse_config) @@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): # Mock DB and Repositories mock_session = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="log-1", error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() @@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..103d888eef --- /dev/null +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py @@ -0,0 +1,42 @@ +import pytest +from dify_trace_langfuse.config import LangfuseConfig +from pydantic import ValidationError + + +class TestLangfuseConfig: + """Test cases for LangfuseConfig""" + + def test_valid_config(self): + """Test valid Langfuse configuration""" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == "https://custom.langfuse.com" + + def test_valid_config_with_path(self): + host = "https://custom.langfuse.com/api/v1" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == host + + def test_default_values(self): + """Test default values are set correctly""" + config = LangfuseConfig(public_key="public", secret_key="secret") + assert config.host == "https://api.langfuse.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangfuseConfig() + + with pytest.raises(ValidationError): + LangfuseConfig(public_key="public") + + with pytest.raises(ValidationError): + LangfuseConfig(secret_key="secret") + + def test_host_validation_empty(self): + """Test host validation with empty value""" + config = LangfuseConfig(public_key="public", secret_key="secret", host="") + assert config.host == "https://api.langfuse.com" diff --git a/api/tests/unit_tests/core/ops/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py similarity index 92% rename from api/tests/unit_tests/core/ops/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py index 017ac8c891..0340ffb669 100644 --- a/api/tests/unit_tests/core/ops/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py @@ -4,14 +4,15 @@ from datetime import datetime, timedelta from types import SimpleNamespace from unittest.mock import MagicMock, patch -from core.ops.entities.config_entity import LangfuseConfig +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace + from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace from graphon.enums import BuiltinNodeTypes def _create_trace_instance() -> LangFuseDataTrace: - with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True): + with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True): return LangFuseDataTrace( LangfuseConfig( public_key="public-key", @@ -116,9 +117,9 @@ class TestLangFuseDataTraceCompletionStartTime: patch.object(trace, "add_span"), patch.object(trace, "add_generation") as add_generation, patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()), - patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()), + patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()), patch( - "core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=repository, ), ): diff --git a/api/providers/trace/trace-langsmith/pyproject.toml b/api/providers/trace/trace-langsmith/pyproject.toml new file mode 100644 index 0000000000..8131952b28 --- /dev/null +++ b/api/providers/trace/trace-langsmith/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langsmith" +version = "0.0.1" +dependencies = [ + "langsmith~=0.7.30", +] +description = "Dify ops tracing provider (LangSmith)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/opik_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py similarity index 100% rename from api/core/ops/opik_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py new file mode 100644 index 0000000000..498b8c5e7e --- /dev/null +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py @@ -0,0 +1,20 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class LangSmithConfig(BaseTracingConfig): + """ + Model class for Langsmith tracing config. + """ + + api_key: str + project: str + endpoint: str = "https://api.smith.langchain.com" + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # LangSmith only allows HTTPS + return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) diff --git a/api/core/ops/tencent_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py similarity index 99% rename from api/core/ops/langsmith_trace/langsmith_trace.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py index d960038f15..145bd70dbc 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py @@ -9,7 +9,6 @@ from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -21,13 +20,14 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( +from core.ops.utils import filter_none_values, generate_dotted_order +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( LangSmithRunModel, LangSmithRunType, LangSmithRunUpdateModel, ) -from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/ops/weave_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed similarity index 100% rename from api/core/ops/weave_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py rename to api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py index 34c64c54a1..45e5894e4a 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py @@ -3,8 +3,14 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( + LangSmithRunModel, + LangSmithRunType, + LangSmithRunUpdateModel, +) +from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -15,12 +21,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( - LangSmithRunModel, - LangSmithRunType, - LangSmithRunUpdateModel, -) -from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -38,7 +38,7 @@ def langsmith_config(): def trace_instance(langsmith_config, monkeypatch): # Mock LangSmith client mock_client = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client) instance = LangSmithDataTrace(langsmith_config) return instance @@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch): def test_init(langsmith_config, monkeypatch): mock_client_class = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangSmithDataTrace(langsmith_config) @@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch): # Mock dependencies mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() @@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): trace_info.error = "" mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() @@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..37efaf69cf --- /dev/null +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py @@ -0,0 +1,35 @@ +import pytest +from dify_trace_langsmith.config import LangSmithConfig +from pydantic import ValidationError + + +class TestLangSmithConfig: + """Test cases for LangSmithConfig""" + + def test_valid_config(self): + """Test valid LangSmith configuration""" + config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.smith.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = LangSmithConfig(api_key="key", project="project") + assert config.endpoint == "https://api.smith.langchain.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangSmithConfig() + + with pytest.raises(ValidationError): + LangSmithConfig(api_key="key") + + with pytest.raises(ValidationError): + LangSmithConfig(project="project") + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") diff --git a/api/providers/trace/trace-mlflow/pyproject.toml b/api/providers/trace/trace-mlflow/pyproject.toml new file mode 100644 index 0000000000..fad6002944 --- /dev/null +++ b/api/providers/trace/trace-mlflow/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-mlflow" +version = "0.0.1" +dependencies = [ + "mlflow-skinny>=3.11.1", +] +description = "Dify ops tracing provider (MLflow / Databricks)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py similarity index 100% rename from api/core/ops/weave_trace/entities/__init__.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py new file mode 100644 index 0000000000..84914165e3 --- /dev/null +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py @@ -0,0 +1,46 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_integer_id, validate_url_with_path + + +class MLflowConfig(BaseTracingConfig): + """ + Model class for MLflow tracing config. + """ + + tracking_uri: str = "http://localhost:5000" + experiment_id: str = "0" # Default experiment id in MLflow is 0 + username: str | None = None + password: str | None = None + + @field_validator("tracking_uri") + @classmethod + def tracking_uri_validator(cls, v, info: ValidationInfo): + if isinstance(v, str) and v.startswith("databricks"): + raise ValueError( + "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." + ) + return validate_url_with_path(v, "http://localhost:5000") + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) + + +class DatabricksConfig(BaseTracingConfig): + """ + Model class for Databricks (Databricks-managed MLflow) tracing config. + """ + + experiment_id: str + host: str + client_id: str | None = None + client_secret: str | None = None + personal_access_token: str | None = None + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py similarity index 99% rename from api/core/ops/mlflow_trace/mlflow_trace.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py index 87fcaeabcc..4e4c45a532 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py @@ -11,7 +11,6 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -24,6 +23,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import JSON_DICT_ADAPTER +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig from extensions.ext_database import db from graphon.enums import BuiltinNodeTypes from models import EndUser diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py similarity index 98% rename from api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py rename to api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py index afc5726ede..20211456e3 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module.""" +"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module.""" from __future__ import annotations @@ -9,8 +9,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig +from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -20,7 +21,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -179,7 +179,7 @@ def _make_node(**overrides): @pytest.fixture def mock_mlflow(): - with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock: + with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock: yield mock @@ -187,10 +187,10 @@ def mock_mlflow(): def mock_tracing(): """Patch all MLflow tracing functions used by the module.""" with ( - patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start, - patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update, - patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set, - patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach, + patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start, + patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update, + patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set, + patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach, ): yield { "start": mock_start, @@ -202,7 +202,7 @@ def mock_tracing(): @pytest.fixture def mock_db(): - with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock: + with patch("dify_trace_mlflow.mlflow_trace.db") as mock: yield mock diff --git a/api/providers/trace/trace-opik/pyproject.toml b/api/providers/trace/trace-opik/pyproject.toml new file mode 100644 index 0000000000..874997168e --- /dev/null +++ b/api/providers/trace/trace-opik/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-opik" +version = "0.0.1" +dependencies = [ + "opik~=1.11.2", +] +description = "Dify ops tracing provider (Opik)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/config.py b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py new file mode 100644 index 0000000000..c16ff1d903 --- /dev/null +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py @@ -0,0 +1,25 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class OpikConfig(BaseTracingConfig): + """ + Model class for Opik tracing config. + """ + + api_key: str | None = None + project: str | None = None + workspace: str | None = None + url: str = "https://www.comet.com/opik/api/" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "Default Project") + + @field_validator("url") + @classmethod + def url_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py similarity index 99% rename from api/core/ops/opik_trace/opik_trace.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py index 672efe45bd..2d124ac989 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py @@ -10,7 +10,6 @@ from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -23,6 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory +from dify_trace_opik.config import OpikConfig from extensions.ext_database import db from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed b/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py index c02ac413f2..eefed3c78c 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py @@ -5,8 +5,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_opik.config import OpikConfig +from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -17,7 +18,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -37,7 +37,7 @@ def opik_config(): @pytest.fixture def trace_instance(opik_config, monkeypatch): mock_client = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client) instance = OpikDataTrace(opik_config) return instance @@ -67,7 +67,7 @@ def test_prepare_opik_uuid(): def test_init(opik_config, monkeypatch): mock_opik = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik) monkeypatch.setenv("FILES_URL", "http://test.url") instance = OpikDataTrace(opik_config) @@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) node_llm = MagicMock() node_llm.id = LLM_NODE_ID @@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e", error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() @@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..5a54b70bba --- /dev/null +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py @@ -0,0 +1,48 @@ +import pytest +from dify_trace_opik.config import OpikConfig +from pydantic import ValidationError + + +class TestOpikConfig: + """Test cases for OpikConfig""" + + def test_valid_config(self): + """Test valid Opik configuration""" + config = OpikConfig( + api_key="test_key", + project="test_project", + workspace="test_workspace", + url="https://custom.comet.com/opik/api/", + ) + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.workspace == "test_workspace" + assert config.url == "https://custom.comet.com/opik/api/" + + def test_default_values(self): + """Test default values are set correctly""" + config = OpikConfig() + assert config.api_key is None + assert config.project is None + assert config.workspace is None + assert config.url == "https://www.comet.com/opik/api/" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = OpikConfig(project="") + assert config.project == "Default Project" + + def test_url_validation_empty(self): + """Test URL validation with empty value""" + config = OpikConfig(url="") + assert config.url == "https://www.comet.com/opik/api/" + + def test_url_validation_missing_suffix(self): + """Test URL validation requires /api/ suffix""" + with pytest.raises(ValidationError, match="URL should end with /api/"): + OpikConfig(url="https://custom.comet.com/opik/") + + def test_url_validation_invalid_scheme(self): + """Test URL validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + OpikConfig(url="ftp://custom.comet.com/opik/api/") diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py index ad9d0846be..fba290f5b8 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py @@ -14,8 +14,9 @@ import uuid from datetime import datetime from unittest.mock import MagicMock, patch +from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid + from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo -from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid # A stable UUID4 used as the workflow_run_id throughout all tests. _WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6" @@ -56,8 +57,8 @@ def _make_workflow_trace_info( def _make_opik_trace_instance() -> OpikDataTrace: """Construct an OpikDataTrace with the Opik SDK client mocked out.""" - with patch("core.ops.opik_trace.opik_trace.Opik"): - from core.ops.entities.config_entity import OpikConfig + with patch("dify_trace_opik.opik_trace.Opik"): + from dify_trace_opik.config import OpikConfig config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/") instance = OpikDataTrace(config) @@ -133,10 +134,10 @@ class TestWorkflowTraceWithoutMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): @@ -265,10 +266,10 @@ class TestWorkflowTraceWithMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): diff --git a/api/providers/trace/trace-tencent/pyproject.toml b/api/providers/trace/trace-tencent/pyproject.toml new file mode 100644 index 0000000000..eab06fc708 --- /dev/null +++ b/api/providers/trace/trace-tencent/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-tencent" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Tencent APM)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/tencent_trace/client.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py similarity index 100% rename from api/core/ops/tencent_trace/client.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py new file mode 100644 index 0000000000..398e6c55a8 --- /dev/null +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py @@ -0,0 +1,30 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig + + +class TencentConfig(BaseTracingConfig): + """ + Tencent APM tracing config + """ + + token: str + endpoint: str + service_name: str + + @field_validator("token") + @classmethod + def token_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("Token cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") + + @field_validator("service_name") + @classmethod + def service_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") diff --git a/api/core/ops/tencent_trace/entities/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/entities/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py diff --git a/api/core/ops/tencent_trace/entities/semconv.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py similarity index 100% rename from api/core/ops/tencent_trace/entities/semconv.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py diff --git a/api/core/ops/tencent_trace/entities/tencent_trace_entity.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py similarity index 100% rename from api/core/ops/tencent_trace/entities/tencent_trace_entity.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed b/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py similarity index 98% rename from api/core/ops/tencent_trace/span_builder.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py index 36878dc58f..763a85ffd7 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py @@ -14,7 +14,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_tencent.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, GEN_AI_IS_ENTRY, @@ -38,9 +39,8 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.utils import TencentTraceUtils -from core.rag.models.document import Document +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.utils import TencentTraceUtils from graphon.entities import WorkflowNodeExecution from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py similarity index 94% rename from api/core/ops/tencent_trace/tencent_trace.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py index d681b9da80..a8c480e4a5 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py @@ -1,14 +1,12 @@ -""" -Tencent APM tracing implementation with separated concerns -""" +"""Tencent APM tracing with idempotent client cleanup.""" +import inspect import logging from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -19,11 +17,12 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.client import TencentTraceClient -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.span_builder import TencentSpanBuilder -from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from dify_trace_tencent.client import TencentTraceClient +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.span_builder import TencentSpanBuilder +from dify_trace_tencent.utils import TencentTraceUtils from extensions.ext_database import db from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, @@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance): """ Tencent APM trace implementation with single responsibility principle. Acts as a coordinator that delegates specific tasks to specialized classes. + + The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen + explicitly in tests or implicitly during garbage collection, so shutdown + must be safe to call multiple times. """ + trace_client: TencentTraceClient + _closed: bool + def __init__(self, tencent_config: TencentConfig): super().__init__(tencent_config) + self._closed = False self.trace_client = TencentTraceClient( service_name=tencent_config.service_name, endpoint=tencent_config.endpoint, @@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance): except Exception: logger.debug("[Tencent APM] Failed to record message trace duration") - def __del__(self): - """Ensure proper cleanup on garbage collection.""" + def close(self) -> None: + """Synchronously and idempotently shutdown the underlying trace client.""" + if getattr(self, "_closed", False): + return + + self._closed = True + trace_client = getattr(self, "trace_client", None) + if trace_client is None: + return + try: - if hasattr(self, "trace_client"): - self.trace_client.shutdown() + shutdown_result = trace_client.shutdown() + if inspect.isawaitable(shutdown_result): + close_awaitable = getattr(shutdown_result, "close", None) + if callable(close_awaitable): + close_awaitable() except Exception: logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup") + + def __del__(self): + """Ensure best-effort cleanup on garbage collection without retrying shutdown.""" + self.close() diff --git a/api/core/ops/tencent_trace/utils.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py similarity index 100% rename from api/core/ops/tencent_trace/utils.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py similarity index 98% rename from api/tests/unit_tests/core/ops/tencent_trace/test_client.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py index 870c18e53e..1e656e2462 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py @@ -8,13 +8,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_tencent import client as client_module +from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version +from dify_trace_tencent.entities.tencent_trace_entity import SpanData from opentelemetry.sdk.trace import Event from opentelemetry.trace import Status, StatusCode -from core.ops.tencent_trace import client as client_module -from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData - metric_reader_instances: list[DummyMetricReader] = [] meter_provider_instances: list[DummyMeterProvider] = [] diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py similarity index 89% rename from api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py index 6113e5c6c8..e850a801f3 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py @@ -1,15 +1,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch -from opentelemetry.trace import StatusCode - -from core.ops.entities.trace_entity import ( - DatasetRetrievalTraceInfo, - MessageTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.ops.tencent_trace.entities.semconv import ( +from dify_trace_tencent.entities.semconv import ( GEN_AI_IS_ENTRY, GEN_AI_IS_STREAMING_REQUEST, GEN_AI_MODEL_NAME, @@ -23,7 +15,15 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.span_builder import TencentSpanBuilder +from dify_trace_tencent.span_builder import TencentSpanBuilder +from opentelemetry.trace import StatusCode + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + MessageTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) from core.rag.models.document import Document from graphon.entities import WorkflowNodeExecution from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -31,7 +31,7 @@ from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutio class TestTencentSpanBuilder: def test_get_time_nanoseconds(self): - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: mock_convert.return_value = 123456789 dt = datetime.now() result = TencentSpanBuilder._get_time_nanoseconds(dt) @@ -48,7 +48,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {"answer": "world"} trace_info.metadata = {"conversation_id": "conv_id"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -70,7 +70,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {} trace_info.metadata = {} # No conversation_id - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 1 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -98,7 +98,7 @@ class TestTencentSpanBuilder: } node_execution.outputs = {"text": "world"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -123,7 +123,7 @@ class TestTencentSpanBuilder: "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, } - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -142,7 +142,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {"conversation_id": "conv_id"} trace_info.is_streaming_request = True - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -162,7 +162,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {} trace_info.is_streaming_request = False - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -182,7 +182,7 @@ class TestTencentSpanBuilder: trace_info.tool_inputs = {"i": 2} trace_info.tool_outputs = "result" - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 101 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1) @@ -204,7 +204,7 @@ class TestTencentSpanBuilder: ) trace_info.documents = [doc] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -222,7 +222,7 @@ class TestTencentSpanBuilder: trace_info.end_time = datetime.now() trace_info.documents = [] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -264,7 +264,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -286,7 +286,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -307,7 +307,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -329,7 +329,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -350,7 +350,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 505 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py similarity index 86% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py index 7afd0b824a..54524b09ca 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py @@ -1,9 +1,12 @@ +import gc import logging -from unittest.mock import MagicMock, patch +import warnings +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.tencent_trace import TencentDataTrace -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -13,7 +16,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.tencent_trace import TencentDataTrace from graphon.entities import WorkflowNodeExecution from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin @@ -28,19 +30,19 @@ def tencent_config(): @pytest.fixture def mock_trace_client(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceClient") as mock: yield mock @pytest.fixture def mock_span_builder(): - with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentSpanBuilder") as mock: yield mock @pytest.fixture def mock_trace_utils(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceUtils") as mock: yield mock @@ -198,9 +200,9 @@ class TestTencentDataTrace: trace_info.workflow_run_id = "run-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.workflow_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") @@ -230,9 +232,9 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.message_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") @@ -262,9 +264,9 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.tool_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") @@ -294,22 +296,22 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.dataset_retrieval_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") def test_suggested_question_trace(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") def test_suggested_question_trace_exception(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") @@ -342,7 +344,7 @@ class TestTencentDataTrace: with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) # The exception should be caught by the outer handler since convert_to_span_id is called first mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -351,7 +353,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -381,7 +383,7 @@ class TestTencentDataTrace: node.id = "n1" mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) assert result is None mock_log.assert_called_once() @@ -403,15 +405,13 @@ class TestTencentDataTrace: mock_executions = [MagicMock()] - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.side_effect = [app, account, tenant_join] - with patch( - "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" - ) as mock_repo: + with patch("dify_trace_tencent.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository") as mock_repo: mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) @@ -423,7 +423,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {} - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -432,14 +432,14 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {"app_id": "app-1"} - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() # Ensure init_app is mocked mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.return_value = None - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -449,8 +449,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "tenant-1" trace_info.metadata = {"user_id": "user-1"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")): - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")): + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() mock_db.engine = MagicMock() @@ -476,8 +476,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "t" trace_info.metadata = {"user_id": "u"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: user_id = tencent_data_trace._get_user_id(trace_info) assert user_id == "unknown" mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") @@ -519,7 +519,7 @@ class TestTencentDataTrace: node.process_data = None node.outputs = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_llm_metrics(node) # Should not crash @@ -557,7 +557,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.metadata = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_llm_metrics(trace_info) # Should not crash @@ -609,7 +609,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_workflow_trace_duration(trace_info) def test_record_message_trace_duration(self, tencent_data_trace): @@ -631,16 +631,41 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.start_time = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_trace_duration(trace_info) - def test_del(self, tencent_data_trace): + def test_close(self, tencent_data_trace): client = tencent_data_trace.trace_client - tencent_data_trace.__del__() + tencent_data_trace.close() client.shutdown.assert_called_once() - def test_del_exception(self, tencent_data_trace): + def test_close_is_idempotent(self, tencent_data_trace): + client = tencent_data_trace.trace_client + + tencent_data_trace.close() + tencent_data_trace.close() + + client.shutdown.assert_called_once() + + def test_close_exception(self, tencent_data_trace): tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.__del__() + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.close() mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") + + def test_close_handles_async_shutdown_mock(self, tencent_data_trace): + shutdown = AsyncMock() + tencent_data_trace.trace_client.shutdown = shutdown + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + tencent_data_trace.close() + gc.collect() + + shutdown.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) + and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py similarity index 88% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py index ef28d18e20..63c6d680d7 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py @@ -8,10 +8,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from dify_trace_tencent.utils import TencentTraceUtils from opentelemetry.trace import Link, TraceFlags -from core.ops.tencent_trace.utils import TencentTraceUtils - def test_convert_to_trace_id_with_valid_uuid() -> None: uuid_str = "12345678-1234-5678-1234-567812345678" @@ -20,7 +19,7 @@ def test_convert_to_trace_id_with_valid_uuid() -> None: def test_convert_to_trace_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int uuid4_mock.assert_called_once() @@ -45,7 +44,7 @@ def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None: def test_convert_to_span_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: span_id = TencentTraceUtils.convert_to_span_id(None, "workflow") assert isinstance(span_id, int) uuid4_mock.assert_called_once() @@ -58,7 +57,7 @@ def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None: def test_generate_span_id_skips_invalid_span_id() -> None: with patch( - "core.ops.tencent_trace.utils.random.getrandbits", + "dify_trace_tencent.utils.random.getrandbits", side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42], ) as bits_mock: assert TencentTraceUtils.generate_span_id() == 42 @@ -75,7 +74,7 @@ def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None: fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) expected = int(fixed.timestamp() * 1e9) - with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock: + with patch("dify_trace_tencent.utils.datetime") as datetime_mock: datetime_mock.now.return_value = fixed assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected datetime_mock.now.assert_called_once() @@ -100,7 +99,7 @@ def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: i @pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None]) def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None: fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type] assert link.context.trace_id == fallback_uuid.int uuid4_mock.assert_called_once() diff --git a/api/providers/trace/trace-weave/pyproject.toml b/api/providers/trace/trace-weave/pyproject.toml new file mode 100644 index 0000000000..ba449f2a93 --- /dev/null +++ b/api/providers/trace/trace-weave/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-weave" +version = "0.0.1" +dependencies = [ + "weave>=0.52.36", +] +description = "Dify ops tracing provider (Weave)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/config.py b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py new file mode 100644 index 0000000000..5942bd57fe --- /dev/null +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py @@ -0,0 +1,29 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class WeaveConfig(BaseTracingConfig): + """ + Model class for Weave tracing config. + """ + + api_key: str + entity: str | None = None + project: str + endpoint: str = "https://trace.wandb.ai" + host: str | None = None + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # Weave only allows HTTPS for endpoint + return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + if v is not None and v.strip() != "": + return validate_url(v, v, allowed_schemes=("https", "http")) + return v diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py similarity index 100% rename from api/core/ops/weave_trace/entities/weave_trace_entity.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed b/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py similarity index 99% rename from api/core/ops/weave_trace/weave_trace.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py index f79544f1c7..4292cbf0f1 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py @@ -17,7 +17,6 @@ from weave.trace_server.trace_server_interface import ( ) from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -29,8 +28,9 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel from extensions.ext_database import db from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..eeb1fe1d87 --- /dev/null +++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py @@ -0,0 +1,61 @@ +import pytest +from dify_trace_weave.config import WeaveConfig +from pydantic import ValidationError + + +class TestWeaveConfig: + """Test cases for WeaveConfig""" + + def test_valid_config(self): + """Test valid Weave configuration""" + config = WeaveConfig( + api_key="test_key", + entity="test_entity", + project="test_project", + endpoint="https://custom.wandb.ai", + host="https://custom.host.com", + ) + assert config.api_key == "test_key" + assert config.entity == "test_entity" + assert config.project == "test_project" + assert config.endpoint == "https://custom.wandb.ai" + assert config.host == "https://custom.host.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = WeaveConfig(api_key="key", project="project") + assert config.entity is None + assert config.endpoint == "https://trace.wandb.ai" + assert config.host is None + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + WeaveConfig() + + with pytest.raises(ValidationError): + WeaveConfig(api_key="key") + + with pytest.raises(ValidationError): + WeaveConfig(project="project") + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") + + def test_host_validation_optional(self): + """Test host validation is optional but validates when provided""" + config = WeaveConfig(api_key="key", project="project", host=None) + assert config.host is None + + config = WeaveConfig(api_key="key", project="project", host="") + assert config.host == "" + + config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") + assert config.host == "https://valid.host.com" + + def test_host_validation_invalid_scheme(self): + """Test host validation rejects invalid schemes when provided""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py similarity index 97% rename from api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py rename to api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py index 531c7de05f..6028d0c550 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.weave_trace.weave_trace module.""" +"""Comprehensive tests for dify_trace_weave.weave_trace module.""" from __future__ import annotations @@ -7,9 +7,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel +from dify_trace_weave.weave_trace import WeaveDataTrace from weave.trace_server.trace_server_interface import TraceStatus -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -20,8 +22,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.ops.weave_trace.weave_trace import WeaveDataTrace from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -191,14 +191,14 @@ def _make_node(**overrides): @pytest.fixture def mock_wandb(): - with patch("core.ops.weave_trace.weave_trace.wandb") as mock: + with patch("dify_trace_weave.weave_trace.wandb") as mock: mock.login.return_value = True yield mock @pytest.fixture def mock_weave(): - with patch("core.ops.weave_trace.weave_trace.weave") as mock: + with patch("dify_trace_weave.weave_trace.weave") as mock: client = MagicMock() client.entity = "my-entity" client.project = "my-project" @@ -307,7 +307,7 @@ class TestGetProjectUrl: monkeypatch.setattr(trace_instance, "entity", None) monkeypatch.setattr(trace_instance, "project_name", None) # Force an error by making string formatting fail - with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger: + with patch("dify_trace_weave.weave_trace.logger") as mock_logger: # Simulate exception via property original_entity = trace_instance.entity trace_instance.entity = None @@ -594,9 +594,9 @@ class TestWorkflowTrace: mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) return repo def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch): @@ -703,8 +703,8 @@ class TestWorkflowTrace: def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch): """Raises ValueError when app_id is missing from metadata.""" - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) trace_info = _make_workflow_trace_info( message_id=None, @@ -802,7 +802,7 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.get", + "dify_trace_weave.weave_trace.db.session.get", lambda model, pk: None, ) @@ -824,7 +824,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -846,7 +846,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = end_user - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -866,7 +866,7 @@ class TestMessageTrace: """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -884,7 +884,7 @@ class TestMessageTrace: """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -899,7 +899,7 @@ class TestMessageTrace: """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() diff --git a/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py index b2908ebdae..11398efb58 100644 --- a/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py +++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py @@ -1,6 +1,6 @@ import json import uuid -from collections.abc import Iterator +from collections.abc import Generator # Added Generator from contextlib import contextmanager from typing import Any @@ -75,7 +75,7 @@ class AnalyticdbVectorBySql: ) @contextmanager - def _get_cursor(self) -> Iterator[Any]: + def _get_cursor(self) -> Generator[Any, None, None]: # Changed from Iterator[Any] assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() diff --git a/api/pyproject.toml b/api/pyproject.toml index a1ceea181e..b13c744f0a 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,7 +6,7 @@ requires-python = "~=3.12.0" dependencies = [ # Legacy: mature and widely deployed "bleach>=6.3.0", - "boto3>=1.42.88", + "boto3>=1.42.91", "celery>=5.6.3", "croniter>=6.2.2", "flask-cors>=6.0.2", @@ -30,11 +30,8 @@ dependencies = [ "flask-migrate>=4.1.0,<5.0.0", "flask-orjson>=2.0.0,<3.0.0", "flask-restx>=1.3.2,<2.0.0", - "google-cloud-aiplatform>=1.147.0,<2.0.0", + "google-cloud-aiplatform>=1.148.1,<2.0.0", "httpx[socks]>=0.28.1,<1.0.0", - "langfuse>=4.2.0,<5.0.0", - "langsmith>=0.7.31,<1.0.0", - "mlflow-skinny>=3.11.1,<4.0.0", "opentelemetry-distro>=0.62b0,<1.0.0", "opentelemetry-instrumentation-celery>=0.62b0,<1.0.0", "opentelemetry-instrumentation-flask>=0.62b0,<1.0.0", @@ -44,15 +41,12 @@ dependencies = [ "opentelemetry-propagator-b3>=1.41.0,<2.0.0", "readabilipy>=0.3.0,<1.0.0", "resend>=2.27.0,<3.0.0", - "weave>=0.52.36,<1.0.0", # Emerging: newer and fast-moving, use compatible pins - "arize-phoenix-otel~=0.15.0", "fastopenapi[flask]~=0.7.0", - "graphon~=0.1.2", + "graphon~=0.2.2", "httpx-sse~=0.4.0", - "json-repair~=0.59.2", - "opik~=1.11.2", + "json-repair~=0.59.4", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -61,8 +55,8 @@ dependencies = [ packages = [] [tool.uv.workspace] -members = ["providers/vdb/*"] -exclude = ["providers/vdb/__pycache__"] +members = ["providers/vdb/*", "providers/trace/*"] +exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"] [tool.uv.sources] dify-vdb-alibabacloud-mysql = { workspace = true } @@ -95,9 +89,17 @@ dify-vdb-upstash = { workspace = true } dify-vdb-vastbase = { workspace = true } dify-vdb-vikingdb = { workspace = true } dify-vdb-weaviate = { workspace = true } +dify-trace-aliyun = { workspace = true } +dify-trace-arize-phoenix = { workspace = true } +dify-trace-langfuse = { workspace = true } +dify-trace-langsmith = { workspace = true } +dify-trace-mlflow = { workspace = true } +dify-trace-opik = { workspace = true } +dify-trace-tencent = { workspace = true } +dify-trace-weave = { workspace = true } [tool.uv] -default-groups = ["storage", "tools", "vdb-all"] +default-groups = ["storage", "tools", "vdb-all", "trace-all"] package = false override-dependencies = [ "pyarrow>=18.0.0", @@ -171,8 +173,8 @@ dev = [ # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.60.0", - "xinference-client>=2.4.0", + "pyrefly>=0.61.1", + "xinference-client>=2.5.0", ] ############################################################ @@ -181,13 +183,13 @@ dev = [ ############################################################ storage = [ "azure-storage-blob>=12.28.0", - "bce-python-sdk>=0.9.69", + "bce-python-sdk>=0.9.70", "cos-python-sdk-v5>=1.9.41", "esdk-obs-python>=3.22.2", "google-cloud-storage>=3.10.1", "opendal>=0.46.0", "oss2>=2.19.1", - "supabase>=2.18.1", + "supabase>=2.28.3", "tos>=2.9.0", ] @@ -264,7 +266,26 @@ vdb-vastbase = ["dify-vdb-vastbase"] vdb-vikingdb = ["dify-vdb-vikingdb"] vdb-weaviate = ["dify-vdb-weaviate"] # Optional client used by some tests / integrations (not a vector backend plugin) -vdb-xinference = ["xinference-client>=2.4.0"] +vdb-xinference = ["xinference-client>=2.5.0"] + +trace-all = [ + "dify-trace-aliyun", + "dify-trace-arize-phoenix", + "dify-trace-langfuse", + "dify-trace-langsmith", + "dify-trace-mlflow", + "dify-trace-opik", + "dify-trace-tencent", + "dify-trace-weave", +] +trace-aliyun = ["dify-trace-aliyun"] +trace-arize-phoenix = ["dify-trace-arize-phoenix"] +trace-langfuse = ["dify-trace-langfuse"] +trace-langsmith = ["dify-trace-langsmith"] +trace-mlflow = ["dify-trace-mlflow"] +trace-opik = ["dify-trace-opik"] +trace-tencent = ["dify-trace-tencent"] +trace-weave = ["dify-trace-weave"] [tool.pyrefly] project-includes = ["."] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index 3e5ece1fcf..fbbca24558 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -34,12 +34,12 @@ core/external_data_tool/api/api.py core/llm_generator/llm_generator.py core/llm_generator/output_parser/structured_output.py core/mcp/mcp_client.py -core/ops/aliyun_trace/data_exporter/traceclient.py -core/ops/arize_phoenix_trace/arize_phoenix_trace.py -core/ops/mlflow_trace/mlflow_trace.py +providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py +providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py core/ops/ops_trace_manager.py -core/ops/tencent_trace/client.py -core/ops/tencent_trace/utils.py +providers/trace/trace-tencent/src/dify_trace_tencent/client.py +providers/trace/trace-tencent/src/dify_trace_tencent/utils.py core/plugin/backwards_invocation/base.py core/plugin/backwards_invocation/model.py core/prompt/utils/extract_thread_messages.py diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index c4582e891d..ac0e2a3a53 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -5,7 +5,8 @@ ".venv", "migrations/", "core/rag", - "providers/", + "providers/vdb/", + "providers/trace/*/tests", ], "typeCheckingMode": "strict", "allowedUntypedLibraries": [ diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 8479cdfb0c..2cc0192a4a 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -7,8 +7,8 @@ from sqlalchemy import select import app from configs import dify_config +from core.db.session_factory import session_factory from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service from models import Account, Tenant, TenantAccountJoin @@ -33,67 +33,68 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = db.session.scalars( - select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) - ).all() - # group by tenant_id - dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) - for dataset_auto_disable_log in dataset_auto_disable_logs: - if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) - url = f"{dify_config.CONSOLE_WEB_URL}/datasets" - for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): - features = FeatureService.get_features(tenant_id) - plan = features.billing.subscription.plan - if plan != CloudPlan.SANDBOX: - knowledge_details = [] - # check tenant - tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) - if not tenant: - continue - # check current owner - current_owner_join = db.session.scalar( - select(TenantAccountJoin) - .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") - .limit(1) - ) - if not current_owner_join: - continue - account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) - if not account: - continue + with session_factory.create_session() as session: + dataset_auto_disable_logs = session.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False)) + ).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != CloudPlan.SANDBOX: + knowledge_details = [] + # check tenant + tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id)) + if not tenant: + continue + # check current owner + current_owner_join = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) + ) + if not current_owner_join: + continue + account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) + if not account: + continue - dataset_auto_dataset_map = {} # type: ignore + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id)) + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, + language_code="en-US", + to=account.email, + template_context={ + "userName": account.email, + "knowledge_details": knowledge_details, + "url": url, + }, + ) + + # update notified to True for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( - dataset_auto_disable_log.document_id - ) - - for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id)) - if dataset: - document_count = len(document_ids) - knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") - if knowledge_details: - email_service = get_email_i18n_service() - email_service.send_email( - email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, - language_code="en-US", - to=account.email, - template_context={ - "userName": account.email, - "knowledge_details": knowledge_details, - "url": url, - }, - ) - - # update notified to True - for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - dataset_auto_disable_log.notified = True - db.session.commit() + dataset_auto_disable_log.notified = True + session.commit() end_at = time.perf_counter() logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/services/account_service.py b/api/services/account_service.py index ccc4a7c1fa..b6554a3de7 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -112,6 +112,14 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: + # Phase-bound token metadata for the change-email flow. Tokens carry the + # current phase so that downstream endpoints can enforce proper progression + CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase" + CHANGE_EMAIL_PHASE_OLD = "old_email" + CHANGE_EMAIL_PHASE_OLD_VERIFIED = "old_email_verified" + CHANGE_EMAIL_PHASE_NEW = "new_email" + CHANGE_EMAIL_PHASE_NEW_VERIFIED = "new_email_verified" + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1) email_code_login_rate_limiter = RateLimiter( @@ -576,13 +584,20 @@ class AccountService: raise ValueError("Email must be provided.") if not phase: raise ValueError("phase must be provided.") + if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW): + raise ValueError("phase must be one of old_email or new_email.") if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60)) - code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) + code, token = cls.generate_change_email_token( + account_email, + account, + old_email=old_email, + additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase}, + ) send_change_mail_task.delay( language=language, diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 78806927bc..97aaea3395 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -17,6 +17,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from core.helper import ssrf_proxy from core.plugin.entities.plugin import PluginDependency from core.trigger.constants import ( @@ -50,7 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.6.0" +CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION class Import(BaseModel): diff --git a/api/services/app_service.py b/api/services/app_service.py index afd98e2975..a046b909b3 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account @@ -303,17 +303,22 @@ class AppService: return app - def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: + def update_app_icon( + self, app: App, icon: str, icon_background: str, icon_type: IconType | str | None = None + ) -> App: """ Update app icon :param app: App instance :param icon: new icon :param icon_background: new icon_background + :param icon_type: new icon type :return: App instance """ assert current_user is not None app.icon = icon app.icon_background = icon_background + if icon_type is not None: + app.icon_type = icon_type if isinstance(icon_type, IconType) else IconType(icon_type) app.updated_by = current_user.id app.updated_at = naive_utc_now() db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e6f5f80a6d..894cb05687 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -30,7 +30,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from graphon.file import helpers as file_helpers from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/services/feature_service.py b/api/services/feature_service.py index d0d3fbd66b..e18eb096c9 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -3,6 +3,7 @@ from enum import StrEnum from pydantic import BaseModel, ConfigDict, Field from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from enums.cloud_plan import CloudPlan from enums.hosted_provider import HostedTrialProvider from services.billing_service import BillingService @@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel): class SystemFeatureModel(BaseModel): + app_dsl_version: str = "" sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" enable_marketplace: bool = False @@ -225,6 +227,7 @@ class FeatureService: @classmethod def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() + system_features.app_dsl_version = CURRENT_APP_DSL_VERSION cls._fulfill_system_params_from_env(system_features) diff --git a/api/services/file_service.py b/api/services/file_service.py index 52da2a7951..f60afe2f19 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -2,7 +2,7 @@ import base64 import hashlib import os import uuid -from collections.abc import Iterator, Sequence +from collections.abc import Generator, Sequence # Changed Iterator to Generator from contextlib import contextmanager, suppress from tempfile import NamedTemporaryFile from typing import Literal @@ -324,7 +324,7 @@ class FileService: def build_upload_files_zip_tempfile( *, upload_files: Sequence[UploadFile], - ) -> Iterator[str]: + ) -> Generator[str, None, None]: # Changed from Iterator[str] """ Build a ZIP from `UploadFile`s and yield a tempfile path. diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index ca84b2a3d8..2e5987dd28 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,10 +1,10 @@ import json import logging import time -from typing import Any, TypedDict +from typing import Any, TypedDict, cast from core.app.app_config.entities import ModelConfig -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -36,6 +36,10 @@ default_retrieval_model = { } +class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False): + metadata_filtering_conditions: dict[str, Any] + + class HitTestingService: @classmethod def retrieve( @@ -51,17 +55,18 @@ class HitTestingService: start = time.perf_counter() # get retrieval model , if the model is not setting , using default - if not retrieval_model: - retrieval_model = dataset.retrieval_model or default_retrieval_model - assert isinstance(retrieval_model, dict) + resolved_retrieval_model = cast( + HitTestingRetrievalModelDict, + retrieval_model or dataset.retrieval_model or default_retrieval_model, + ) document_ids_filter = None - metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions and query: + metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {}) + if metadata_filtering_conditions_raw and query: dataset_retrieval = DatasetRetrieval() from core.rag.entities import MetadataFilteringCondition - metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions) + metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw) metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( dataset_ids=[dataset.id], @@ -78,19 +83,21 @@ class HitTestingService: if metadata_condition and not document_ids_filter: return cls.compact_retrieve_response(query, []) all_documents = RetrievalService.retrieve( - retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), + retrieval_method=RetrievalMethod( + resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH) + ), dataset_id=dataset.id, query=query, attachment_ids=attachment_ids, - top_k=retrieval_model.get("top_k", 4), - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] + top_k=resolved_retrieval_model.get("top_k", 4), + score_threshold=resolved_retrieval_model.get("score_threshold", 0.0) + if resolved_retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] + reranking_model=resolved_retrieval_model.get("reranking_model", None) + if resolved_retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model", + weights=resolved_retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, ) diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 68ef67dec1..8b4983e5f7 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index aa7456dcd3..8c9a81af87 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -50,7 +50,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param language: language :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod @@ -60,5 +60,5 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): :param template_id: Template ID :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(template_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 0ffbef8365..9d446f6d4b 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, TypedDict import yaml from sqlalchemy import select @@ -10,6 +10,30 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class CustomizedTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + position: int + chunk_structure: str + + +class CustomizedTemplatesResultDict(TypedDict): + pipeline_templates: list[CustomizedTemplateItemDict] + + +class CustomizedTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + created_by: str + + class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval recommended app from database @@ -17,12 +41,10 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): def get_pipeline_templates(self, language: str) -> dict[str, Any]: _, current_tenant_id = current_account_with_tenant() - result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) - return result + return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.CUSTOMIZED @@ -40,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) ).all() - recommended_pipelines_results = [] + recommended_pipelines_results: list[CustomizedTemplateItemDict] = [] for pipeline_customized_template in pipeline_customized_templates: - recommended_pipeline_result = { + recommended_pipeline_result: CustomizedTemplateItemDict = { "id": pipeline_customized_template.id, "name": pipeline_customized_template.name, "description": pipeline_customized_template.description, diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 073eed221c..2964537c35 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, TypedDict import yaml from sqlalchemy import select @@ -9,18 +9,41 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class PipelineTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + copyright: str + privacy_policy: str + position: int + chunk_structure: str + + +class PipelineTemplatesResultDict(TypedDict): + pipeline_templates: list[PipelineTemplateItemDict] + + +class PipelineTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + + class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval pipeline template from database """ def get_pipeline_templates(self, language: str) -> dict[str, Any]: - result = self.fetch_pipeline_templates_from_db(language) - return result + return self.fetch_pipeline_templates_from_db(language) def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.DATABASE @@ -39,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ).all() ) - recommended_pipelines_results = [] + recommended_pipelines_results: list[PipelineTemplateItemDict] = [] for pipeline_built_in_template in pipeline_built_in_templates: - recommended_pipeline_result = { + recommended_pipeline_result: PipelineTemplateItemDict = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, "description": pipeline_built_in_template.description, diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index d5ef745bec..9565ac46cc 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -17,21 +17,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: - result: dict[str, Any] | None try: - result = self.fetch_pipeline_template_detail_from_dify_official(template_id) + return self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) def get_pipeline_templates(self, language: str) -> dict[str, Any]: try: - result = self.fetch_pipeline_templates_from_dify_official(language) + return self.fetch_pipeline_templates_from_dify_official(language) except Exception as e: logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) def get_type(self) -> str: return PipelineTemplateType.REMOTE diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 968600d1bc..9db6682e10 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -476,7 +476,7 @@ class RagPipelineService: :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index a91f49e9e6..cf39469be8 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -349,7 +349,6 @@ class SummaryIndexService: summary_record_id, ) summary_record_in_session = DocumentSegmentSummary( - id=summary_record_id, # Use the same ID if available dataset_id=dataset.id, document_id=segment.document_id, chunk_id=segment.id, @@ -360,6 +359,9 @@ class SummaryIndexService: status=SummaryStatus.COMPLETED, enabled=True, ) + if summary_record_in_session is None: + raise RuntimeError("summary_record_in_session should not be None at this point") + summary_record_in_session.id = summary_record_id session.add(summary_record_in_session) logger.info( "Created new summary record (id=%s) for segment %s after vectorization", diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 8ad010f62d..5ff2c21749 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -4,6 +4,7 @@ from typing import Any, TypedDict, cast from httpx import get from sqlalchemy import select +from sqlalchemy.orm import sessionmaker from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_runtime import ToolRuntime @@ -15,6 +16,7 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ) +from core.tools.errors import ApiToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter @@ -116,71 +118,85 @@ class ApiToolManageService: privacy_policy: str, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - create api tool provider + Create a new API tool provider. + + :param user_id: The ID of the user creating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # Create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is not None: - raise ValueError(f"provider {provider_name} already exists") + if provider is not None: + raise ValueError(f"provider {provider_name} already exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if len(tool_bundles) > 100: - raise ValueError("the number of apis should be less than 100") + if len(tool_bundles) > 100: + raise ValueError("the number of apis should be less than 100") - # create db provider - db_provider = ApiToolProvider( - tenant_id=tenant_id, - user_id=user_id, - name=provider_name, - icon=json.dumps(icon), - schema=schema, - description=extra_info.get("description", ""), - schema_type_str=schema_type, - tools_str=json.dumps(jsonable_encoder(tool_bundles)), - credentials_str="{}", - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - ) + # create API tool provider + api_tool_provider = ApiToolProvider( + tenant_id=tenant_id, + user_id=user_id, + name=provider_name, + icon=json.dumps(icon), + schema=schema, + description=extra_info.get("description", ""), + schema_type_str=schema_type, + tools_str=json.dumps(jsonable_encoder(tool_bundles)), + credentials_str="{}", + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + ) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # create provider entity + provider_controller = ApiToolProviderController.from_db(api_tool_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - # encrypt credentials - encrypter, _ = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) - db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) + # encrypt credentials + encrypter, _ = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) + api_tool_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) - db.session.add(db_provider) - db.session.commit() + _session.add(api_tool_provider) - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) return {"result": "success"} @@ -212,16 +228,25 @@ class ApiToolManageService: @staticmethod def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]: """ - list api tool provider tools + List tools provided by a specific API tool provider. + + :param user_id: The ID of the user requesting the list. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :return: A list of ToolApiEntity objects. """ - provider: ApiToolProvider | None = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) if provider is None: raise ValueError(f"you have not added provider {provider_name}") @@ -251,103 +276,133 @@ class ApiToolManageService: privacy_policy: str | None, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - update api tool provider + Update an existing API tool provider. + + :param user_id: The ID of the user updating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The new name of the API tool provider. + :param original_provider: The original name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param _schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + if provider is None: + raise ApiToolProviderNotFoundError(provider_name=original_provider, tenant_id=tenant_id) - # update db provider - provider.name = provider_name - provider.icon = json.dumps(icon) - provider.schema = schema - provider.description = extra_info.get("description", "") - provider.schema_type_str = schema_type - provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) - provider.privacy_policy = privacy_policy - provider.custom_disclaimer = custom_disclaimer + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + # update db provider + provider.name = provider_name + provider.icon = json.dumps(icon) + provider.schema = schema + provider.description = extra_info.get("description", "") + provider.schema_type_str = schema_type + provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) + provider.privacy_policy = privacy_policy + provider.custom_disclaimer = custom_disclaimer - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # create provider entity - provider_controller = ApiToolProviderController.from_db(provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # get original credentials if exists - encrypter, cache = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) + # create provider entity + provider_controller = ApiToolProviderController.from_db(provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - original_credentials = encrypter.decrypt(provider.credentials) - masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] + # get original credentials if exists + encrypter, cache = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) - credentials = dict(encrypter.encrypt(credentials)) - provider.credentials_str = json.dumps(credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - db.session.add(provider) - db.session.commit() + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + + credentials = dict(encrypter.encrypt(credentials)) + provider.credentials_str = json.dumps(credentials) + + _session.add(provider) + + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) # delete cache cache.delete() - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) - return {"result": "success"} @staticmethod def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + Delete an API tool provider. + + :param user_id: The ID of the user performing the deletion operation. + :param tenant_id: The ID of the workspace/tenant where the provider belongs. + :param provider_name: The unique name of the API tool provider to be deleted. + :raises ValueError: If the specified provider does not exist in the tenant. + :return: A dictionary indicating the result status. """ - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"you have not added provider {provider_name}") + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") - db.session.delete(provider) - db.session.commit() + _session.delete(provider) return {"result": "success"} @staticmethod - def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str) -> dict[str, Any]: """ - get api tool provider + Get API tool provider details. + + :param user_id: The ID of the user requesting the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider: The name of the API tool provider. + :return: A dictionary containing the provider details. """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) @@ -360,10 +415,20 @@ class ApiToolManageService: parameters: dict[str, Any], schema_type: ApiProviderSchemaType, schema: str, - ): + ) -> dict[str, Any]: """ - test api tool before adding api tool provider + Test an API tool before adding the API tool provider. + + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param tool_name: The name of the specific tool to test. + :param credentials: The credentials for the provider. + :param parameters: The parameters to pass to the tool. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :return: A dictionary containing the result or error message. """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema_type}") @@ -377,18 +442,21 @@ class ApiToolManageService: if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # create new session with automatic transaction management to get the provider + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if not db_provider: + if provider is None: # create a fake db provider - db_provider = ApiToolProvider( + provider = ApiToolProvider( tenant_id="", user_id="", name="", @@ -407,12 +475,12 @@ class ApiToolManageService: auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) + provider_controller = ApiToolProviderController.from_db(provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) # decrypt credentials - if db_provider.id: + if provider.id: encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, controller=provider_controller, @@ -443,14 +511,21 @@ class ApiToolManageService: @staticmethod def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ - list api tools + List all API tools for a specific tenant. + + :param tenant_id: The ID of the workspace/tenant. + :return: A list of ToolProviderApiEntity objects. """ # get all api providers - db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + # create new session with automatic transaction management + providers: list[ApiToolProvider] = [] + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + providers = list( + _session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + ) result: list[ToolProviderApiEntity] = [] - - for provider in db_providers: + for provider in providers: # convert provider controller to user provider provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index c96050ce13..1529c2b98f 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator): return TruncationResult(StringSegment(value=fallback_result.value), True) # Apply final fallback - convert to JSON string and truncate - json_str = dumps_with_segments(result.value, ensure_ascii=False) + json_str = dumps_with_segments(result.value) if len(json_str) > self._max_size_bytes: json_str = json_str[: self._max_size_bytes] + "..." return TruncationResult(result=StringSegment(value=json_str), truncated=True) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 5ec00ee336..96f936ff9b 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -1067,7 +1067,7 @@ class DraftVariableSaver: filename = f"{self._generate_filename(name)}.txt" else: # For other types, store as JSON - original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False) + original_content_serialized = dumps_with_segments(value_seg.value) content_type = "application/json" filename = f"{self._generate_filename(name)}.json" diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index d71223314e..d4b9095ce5 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.trigger.constants import is_trigger_node_type -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, - normalize_human_input_node_data_for_graph, + adapt_human_input_node_data_for_graph, parse_human_input_delivery_methods, ) from core.workflow.node_factory import ( @@ -791,7 +791,7 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config @@ -1096,7 +1096,7 @@ class WorkflowService: raise ValueError("Node type must be human-input.") node_data = HumanInputNodeData.model_validate( - normalize_human_input_node_data_for_graph(node_config["data"]), + adapt_human_input_node_data_for_graph(node_config["data"]), from_attributes=True, ) delivery_method = self._resolve_human_input_delivery_method( @@ -1237,9 +1237,10 @@ class WorkflowService: variable_pool=variable_pool, start_at=time.perf_counter(), ) + node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"])) node = HumanInputNode( - id=node_config["id"], - config=node_config, + node_id=node_config["id"], + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, runtime=DifyHumanInputNodeRuntime(run_context), @@ -1529,7 +1530,7 @@ class WorkflowService: from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) + HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index f8ae3f4b6e..2a60be7762 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail from graphon.runtime import GraphRuntimeState, VariablePool diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index b5318aaa2b..2392084c36 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,5 +1,6 @@ from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult, StreamCompletedEvent @@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=_GP(), graph_runtime_state=_GS(vp), ) diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e3476c292b..aaa6092993 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import NodeRunResult from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeNodeData from graphon.nodes.code.limits import CodeNodeLimits from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -64,8 +65,8 @@ def init_code_node(code_config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = CodeNode( - id=str(uuid.uuid4()), - config=code_config, + node_id=str(uuid.uuid4()), + config=CodeNodeData.model_validate(code_config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, code_executor=node_factory._code_executor, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index aa6cf1e021..b9f7b9575b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.file.file_manager import file_manager from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -75,8 +75,8 @@ def init_http_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, @@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=graph_config["nodes"][1], + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index fa5d63cfbf..3eead70163 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.llm.node import LLMNode from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory @@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode: llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=LLMNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 52886855b8..f2eabb86c3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -11,6 +11,7 @@ from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance @@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = ParameterExtractorNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ParameterExtractorNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 9e3e1a47e3..e2e0723fb8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph +from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError @@ -86,8 +87,8 @@ def test_execute_template_transform(): assert graph is not None node = TemplateTransformNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=TemplateTransformNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, jinja2_template_renderer=_SimpleJinja2Renderer(), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index f9ec51ee10..a8e9422c1e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.node_events import StreamCompletedEvent from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool.tool_node import ToolNode from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -60,8 +61,8 @@ def init_tool_node(config: dict): tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) node = ToolNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 15dec06311..18755ef012 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -234,6 +234,35 @@ class TestAppEndpoints: } ) + def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch): + api = app_module.AppIconApi() + method = _unwrap(api.post) + payload = { + "icon": "https://example.com/icon.png", + "icon_type": "image", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app_icon.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetail, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace()) + + assert response == {"id": "app-1"} + assert app_service.update_app_icon.call_args.args[1:] == ( + payload["icon"], + payload["icon_background"], + app_module.IconType.IMAGE, + ) + class TestOpsTraceEndpoints: @pytest.fixture diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 14d5740072..6524d6ce61 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import Session from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index da4f8847d6..5aed230cd4 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -101,8 +101,8 @@ def _build_graph( start_data = StartNodeData(title="start", variables=[]) start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, + node_id="start", + config=start_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -116,8 +116,8 @@ def _build_graph( ], ) human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, + node_id="human", + config=human_data, graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, @@ -130,8 +130,8 @@ def _build_graph( desc=None, ) end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, + node_id="end", + config=end_data, graph_init_params=params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 2e207ddc67..35e41035df 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase): file_related_id = related_id return File( - id=str(uuid4()), # Generate new UUID for File.id + file_id=str(uuid4()), # Generate new UUID for File.id tenant_id=tenant_id, - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=transfer_method, related_id=file_related_id, remote_url=remote_url, diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index aaf9a85d60..54b7afc018 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -271,7 +271,7 @@ def _create_recipient( def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: - from core.workflow.human_input_compat import DeliveryMethodType + from core.workflow.human_input_adapter import DeliveryMethodType from models.human_input import ConsoleDeliveryPayload delivery = HumanInputDelivery( diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index fa57dd4a6f..b695ae9fd9 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -658,15 +658,17 @@ class TestAppService: # Update app icon new_icon = "🌟" new_icon_background = "#FFD93D" + new_icon_type = "image" mock_current_user = create_autospec(Account, instance=True) mock_current_user.id = account.id mock_current_user.current_tenant_id = account.current_tenant_id with patch("services.app_service.current_user", mock_current_user): - updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) + updated_app = app_service.update_app_icon(app, new_icon, new_icon_background, new_icon_type) assert updated_app.icon == new_icon assert updated_app.icon_background == new_icon_background + assert str(updated_app.icon_type).lower() == new_icon_type assert updated_app.updated_by == account.id # Verify other fields remain unchanged diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py new file mode 100644 index 0000000000..2bec703f0c --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py @@ -0,0 +1,650 @@ +"""Testcontainers integration tests for SQL-backed DocumentService paths.""" + +import datetime +import json +from unittest.mock import create_autospec, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from core.rag.index_processor.constant.index_type import IndexStructureType +from extensions.storage.storage_type import StorageType +from models import Account +from models.dataset import Dataset, Document +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus +from models.model import UploadFile +from services.dataset_service import DocumentService +from services.errors.account import NoPermissionError + +FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0) + + +class DocumentServiceIntegrationFactory: + @staticmethod + def create_dataset( + db_session_with_containers, + *, + tenant_id: str | None = None, + created_by: str | None = None, + name: str | None = None, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id or str(uuid4()), + name=name or f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by or str(uuid4()), + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers, + *, + dataset: Dataset, + name: str = "doc.txt", + position: int = 1, + tenant_id: str | None = None, + indexing_status: str = IndexingStatus.COMPLETED, + enabled: bool = True, + archived: bool = False, + is_paused: bool = False, + need_summary: bool = False, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + batch: str | None = None, + data_source_type: str = DataSourceType.UPLOAD_FILE, + data_source_info: dict | None = None, + created_by: str | None = None, + ) -> Document: + document = Document( + tenant_id=tenant_id or dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type=data_source_type, + data_source_info=json.dumps(data_source_info or {}), + batch=batch or f"batch-{uuid4()}", + name=name, + created_from=DocumentCreatedFrom.WEB, + created_by=created_by or dataset.created_by, + doc_form=doc_form, + ) + document.indexing_status = indexing_status + document.enabled = enabled + document.archived = archived + document.is_paused = is_paused + document.need_summary = need_summary + if indexing_status == IndexingStatus.COMPLETED: + document.completed_at = FIXED_UPLOAD_CREATED_AT + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + @staticmethod + def create_upload_file( + db_session_with_containers, + *, + tenant_id: str, + created_by: str, + file_id: str | None = None, + name: str = "source.txt", + ) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.LOCAL, + key=f"uploads/{uuid4()}", + name=name, + size=128, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=created_by, + created_at=FIXED_UPLOAD_CREATED_AT, + used=False, + ) + if file_id: + upload_file.id = file_id + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + +@pytest.fixture +def current_user_mock(): + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.id = str(uuid4()) + current_user.current_tenant_id = str(uuid4()) + current_user.current_role = None + yield current_user + + +def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_document(dataset.id, None) is None + + +def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset) + + result = DocumentService.get_document(dataset.id, document.id) + + assert result is not None + assert result.id == document.id + + +def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + result = DocumentService.get_documents_by_ids(dataset.id, []) + + assert result == [] + + +def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt") + doc_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + name="b.txt", + position=2, + ) + + result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id]) + + assert {document.id for document in result} == {doc_a.id, doc_b.id} + + +def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.update_documents_need_summary(dataset.id, []) == 0 + + +def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + paragraph_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + need_summary=True, + ) + qa_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + need_summary=True, + doc_form=IndexStructureType.QA_INDEX, + ) + + updated_count = DocumentService.update_documents_need_summary( + dataset.id, + [paragraph_doc.id, qa_doc.id], + need_summary=False, + ) + + db_session_with_containers.expire_all() + refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id) + refreshed_qa = db_session_with_containers.get(Document, qa_doc.id) + assert updated_count == 1 + assert refreshed_paragraph is not None + assert refreshed_qa is not None + assert refreshed_paragraph.need_summary is False + assert refreshed_qa.need_summary is True + + +def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url: + result = DocumentService.get_document_download_url(document) + + assert result == "signed-url" + get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True) + + +def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_type=DataSourceType.WEBSITE_CRAWL, + data_source_info={"url": "https://example.com"}, + ) + + with pytest.raises(NotFound, match="invalid source"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={}, + ) + + with pytest.raises(NotFound, match="missing file"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": 99}, + ) + + result = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + assert result == "99" + + +def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": "missing-file"}, + ) + + with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}): + with pytest.raises(NotFound, match="Uploaded file not found"): + DocumentService._get_upload_file_for_upload_file_document(document) + + +def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + result = DocumentService._get_upload_file_for_upload_file_document(document) + + assert result.id == upload_file.id + + +def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with pytest.raises(NotFound, match="Document not found"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[str(uuid4())], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + tenant_id=str(uuid4()), + data_source_info={"upload_file_id": upload_file.id}, + ) + + with pytest.raises(Forbidden, match="No permission"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": str(uuid4())}, + ) + + with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + mapping = DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document_a.id, document_b.id], + tenant_id=dataset.tenant_id, + ) + + assert mapping[document_a.id].id == upload_file_a.id + assert mapping[document_b.id].id == upload_file_b.id + + +def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset( + current_user_mock, flask_app_with_containers +): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=str(uuid4()), + document_ids=[str(uuid4())], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + + with patch( + "services.dataset_service.DatasetService.check_dataset_permission", + side_effect=NoPermissionError("denied"), + ): + with pytest.raises(Forbidden, match="denied"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[document_b.id, document_a.id], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id] + assert download_name.endswith(".zip") + + +def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + enabled_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + enabled=False, + ) + + result = DocumentService.get_document_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [enabled_document.id] + + +def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + available_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.COMPLETED, + enabled=True, + archived=False, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.ERROR, + ) + + result = DocumentService.get_working_documents_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [available_document.id] + + +def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + error_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.ERROR, + ) + paused_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.PAUSED, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=3, + indexing_status=IndexingStatus.COMPLETED, + ) + + result = DocumentService.get_error_documents_by_dataset_id(dataset.id) + + assert {document.id for document in result} == {error_document.id, paused_document.id} + + +def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + batch = f"batch-{uuid4()}" + matching_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + batch=batch, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + tenant_id=str(uuid4()), + batch=batch, + ) + + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.current_tenant_id = dataset.tenant_id + result = DocumentService.get_batch_documents(dataset.id, batch) + + assert [document.id for document in result] == [matching_document.id] + + +def test_get_document_file_detail_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + + result = DocumentService.get_document_file_detail(upload_file.id) + + assert result is not None + assert result.id == upload_file.id + + +def test_delete_document_emits_signal_and_commits(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.document_was_deleted.send") as signal_send: + DocumentService.delete_document(document) + + assert db_session_with_containers.get(Document, document.id) is None + signal_send.assert_called_once_with( + document.id, + dataset_id=document.dataset_id, + doc_form=document.doc_form, + file_id=upload_file.id, + ) + + +def test_delete_documents_ignores_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, []) + + delay.assert_not_called() + + +def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX + db_session_with_containers.commit() + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, [document_a.id, document_b.id]) + + assert db_session_with_containers.get(Document, document_a.id) is None + assert db_session_with_containers.get(Document, document_b.id) is None + delay.assert_called_once() + args = delay.call_args.args + assert args[0] == [document_a.id, document_b.id] + assert args[1] == dataset.id + assert set(args[3]) == {upload_file_a.id, upload_file_b.id} + + +def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3) + + assert DocumentService.get_documents_position(dataset.id) == 4 + + +def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_documents_position(dataset.id) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 18c5320d0a..80f9083e81 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 21a54e909e..ed75363f3b 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -8,7 +8,7 @@ import pytest from sqlalchemy.engine import Engine from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index d3e765055a..af83adaae0 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -1,3 +1,5 @@ +import inspect +import json from unittest.mock import patch import pytest @@ -6,6 +8,8 @@ from pydantic import TypeAdapter, ValidationError from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.errors import ApiToolProviderNotFoundError +from core.tools.tool_label_manager import ToolLabelManager from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -590,30 +594,204 @@ class TestApiToolManageService: with pytest.raises(ValueError, match="you have not added provider"): ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") - def test_update_api_tool_provider_not_found( + def test_update_api_tool_provider_success( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): - """Test update raises ValueError when original provider not found.""" fake = Faker() + + # Firmware fix for cache.delete() in update flow + mock_encrypter = mock_external_service_dependencies["encrypter"] + from unittest.mock import MagicMock + + mock_cache = MagicMock() + mock_cache.delete.return_value = None + mock_encrypter.return_value = (mock_encrypter, mock_cache) + + # Get fake account and tenant account, tenant = self._create_test_account_and_tenant( db_session_with_containers, mock_external_service_dependencies ) - with pytest.raises(ValueError, match="does not exists"): - ApiToolManageService.update_api_tool_provider( + # original provider name + original_name = "original-provider" + + # Create original provider + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=original_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="", + custom_disclaimer="", + labels=["old-label"], + ) + + # new provide name and new labels for update + new_name = "updated-provider" + new_labels = ["new-label-1", "new-label-2"] + + # Reset mock history so assertions focus on update path only + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + + # Act: Update the provider with new values + result = ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + # new provider name - changed 1 + provider_name=new_name, + original_provider=original_name, + # new icon - changed 2 + icon={"type": "emoji", "value": "🚀"}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + # new privacy policy - changed 3 + privacy_policy="https://new-policy.com", + # new custom disclaimer - changed 4 + custom_disclaimer="New disclaimer", + # new labels - changed 5 (However, we will not verify this, not this layer responsibility.) + labels=new_labels, + ) + + # Assert: Verify the result + assert result == {"result": "success"} + + # Get the updated provider from the database + updated_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == new_name) + .first() + ) + + # Verify the provider was updated successfully + assert updated_provider is not None + + # Manually refresh to keep object detachment + db_session_with_containers.refresh(updated_provider) + # Verify all the updated fields + # - changed 1 + assert updated_provider.name == new_name + # - changed 2 + icon_data = json.loads(updated_provider.icon) + assert icon_data["type"] == "emoji" + assert icon_data["value"] == "🚀" + # - changed 3 + assert updated_provider.privacy_policy == "https://new-policy.com" + # - changed 4 + assert updated_provider.custom_disclaimer == "New disclaimer" + + # Verify old provider name no longer exists after rename + original_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == original_name) + .first() + ) + assert original_provider is None + + # Verify update flow calls critical collaborators + mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + mock_external_service_dependencies["encrypter"].assert_called_once() + mock_cache.delete.assert_called_once() + + # Deeply verify on session propagation of labels update logics: + # Since in refactoring, we pass session down to label manager to keep atomicity. + # The assertion here is to verify this. + sig = inspect.signature(ToolLabelManager.update_tool_labels) + args, kwargs = mock_external_service_dependencies["tool_label_manager"].update_tool_labels.call_args + bound_args = sig.bind(*args, **kwargs) + passed_session = bound_args.arguments.get("session") + # Ensure the type: Session + assert isinstance(passed_session, Session), f"Expected Session object, got {type(passed_session)}" + assert passed_session is not None, ( + "Atomicity Failure: Session cannot be passed to Label Manager in update_api_tool_provider" + ) + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update raises ValueError when original provider not found. + + This test verifies: + - Proper error when trying to update a non-existing original provider + - No accidental upsert/new provider creation + - No external dependency invocation on early failure path + """ + # Arrange: Create test account and tenant + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Keep an existing provider in DB to ensure unrelated data remains unchanged + existing_provider_name = "existing-provider" + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=existing_provider_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="https://existing-policy.com", + custom_disclaimer="Existing disclaimer", + labels=["existing-label"], + ) + + # Reset mock history so assertions focus on update failure path only + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + + # Act & Assert: Verify update fails with clear error message + target_new_name = "new-provider-name" + missing_original_name = "missing-original-provider" + with pytest.raises(ApiToolProviderNotFoundError) as exc_info: + _ = ApiToolManageService.update_api_tool_provider( user_id=account.id, tenant_id=tenant.id, - provider_name="new-name", - original_provider="nonexistent", - icon={}, + provider_name=target_new_name, + original_provider=missing_original_name, + icon={"type": "emoji", "value": "🚀"}, credentials={"auth_type": "none"}, _schema_type=ApiProviderSchemaType.OPENAPI, schema=self._create_test_openapi_schema(), - privacy_policy=None, - custom_disclaimer="", - labels=[], + privacy_policy="https://new-policy.com", + custom_disclaimer="New disclaimer", + labels=["new-label"], ) + error = exc_info.value + assert error.provider_name == missing_original_name + assert error.tenant_id == tenant.id + assert error.error_code == "api_tool_provider_not_found" + + # Assert: Existing provider should remain unchanged + existing_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == existing_provider_name) + .first() + ) + assert existing_provider is not None + assert existing_provider.name == existing_provider_name + + # Assert: No new provider should be created + unexpected_new_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == target_new_name) + .first() + ) + assert unexpected_new_provider is None + + # Assert: Early failure should skip all downstream external interactions + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_not_called() + mock_external_service_dependencies["encrypter"].assert_not_called() + mock_external_service_dependencies["provider_controller"].from_db.assert_not_called() + def test_update_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 328bdbf055..95a867dbb5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py index baac4cd4e0..1af15d8dc6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py +++ b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py @@ -1,6 +1,25 @@ import datetime +from types import SimpleNamespace +from unittest.mock import PropertyMock, patch -from controllers.console.app.mcp_server import AppMCPServerResponse +from flask import Flask + +from controllers.console import console_ns +from controllers.console.app.mcp_server import AppMCPServerController, AppMCPServerResponse + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class _ValidatedResponse: + def __init__(self, payload): + self._payload = payload + + def model_dump(self, mode="json"): + return self._payload class TestAppMCPServerResponse: @@ -40,6 +59,18 @@ class TestAppMCPServerResponse: resp = AppMCPServerResponse.model_validate(data) assert resp.parameters == {"already": "parsed"} + def test_parameters_json_array_parsed(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": '["a", "b"]', + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == ["a", "b"] + def test_timestamps_normalized(self): dt = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) data = { @@ -68,3 +99,40 @@ class TestAppMCPServerResponse: resp = AppMCPServerResponse.model_validate(data) assert resp.created_at is None assert resp.updated_at is None + + +class TestAppMCPServerController: + def test_get_returns_empty_dict_when_server_missing(self): + api = AppMCPServerController() + method = unwrap(api.get) + + with patch("controllers.console.app.mcp_server.db.session.scalar", return_value=None): + response = method(api, app_model=SimpleNamespace(id="app-1")) + + assert response == {} + + def test_post_returns_201(self): + api = AppMCPServerController() + method = unwrap(api.post) + payload = {"parameters": {"timeout": 30}} + app = Flask(__name__) + app.config["TESTING"] = True + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")), + patch("controllers.console.app.mcp_server.db.session.add"), + patch("controllers.console.app.mcp_server.db.session.commit"), + patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"), + patch( + "controllers.console.app.mcp_server.AppMCPServerResponse.model_validate", + return_value=_ValidatedResponse({"id": "server-1"}), + ), + ): + response, status_code = method( + api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description") + ) + + assert response == {"id": "server-1"} + assert status_code == 201 diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 6ff3b19362..e91c0a0597 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: file_list = [ File( tenant_id="t1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="http://u", ) diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index b19a1740eb..22b80b748e 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url(): # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", filename="test.jpg", @@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url(): # Create a File object with REMOTE_URL transfer method test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", filename="test.jpg", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 94d6c17915..9465936f28 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -1772,6 +1772,21 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "http://localhost:5000/v1" + def test_get_api_base_url_no_double_v1(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + "https://example.com/v1", + ), + ): + response = method(api) + + assert response["api_base_url"] == "https://example.com/v1" + class TestDatasetRetrievalSettingApi: def test_get_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index ce2278de4f..d9b02ac453 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -215,17 +216,23 @@ class TestDatasetDocumentListApi: method = unwrap(api.post) payload = {"indexing_technique": "economy"} + created_dataset = SimpleNamespace(id="ds-1", name="Dataset", indexing_technique="economy") + created_document = SimpleNamespace(id="doc-1", name="Document", doc_metadata_details=None) with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=created_dataset, + ), patch( "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", - return_value=([MagicMock()], "batch-1"), + return_value=([created_document], "batch-1"), ), ): response = method(api, "ds-1") diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index c513be950b..26ff264f18 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -68,7 +68,10 @@ class TestChangeEmailSend: mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) - mock_get_change_data.return_value = {"email": "current@example.com"} + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + } mock_send_email.return_value = "token-abc" with app.test_request_context( @@ -85,12 +88,55 @@ class TestChangeEmailSend: email="new@example.com", old_email="current@example.com", language="en-US", - phase="new_email", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, ) mock_extract_ip.assert_called_once() mock_is_ip_limit.assert_called_once_with("127.0.0.1") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_change_data, + mock_current_account, + mock_db, + app, + ): + """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("current@example.com", "acc1") + mock_current_account.return_value = (mock_account, None) + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailSendEmailApi().post() + + mock_send_email.assert_not_called() + class TestChangeEmailValidity: @patch("controllers.console.wraps.db") @@ -122,7 +168,12 @@ class TestChangeEmailValidity: mock_account = _build_account("user@example.com", "acc2") mock_current_account.return_value = (mock_account, None) mock_is_rate_limit.return_value = False - mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"} + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } mock_generate_token.return_value = (None, "new-token") with app.test_request_context( @@ -138,11 +189,169 @@ class TestChangeEmailValidity: mock_add_rate.assert_not_called() mock_revoke_token.assert_called_once_with("token-123") mock_generate_token.assert_called_once_with( - "user@example.com", code="1234", old_email="old@example.com", additional_data={} + "user@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + }, ) mock_reset_rate.assert_called_once_with("user@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_upgrade_new_phase_token_to_new_verified( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "new@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + } + mock_generate_token.return_value = (None, "new-verified-token") + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "new@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailCheckApi().post() + + assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"} + mock_generate_token.assert_called_once_with( + "new@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + }, + ) + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_phase_is_unknown( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + """A token whose phase marker is a string but not a known transition must be rejected.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_has_no_phase( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + class TestChangeEmailReset: @patch("controllers.console.wraps.db") @@ -175,7 +384,11 @@ class TestChangeEmailReset: mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True - mock_get_data.return_value = {"old_email": "OLD@example.com"} + mock_get_data.return_value = { + "email": "new@example.com", + "old_email": "OLD@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } mock_account_after_update = _build_account("new@example.com", "acc3-updated") mock_update_account.return_value = mock_account_after_update @@ -194,6 +407,155 @@ class TestChangeEmailReset: mock_send_notify.assert_called_once_with(email="new@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_phase_is_not_new_verified( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + # Simulate a token straight out of step #1 (phase=old_email) — exactly + # the replay used in the advisory PoC. + mock_get_data.return_value = { + "email": "old@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-from-step1"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_email_differs_from_payload_new_email( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + """A verified token for address A must not be replayed to change to address B.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + mock_get_data.return_value = { + "email": "verified@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-verified"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + +class TestAccountServiceSendChangeEmailEmail: + """Service-level coverage for the phase-bound changes in `send_change_email_email`.""" + + def test_should_raise_value_error_for_invalid_phase(self): + with pytest.raises(ValueError, match="phase must be one of"): + AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + phase="old_email_verified", + ) + + @patch("services.account_service.send_change_mail_task") + @patch("services.account_service.AccountService.change_email_rate_limiter") + @patch("services.account_service.AccountService.generate_change_email_token") + def test_should_stamp_phase_into_generated_token( + self, + mock_generate_token, + mock_rate_limiter, + mock_mail_task, + ): + mock_rate_limiter.is_rate_limited.return_value = False + mock_generate_token.return_value = ("123456", "the-token") + + returned = AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + language="en-US", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, + ) + + assert returned == "the-token" + mock_generate_token.assert_called_once_with( + "user@example.com", + None, + old_email="user@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + }, + ) + mock_mail_task.delay.assert_called_once() + mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com") + class TestAccountDeletionFeedback: @patch("controllers.console.wraps.db") diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index 0895fac3a4..d1b09c3a58 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -41,17 +41,22 @@ class TestTenantUserPayload: class TestGetUser: """Test get_user function""" + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_return_existing_user_by_id( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test returning existing user when found by ID""" # Arrange mock_user = MagicMock() mock_user.id = "user123" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -59,13 +64,45 @@ class TestGetUser: # Assert assert result == mock_user - mock_session.get.assert_called_once() + mock_session.scalar.assert_called_once() + @patch("controllers.inner_api.plugin.wraps.select") + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_not_resolve_non_anonymous_users_across_tenants( + self, + mock_db, + mock_sessionmaker, + mock_enduser_class, + mock_select, + app: Flask, + ): + """Test that explicit user IDs remain scoped to the current tenant.""" + # Arrange + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + mock_new_user = MagicMock() + mock_new_user.tenant_id = "tenant-current" + mock_enduser_class.return_value = mock_new_user + + # Act + with app.app_context(): + result = get_user("tenant-current", "foreign-user-id") + + # Assert + assert result == mock_new_user + mock_session.get.assert_not_called() + mock_session.scalar.assert_called_once() + mock_session.add.assert_called_once_with(mock_new_user) + + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") def test_should_return_existing_anonymous_user_by_session_id( - self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask ): """Test returning existing anonymous user by session_id""" # Arrange @@ -73,8 +110,9 @@ class TestGetUser: mock_user.session_id = "anonymous_session" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - # non-anonymous path uses session.get(); anonymous uses session.scalar() - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -83,17 +121,22 @@ class TestGetUser: # Assert assert result == mock_user + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_create_new_user_when_not_found( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test creating new user when not found in database""" # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = None + mock_session.scalar.return_value = None mock_new_user = MagicMock() mock_enduser_class.return_value = mock_new_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -134,7 +177,7 @@ class TestGetUser: # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.side_effect = Exception("Database error") + mock_session.scalar.side_effect = Exception("Database error") # Act & Assert with app.app_context(): diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index 14c35a9ed5..4fb8ecf784 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import ( ConversationVariableUpdatePayload, ) from controllers.service_api.app.error import NotChatAppError +from fields._value_type_serializer import serialize_value_type +from graphon.variables import StringSegment from graphon.variables.types import SegmentType from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService @@ -284,6 +286,32 @@ class TestConversationVariableResponseModels: assert response.created_at == int(created_at.timestamp()) assert response.updated_at == int(created_at.timestamp()) + def test_variable_response_normalizes_string_value_type_alias(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER.value, + } + ) + + assert response.value_type == "number" + + def test_variable_response_normalizes_callable_exposed_type(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()), + } + ) + + assert response.value_type == "string" + + def test_serialize_value_type_supports_segments_and_mappings(self): + assert serialize_value_type(StringSegment(value="hello")) == "string" + assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number" + def test_variable_pagination_response(self): response = ConversationVariableInfiniteScrollPaginationResponse.model_validate( { diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 3ab63aed25..dd6cd0e919 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: def create_test_file(self, file_id: str = "test_file_1") -> File: """Create a test File object""" return File( - id=file_id, - type=FileType.DOCUMENT, + file_id=file_id, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", filename=f"{file_id}.txt", diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index a04a7b7576..6104b8d6ca 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401 from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow import node_factory as node_factory_module from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph @@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]): def version(cls) -> str: return "1" - def init_node_data(self, data): - self._node_data = _StubToolNodeData.model_validate(data) + def __init__( + self, + node_id: str, + config: _StubToolNodeData, + *, + graph_init_params, + graph_runtime_state, + **_kwargs: Any, + ) -> None: + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) def _get_error_strategy(self): return self._node_data.error_strategy @@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): - original_create_node = DifyNodeFactory.create_node + original_resolve_node_class = node_factory_module.resolve_workflow_node_class - def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) - node_data = typed_node_config["data"] - if node_data.type == BuiltinNodeTypes.TOOL: - return _StubToolNode( - id=str(typed_node_config["id"]), - config=typed_node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return original_create_node(self, typed_node_config) + def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + if node_type == BuiltinNodeTypes.TOOL: + return _StubToolNode + return original_resolve_node_class(node_type=node_type, node_version=node_version) - mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) + mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class) def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index cddd03f4b0..701863b927 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -26,8 +26,8 @@ def _build_file( extension: str | None = None, ) -> File: return File( - id="file-id", - type=FileType.IMAGE, + file_id="file-id", + file_type=FileType.IMAGE, transfer_method=transfer_method, reference=reference, remote_url=remote_url, @@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M assert runtime.multimodal_send_format == "url" - with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get: assert runtime.http_get("http://example", follow_redirects=False) == "response" mock_get.assert_called_once_with("http://example", follow_redirects=False) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index c4bfb23272..30a068f4c5 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes class DummyNode: - def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): - self.id = id + def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = node_id self.config = config self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index 81315d2508..deeac49bbc 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", extension=".png", @@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", extension=".pdf", diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index a0b2820157..aeca2e3afd 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None: # Assert assert simple_provider.provider == "openai" - assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.label.en_us == "OpenAI" assert simple_provider.supported_model_types == [ModelType.LLM] diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index fe2c226843..a28143026f 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) ] ) - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="encrypted-old-key" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): - with patch( - "core.entities.provider_configuration.encrypter.encrypt_token", - side_effect=lambda tenant_id, value: f"enc::{value}", - ): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, - credential_id="credential-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + ) assert validated["openai_api_key"] == "enc::restored-key" assert validated["region"] == "us" @@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) -def test_validate_provider_credentials_opens_session_when_not_passed() -> None: +def test_validate_provider_credentials_without_credential_id() -> None: configuration = _build_provider_configuration() - mock_session = Mock() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.Session") as mock_session_cls: - with patch("core.entities.provider_configuration.db") as mock_db: - mock_db.engine = Mock() - mock_session_cls.return_value.__enter__.return_value = mock_session - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory - ): - validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} - mock_session_cls.assert_called_once() def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: @@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non def test_validate_provider_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-key"} @@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback( def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( encrypted_config='{"openai_api_key":"enc"}' ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) - assert validated == {"openai_api_key": "enc-new"} - - session = Mock() - mock_factory = Mock() - mock_factory.model_credentials_validate.return_value = {"region": "us"} - with _patched_session(session): + with _patched_session(mock_session): with patch( "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory ): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"region": "us"}, - ) + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) + assert validated == {"openai_api_key": "enc-new"} + + mock_factory2 = Mock() + mock_factory2.model_credentials_validate.return_value = {"region": "us"} + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) assert validated == {"region": "us"} @@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = None + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = None mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} @@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index bb6e40e224..8cb0938575 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType def test_file(): file = File( - id="test-file", + file_id="test-file", tenant_id="test-tenant-id", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="test-related-id", filename="image.png", @@ -25,27 +25,21 @@ def test_file(): assert file.size == 67 -def test_file_model_validate_accepts_legacy_tenant_id(): - data = { - "id": "test-file", - "tenant_id": "test-tenant-id", - "type": "image", - "transfer_method": "tool_file", - "related_id": "test-related-id", - "filename": "image.png", - "extension": ".png", - "mime_type": "image/png", - "size": 67, - "storage_key": "test-storage-key", - "url": "https://example.com/image.png", - # Extra legacy fields - "tool_file_id": "tool-file-123", - "upload_file_id": "upload-file-456", - "datasource_file_id": "datasource-file-789", - } +def test_file_constructor_accepts_legacy_tenant_id(): + file = File( + file_id="test-file", + tenant_id="test-tenant-id", + file_type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + tool_file_id="tool-file-123", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) - file = File.model_validate(data) - - assert file.related_id == "test-related-id" + assert file.related_id == "tool-file-123" assert file.storage_key == "test-storage-key" assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 3b5c5e6597..d9fed9ae2a 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -1,11 +1,17 @@ from unittest.mock import MagicMock, patch +import httpx import pytest from core.helper.ssrf_proxy import ( SSRF_DEFAULT_MAX_RETRIES, + SSRFProxy, _get_user_provided_host_header, + _to_graphon_http_response, + graphon_ssrf_proxy, make_request, + max_retries_exceeded_error, + request_error, ) @@ -174,3 +180,56 @@ class TestFollowRedirectsParameter: call_kwargs = mock_client.request.call_args.kwargs assert call_kwargs.get("follow_redirects") is True + + +def test_to_graphon_http_response_preserves_httpx_response_fields() -> None: + response = httpx.Response( + 201, + headers={"X-Test": "1"}, + content=b"payload", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + wrapped = _to_graphon_http_response(response) + + assert wrapped.status_code == 201 + assert wrapped.headers == {"x-test": "1", "content-length": "7"} + assert wrapped.content == b"payload" + assert wrapped.url == "https://example.com/resource" + assert wrapped.reason_phrase == "Created" + assert wrapped.text == "payload" + + +def test_ssrf_proxy_exposes_expected_error_types() -> None: + proxy = SSRFProxy() + + assert proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert proxy.request_error is request_error + assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert graphon_ssrf_proxy.request_error is request_error + + +@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"]) +def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None: + response = httpx.Response( + 200, + headers={"X-Test": "1"}, + content=b"ok", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method: + wrapped = getattr(graphon_ssrf_proxy, method_name)( + "https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + + mock_method.assert_called_once_with( + url="https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + assert wrapped.status_code == 200 + assert wrapped.url == "https://example.com/resource" + assert wrapped.content == b"ok" diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py index 81f8da9a62..bbbffa2e69 100644 --- a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py @@ -971,6 +971,23 @@ class TestHandlePostRequestNew: assert isinstance(item, SessionMessage) assert isinstance(item.message.root, JSONRPCError) assert item.message.root.id == 77 + assert item.message.root.error.message == "Session terminated by server" + + def test_404_on_initialization_includes_url_in_error(self): + t = _new_transport(url="http://example.com/mcp/server/abc123/mcp") + q: queue.Queue = queue.Queue() + msg = _make_request_msg("initialize", 1) + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.error.code == 32600 + assert "404 Not Found" in item.message.root.error.message + assert "http://example.com/mcp/server/abc123/mcp" in item.message.root.error.message def test_404_for_notification_no_error_sent(self): t = _new_transport() diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 249ecb5006..c4fd970562 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 2cbff54c42..69650c85cc 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -1,16 +1,11 @@ -import pytest -from pydantic import ValidationError +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_opik.config import OpikConfig +from dify_trace_weave.config import WeaveConfig -from core.ops.entities.config_entity import ( - AliyunConfig, - ArizeConfig, - LangfuseConfig, - LangSmithConfig, - OpikConfig, - PhoenixConfig, - TracingProviderEnum, - WeaveConfig, -) +from core.ops.entities.config_entity import TracingProviderEnum class TestTracingProviderEnum: @@ -27,349 +22,8 @@ class TestTracingProviderEnum: assert TracingProviderEnum.ALIYUN == "aliyun" -class TestArizeConfig: - """Test cases for ArizeConfig""" - - def test_valid_config(self): - """Test valid Arize configuration""" - config = ArizeConfig( - api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" - ) - assert config.api_key == "test_key" - assert config.space_id == "test_space" - assert config.project == "test_project" - assert config.endpoint == "https://custom.arize.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = ArizeConfig() - assert config.api_key is None - assert config.space_id is None - assert config.project is None - assert config.endpoint == "https://otlp.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = ArizeConfig(project="") - assert config.project == "default" - - def test_project_validation_none(self): - """Test project validation with None value""" - config = ArizeConfig(project=None) - assert config.project == "default" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = ArizeConfig(endpoint="") - assert config.endpoint == "https://otlp.arize.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") - assert config.endpoint == "https://custom.arize.com" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="ftp://invalid.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="invalid.com") - - -class TestPhoenixConfig: - """Test cases for PhoenixConfig""" - - def test_valid_config(self): - """Test valid Phoenix configuration""" - config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.phoenix.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = PhoenixConfig() - assert config.api_key is None - assert config.project is None - assert config.endpoint == "https://app.phoenix.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = PhoenixConfig(project="") - assert config.project == "default" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation with path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") - assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" - - def test_endpoint_validation_without_path(self): - """Test endpoint validation without path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") - assert config.endpoint == "https://app.phoenix.arize.com" - - -class TestLangfuseConfig: - """Test cases for LangfuseConfig""" - - def test_valid_config(self): - """Test valid Langfuse configuration""" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == "https://custom.langfuse.com" - - def test_valid_config_with_path(self): - host = "https://custom.langfuse.com/api/v1" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == host - - def test_default_values(self): - """Test default values are set correctly""" - config = LangfuseConfig(public_key="public", secret_key="secret") - assert config.host == "https://api.langfuse.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangfuseConfig() - - with pytest.raises(ValidationError): - LangfuseConfig(public_key="public") - - with pytest.raises(ValidationError): - LangfuseConfig(secret_key="secret") - - def test_host_validation_empty(self): - """Test host validation with empty value""" - config = LangfuseConfig(public_key="public", secret_key="secret", host="") - assert config.host == "https://api.langfuse.com" - - -class TestLangSmithConfig: - """Test cases for LangSmithConfig""" - - def test_valid_config(self): - """Test valid LangSmith configuration""" - config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.smith.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = LangSmithConfig(api_key="key", project="project") - assert config.endpoint == "https://api.smith.langchain.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangSmithConfig() - - with pytest.raises(ValidationError): - LangSmithConfig(api_key="key") - - with pytest.raises(ValidationError): - LangSmithConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") - - -class TestOpikConfig: - """Test cases for OpikConfig""" - - def test_valid_config(self): - """Test valid Opik configuration""" - config = OpikConfig( - api_key="test_key", - project="test_project", - workspace="test_workspace", - url="https://custom.comet.com/opik/api/", - ) - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.workspace == "test_workspace" - assert config.url == "https://custom.comet.com/opik/api/" - - def test_default_values(self): - """Test default values are set correctly""" - config = OpikConfig() - assert config.api_key is None - assert config.project is None - assert config.workspace is None - assert config.url == "https://www.comet.com/opik/api/" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = OpikConfig(project="") - assert config.project == "Default Project" - - def test_url_validation_empty(self): - """Test URL validation with empty value""" - config = OpikConfig(url="") - assert config.url == "https://www.comet.com/opik/api/" - - def test_url_validation_missing_suffix(self): - """Test URL validation requires /api/ suffix""" - with pytest.raises(ValidationError, match="URL should end with /api/"): - OpikConfig(url="https://custom.comet.com/opik/") - - def test_url_validation_invalid_scheme(self): - """Test URL validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - OpikConfig(url="ftp://custom.comet.com/opik/api/") - - -class TestWeaveConfig: - """Test cases for WeaveConfig""" - - def test_valid_config(self): - """Test valid Weave configuration""" - config = WeaveConfig( - api_key="test_key", - entity="test_entity", - project="test_project", - endpoint="https://custom.wandb.ai", - host="https://custom.host.com", - ) - assert config.api_key == "test_key" - assert config.entity == "test_entity" - assert config.project == "test_project" - assert config.endpoint == "https://custom.wandb.ai" - assert config.host == "https://custom.host.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = WeaveConfig(api_key="key", project="project") - assert config.entity is None - assert config.endpoint == "https://trace.wandb.ai" - assert config.host is None - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - WeaveConfig() - - with pytest.raises(ValidationError): - WeaveConfig(api_key="key") - - with pytest.raises(ValidationError): - WeaveConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") - - def test_host_validation_optional(self): - """Test host validation is optional but validates when provided""" - config = WeaveConfig(api_key="key", project="project", host=None) - assert config.host is None - - config = WeaveConfig(api_key="key", project="project", host="") - assert config.host == "" - - config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") - assert config.host == "https://valid.host.com" - - def test_host_validation_invalid_scheme(self): - """Test host validation rejects invalid schemes when provided""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") - - -class TestAliyunConfig: - """Test cases for AliyunConfig""" - - def test_valid_config(self): - """Test valid Aliyun configuration""" - config = AliyunConfig( - app_name="test_app", - license_key="test_license_key", - endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", - ) - assert config.app_name == "test_app" - assert config.license_key == "test_license_key" - assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - assert config.app_name == "dify_app" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - AliyunConfig() - - with pytest.raises(ValidationError): - AliyunConfig(license_key="test_license") - - with pytest.raises(ValidationError): - AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_app_name_validation_empty(self): - """Test app_name validation with empty value""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" - ) - assert config.app_name == "dify_app" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = AliyunConfig(license_key="test_license", endpoint="") - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation preserves path for Aliyun endpoints""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - ) - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_license_key_required(self): - """Test that license_key is required and cannot be empty""" - with pytest.raises(ValidationError): - AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_valid_endpoint_format_examples(self): - """Test valid endpoint format examples from comments""" - valid_endpoints = [ - # cms2.0 public endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", - # cms2.0 intranet endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", - # xtrace public endpoint - "http://tracing-cn-heyuan.arms.aliyuncs.com", - # xtrace intranet endpoint - "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", - ] - - for endpoint in valid_endpoints: - config = AliyunConfig(license_key="test_license", endpoint=endpoint) - assert config.endpoint == endpoint - - class TestConfigIntegration: - """Integration tests for configuration classes""" + """Cross-provider configuration sanity checks""" def test_all_configs_can_be_instantiated(self): """Test that all config classes can be instantiated with valid data""" @@ -388,7 +42,6 @@ class TestConfigIntegration: def test_url_normalization_consistency(self): """Test that URL normalization works consistently across configs""" - # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index 68aa130518..88bf555594 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -56,7 +56,7 @@ class TestPluginModelRuntime: assert len(providers) == 1 assert providers[0].provider == "langgenius/openai/openai" assert providers[0].provider_name == "openai" - assert providers[0].label.en_US == "OpenAI" + assert providers[0].label.en_us == "OpenAI" client.fetch_model_providers.assert_called_once_with("tenant") def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index d49b6e4b71..00a4207786 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -466,7 +466,7 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): file_param = File( tenant_id="tenant-1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/file.png", storage_key="", @@ -499,14 +499,14 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): file_one = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/a.txt", storage_key="", ) file_two = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/b.txt", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 395d392127..e536c0831f 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg files = [ File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="", @@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files(): completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2") file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query(): memory = MagicMock(spec=TokenBufferMemory) prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 803afa54d7..28966242d8 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 9f9ea33695..5308c8e7b3 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from graphon.model_runtime.entities.message_entities import UserPromptMessage # from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule # from graphon.model_runtime.entities.provider_entities import ProviderEntity -# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 64eb89590a..0220fb6d4a 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,12 +1,14 @@ """Primarily used for testing merged cell scenarios""" +import gc import io import os import tempfile +import warnings from collections import UserDict from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from docx import Document @@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): WordExtractor("not-a-file", "tenant", "user") -def test_del_closes_temp_file(): +def test_close_closes_temp_file(): extractor = object.__new__(WordExtractor) + extractor._closed = False extractor.temp_file = MagicMock() - WordExtractor.__del__(extractor) + extractor.close() extractor.temp_file.close.assert_called_once() +def test_close_is_idempotent(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + + extractor.close() + extractor.close() + + extractor.temp_file.close.assert_called_once() + + +def test_close_handles_async_close_mock(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + extractor.temp_file.close = AsyncMock() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + extractor.close() + gc.collect() + + extractor.temp_file.close.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] + + def test_extract_images_handles_invalid_external_cases(monkeypatch): class FakeTargetRef: def __contains__(self, item): diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 8be1ac318c..18ae9fafc8 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -14,7 +14,7 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 1297a95df1..4248782d93 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -21,7 +21,7 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index f17927f16b..eab0176f41 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -6,9 +6,9 @@ from models.workflow import Workflow def test_file_to_dict(): file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="storage_key", diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index f45b43082c..a5a542c94f 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -372,6 +372,78 @@ def test_get_configurations_binds_manager_runtime_to_provider_configuration( provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) +def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls, + patch( + "core.provider_manager.ProviderConfiguration", + return_value=provider_configuration, + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is second + mock_get_all_providers.assert_called_once_with("tenant-id") + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_provider_configuration.assert_called_once() + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_clear_configurations_cache_rebuilds_requested_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + provider_configuration_first = Mock() + provider_configuration_second = Mock() + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch( + "core.provider_manager.ProviderConfiguration", + side_effect=[provider_configuration_first, provider_configuration_second], + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + manager.clear_configurations_cache("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is not second + assert mock_get_all_providers.call_count == 2 + assert mock_provider_configuration.call_count == 2 + provider_configuration_first.bind_model_runtime.assert_called_once_with(manager._model_runtime) + provider_configuration_second.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): manager = _build_provider_manager(mocker) provider_configuration = Mock() diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 72052c8c05..9e07ea1b6d 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,8 +1,9 @@ import dataclasses +from typing import Annotated import orjson import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Discriminator, Tag from core.helper import encrypter from core.workflow.system_variables import build_bootstrap_variables, build_system_variables @@ -12,17 +13,18 @@ from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( ArrayAnySegment, + ArrayBooleanSegment, ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + BooleanSegment, FileSegment, FloatSegment, IntegerSegment, NoneSegment, ObjectSegment, Segment, - SegmentUnion, StringSegment, get_segment_discriminator, ) @@ -47,6 +49,26 @@ from graphon.variables.variables import ( StringVariable, Variable, ) +from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup + +type SegmentUnion = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] + ), + Discriminator(get_segment_discriminator), +] def _build_variable_pool( @@ -123,7 +145,7 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, @@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad: assert restored == model def test_all_segments_serialization(self): - """Test serialization/deserialization of all segment types""" + """Test file-aware segment serialization through Dify's model boundary.""" # Create one instance of each segment type test_file = create_test_file() @@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Segments(segments=all_segments) json_str = model.model_dump_json() - loaded = _Segments.model_validate_json(json_str) + loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all segments are preserved assert len(loaded.segments) == len(all_segments) @@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad: assert loaded_segment.value == original.value def test_all_variables_serialization(self): - """Test serialization/deserialization of all variable types""" + """Test file-aware variable serialization through Dify's model boundary.""" # Create one instance of each variable type test_file = create_test_file() @@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Variables(variables=all_variables) json_str = model.model_dump_json() - loaded = _Variables.model_validate_json(json_str) + loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all variables are preserved assert len(loaded.variables) == len(all_variables) diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 94e788edb2..317fe99d37 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -35,7 +35,7 @@ def create_test_file( """Factory function to create File objects for testing.""" return File( tenant_id="test-tenant", - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 76b2984a4b..9f3e3b00b9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -1,12 +1,13 @@ -""" -Mock node factory for testing workflows with third-party service dependencies. +"""Mock node factory for third-party-service workflow tests. -This module provides a MockNodeFactory that automatically detects and mocks nodes -requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). +The factory follows the same config adaptation path as production +`DifyNodeFactory.create_node()`, but swaps selected node classes for mock +implementations before instantiation. """ from typing import TYPE_CHECKING, Any +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType @@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory): :param node_config: Node configuration dictionary :return: Node instance (real or mocked) """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) + node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_type = node_data.type # Check if this node type should be mocked if node_type in self._mock_node_types: - node_id = typed_node_config["id"] - # Create mock node instance mock_class = self._mock_node_types[node_type] + resolved_node_data = self._validate_resolved_node_data(mock_class, node_data) if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory): ) elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory): BuiltinNodeTypes.PARAMETER_EXTRACTOR, }: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory): ) else: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory): return mock_instance # For non-mocked node types, use parent implementation - return super().create_node(typed_node_config) + return super().create_node(node_config) def should_mock_node(self, node_type: NodeType) -> bool: """ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 971b9b2bbf..f9819c47ec 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -55,13 +55,14 @@ class MockNodeMixin: def __init__( self, - id: str, - config: Mapping[str, Any], + node_id: str, + config: Any, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", mock_config: Optional["MockConfig"] = None, **kwargs: Any, - ): + ) -> None: if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) @@ -96,7 +97,7 @@ class MockNodeMixin: kwargs.setdefault("message_transformer", MagicMock()) super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 55a329eba9..75bc6d05f7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -139,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( - id=start_config["id"], - config=start_config, + node_id=start_config["id"], + config=StartNodeData(title="Start", variables=[]), graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -154,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_a_config = {"id": "human_a", "data": human_data.model_dump()} human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, + node_id=human_a_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -164,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_b_config = {"id": "human_b", "data": human_data.model_dump()} human_b = HumanInputNode( - id=human_b_config["id"], - config=human_b_config, + node_id=human_b_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -182,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor ) end_config = {"id": "end", "data": end_data.model_dump()} end_node = EndNode( - id=end_config["id"], - config=end_config, + node_id=end_config["id"], + config=end_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 9c0ad25b58..76b4cd1ef4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -9,6 +9,7 @@ from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.answer.answer_node import AnswerNode +from graphon.nodes.answer.entities import AnswerNodeData from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -66,20 +67,15 @@ def test_execute_answer(): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - } - node = AnswerNode( - id=str(uuid.uuid4()), + node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, + config=AnswerNodeData( + title="123", + type="answer", + answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + ), ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 9cceadde49..d7ef781732 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,5 +1,6 @@ from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -77,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=gp, graph_runtime_state=gs, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index a3cadc0681..2e89a2da3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -12,7 +12,7 @@ from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus from graphon.file.file_manager import file_manager from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config(): assert default_config["retry_config"]["max_retries"] == 7 -def test_get_default_config_with_malformed_http_request_config_raises_value_error(): - with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"): +def test_get_default_config_with_malformed_http_request_config_raises_type_error(): + with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"): HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"}) @@ -114,8 +114,8 @@ def _build_http_node( start_at=time.perf_counter(), ) return HttpRequestNode( - id="http-node", - config=node_config, + node_id="http-node", + config=HttpRequestNodeData.model_validate(node_data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index 1d6a4da7c4..07430498e5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,4 +1,4 @@ -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index c0e21d0bf7..0659984c76 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -19,7 +19,7 @@ from core.repositories.human_input_repository import ( HumanInputFormRecipientEntity, HumanInputFormRepository, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryMethodType, EmailDeliveryConfig, EmailDeliveryMethod, @@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): entity.status_value = HumanInputFormStatus.SUBMITTED +def _build_human_input_node( + *, + node_id: str, + node_data: HumanInputNodeData | Mapping[str, Any], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + runtime: DifyHumanInputNodeRuntime, +) -> HumanInputNode: + typed_node_data = ( + node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data) + ) + return HumanInputNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + runtime=runtime, + ) + + class TestDeliveryMethod: """Test DeliveryMethod entity.""" @@ -239,7 +259,7 @@ class TestUserAction: data[field_name] = value with pytest.raises(ValidationError) as exc_info: - UserAction(**data) + UserAction.model_validate(data) errors = exc_info.value.errors() assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) @@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent: form_repository = InMemoryHumanInputFormRepository() runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index bc98028d5b..4a9438b14f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -11,6 +11,7 @@ from graphon.graph_events import ( NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) +from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool @@ -25,6 +26,28 @@ class _FakeFormRepository: return self._form +def _create_human_input_node( + *, + config: dict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + repo: _FakeFormRepository, +) -> HumanInputNode: + node_data = ( + config["data"] + if isinstance(config["data"], HumanInputNodeData) + else HumanInputNodeData.model_validate(config["data"]) + ) + return HumanInputNode( + node_id=config["id"], + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + ) + + def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( @@ -80,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) @@ -145,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode: ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 82cc734274..8ffce39cd6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -5,6 +5,7 @@ import pytest from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams +from graphon.nodes.iteration.entities import IterationNodeData from graphon.nodes.iteration.exc import IterationGraphNotFoundError from graphon.nodes.iteration.iteration_node import IterationNode from graphon.runtime import ( @@ -44,17 +45,14 @@ def _build_iteration_node( ) -> IterationNode: init_params = build_test_graph_init_params(graph_config=graph_config) return IterationNode( - id="iteration-node", - config={ - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration-node", "output"], - "start_node_id": start_node_id, - }, - }, + node_id="iteration-node", + config=IterationNodeData( + type="iteration", + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id=start_node_id, + ), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index a6fca1bfb4..f254fc3d09 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -93,6 +93,25 @@ def sample_chunks(): } +def _build_node( + *, + node_id: str, + node_data: KnowledgeIndexNodeData | dict[str, object], + graph_init_params, + graph_runtime_state, +) -> KnowledgeIndexNode: + return KnowledgeIndexNode( + node_id=node_id, + config=( + node_data + if isinstance(node_data, KnowledgeIndexNodeData) + else KnowledgeIndexNodeData.model_validate(node_data) + ), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + class TestKnowledgeIndexNode: """ Test suite for KnowledgeIndexNode. @@ -115,9 +134,9 @@ class TestKnowledgeIndexNode: } # Act - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -143,9 +162,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -176,9 +195,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -212,9 +231,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -269,9 +288,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -332,9 +351,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -383,9 +402,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -440,9 +459,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -498,9 +517,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -536,9 +555,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -583,9 +602,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 45e8ae7d20..e923ee761b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -14,7 +14,11 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( SingleRetrievalConfig, ) from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import ( + KnowledgeRetrievalNode, + _normalize_metadata_filter_scalar, + _normalize_metadata_filter_sequence_item, +) from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowNodeExecutionStatus @@ -85,6 +89,12 @@ def sample_node_data(): ) +def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None: + assert _normalize_metadata_filter_scalar(3) == 3 + assert _normalize_metadata_filter_scalar(True) == "True" + assert _normalize_metadata_filter_sequence_item(4) == "4" + + class TestKnowledgeRetrievalNode: """ Test suite for KnowledgeRetrievalNode. @@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode: # Act node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -470,8 +480,8 @@ class TestFetchDatasetRetriever: config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -507,8 +517,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -562,8 +572,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -610,8 +620,8 @@ class TestFetchDatasetRetriever: mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -671,8 +681,8 @@ class TestFetchDatasetRetriever: node_id = str(uuid.uuid4()) config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index eca34f05be..388654f279 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock import pytest @@ -5,6 +6,7 @@ import pytest from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.entities import ListOperatorNodeData from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment @@ -13,11 +15,28 @@ from graphon.variables import ArrayNumberSegment, ArrayStringSegment class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" + @staticmethod + def _build_node(*, config, graph_init_params, graph_runtime_state): + return ListOperatorNode( + node_id="test", + config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + @staticmethod + def _filter_by(comparison_operator: str, value: str) -> dict[str, object]: + return { + "enabled": True, + "conditions": [{"comparison_operator": comparison_operator, "value": value}], + } + @pytest.fixture def mock_graph_runtime_state(self): """Create mock GraphRuntimeState.""" mock_state = MagicMock(spec=GraphRuntimeState) mock_variable_pool = MagicMock() + mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value) mock_state.variable_pool = mock_variable_pool return mock_state @@ -45,9 +64,8 @@ class TestListOperatorNode: def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable - return ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + return self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -64,9 +82,8 @@ class TestListOperatorNode: "limit": {"enabled": False}, } - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -109,9 +126,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=[]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -128,11 +144,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "contains", - "value": "app", - }, + "filter_by": self._filter_by("contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -140,9 +152,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -157,11 +168,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "not contains", - "value": "app", - }, + "filter_by": self._filter_by("not contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -169,9 +176,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -186,11 +192,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "5", - }, + "filter_by": self._filter_by(">", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -198,9 +200,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -226,9 +227,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -254,9 +254,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -282,9 +281,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -299,11 +297,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "3", - }, + "filter_by": self._filter_by(">", "3"), "order_by": { "enabled": True, "value": "desc", @@ -317,9 +311,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -341,9 +334,8 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = None - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -366,9 +358,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["first", "middle", "last"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -384,11 +375,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "start with", - "value": "app", - }, + "filter_by": self._filter_by("start with", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -396,9 +383,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -413,11 +399,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "end with", - "value": "le", - }, + "filter_by": self._filter_by("end with", "le"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -425,9 +407,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -442,11 +423,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "=", - "value": "5", - }, + "filter_by": self._filter_by("=", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -454,9 +431,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -471,11 +447,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "≠", - "value": "5", - }, + "filter_by": self._filter_by("≠", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -483,9 +455,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -511,9 +482,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index 4186bbdc93..212ad07bd3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -71,8 +71,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -95,6 +95,8 @@ def variable_pool() -> VariablePool: def _fetch_prompt_messages_with_mocked_content(content): variable_pool = VariablePool.empty() model_instance = mock.MagicMock(spec=ModelInstance) + model_schema = mock.MagicMock() + model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text" prompt_template = [ LLMNodeChatModelMessage( text="You are a classifier.", @@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( "graphon.nodes.llm.llm_utils.fetch_model_schema", - return_value=mock.MagicMock(features=[]), + return_value=model_schema, ), mock.patch( "graphon.nodes.llm.llm_utils.handle_list_messages", diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index b1f81b6c48..c707cf28cd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -140,8 +140,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -205,14 +205,10 @@ def llm_node( mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers(): def test_fetch_files_with_file_segment(): file = File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment(): def test_fetch_files_with_array_file_segment(): files = [ File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", storage_key="", ), File( - id="2", - type=FileType.IMAGE, + file_id="2", + file_type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="2", @@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/png", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/jpg", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: image_b64_data = base64.b64encode(image_raw_data).decode() mock_saved_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", extension=".png", @@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable file_saver=file_saver, file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, reasoning_format="separated", ) ) @@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=_build_prepared_llm_mock(), reasoning_format="separated", request_start_time=1.0, @@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors(): file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=model_instance, ) ) diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index bc44ececd8..892f6cc586 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -13,6 +13,28 @@ from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params +def _build_template_transform_node( + *, + node_data, + graph_init_params, + graph_runtime_state, + node_id: str = "test_node", + **kwargs, +) -> TemplateTransformNode: + typed_node_data = ( + node_data + if isinstance(node_data, TemplateTransformNodeData) + else TemplateTransformNodeData.model_validate(node_data) + ) + return TemplateTransformNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + **kwargs, + ) + + class TestTemplateTransformNode: """Comprehensive test suite for TemplateTransformNode.""" @@ -59,9 +81,8 @@ class TestTemplateTransformNode: def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test that TemplateTransformNode initializes correctly.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -75,9 +96,8 @@ class TestTemplateTransformNode: def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_title method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -88,9 +108,8 @@ class TestTemplateTransformNode: def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_description method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -108,9 +127,8 @@ class TestTemplateTransformNode: } mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -143,9 +161,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() with pytest.raises(ValueError, match="max_output_length must be a positive integer"): - TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -170,9 +187,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -198,9 +214,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Value: " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -218,9 +233,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error") - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -238,9 +252,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -260,9 +273,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "1234567890" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -302,9 +314,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -375,8 +386,8 @@ class TestTemplateTransformNode: ) assert mapping == { - "node_123.var1": ["sys", "input1"], - "node_123.empty_selector": [], + "node_123.var1": ("sys", "input1"), + "node_123.empty_selector": (), } def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): @@ -409,9 +420,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a static message." - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -448,9 +458,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Total: $31.5" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -477,9 +486,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -507,9 +515,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Tags: #python #ai #workflow " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index 636237e56e..a846efbb43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -4,6 +4,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import ( DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, TemplateTransformNode, @@ -37,15 +38,13 @@ def mock_graph_runtime_state(): def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): node = TemplateTransformNode( - id="test_node", - config={ - "id": "test_node", - "data": { - "title": "Template Transform", - "variables": [], - "template": "hello", - }, - }, + node_id="test_node", + config=TemplateTransformNodeData( + title="Template Transform", + type="template-transform", + variables=[], + template="hello", + ), graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=MagicMock(), @@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri assert mapping == { "node_123.validated": ["sys", "input1"], - "node_123.raw": ["sys", "input2"], + "node_123.raw": ("sys", "input2"), } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 0522dd9d14..364408ead6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -7,7 +7,6 @@ from core.workflow.node_runtime import resolve_dify_run_context from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import GraphRuntimeState, VariablePool @@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - }, - } - ) +def _build_node_config() -> dict[str, object]: + return { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + ), + } + + +def _build_node_data() -> _SampleNodeData: + return _build_node_config()["data"] # type: ignore[return-value] def test_node_hydrates_data_during_initialization(): @@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization(): init_params, runtime_state = _build_context(graph_config) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum(): ) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error(): def test_base_node_data_keeps_dict_style_access_compatibility(): - node_data = _SampleNodeData.model_validate( - { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - } - ) + node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar") assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" @@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): def test_node_hydration_preserves_compatibility_extra_fields(): graph_config: dict[str, object] = {} init_params, runtime_state = _build_context(graph_config) - node_config = NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - "compat_flag": True, - }, - } - ) + node_config = { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + compat_flag=True, + ), + } node = _SampleNode( - id="node-1", - config=node_config, + node_id="node-1", + config=node_config["data"], graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 87ec2d5bce..dd75b32593 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -11,14 +11,16 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod from graphon.node_events import NodeRunResult from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, + _extract_text_from_file, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from graphon.variables import ArrayFileSegment +from graphon.variables import ArrayFileSegment, FileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params): title="Test Document Extractor", variable_selector=["node_id", "variable_name"], ) - node_config = {"id": "test_node_id", "data": node_data.model_dump()} http_client = Mock() node = DocumentExtractorNode( - id="test_node_id", - config=node_config, + node_id="test_node_id", + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=Mock(), http_client=http_client, @@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"] - mock_excel_instance.parse.side_effect = [df, Exception("Parse error")] + mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_mixed_content" @@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"] - mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")] + mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_all_bad_sheets" @@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): assert mock_excel_instance.parse.call_count == 2 +@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook")) +def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file): + with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"): + _extract_text_from_excel(b"broken") + + @patch("pandas.ExcelFile") def test_extract_text_from_excel_numeric_type_column(mock_excel_file): """Test extracting text from Excel file with numeric column names.""" @@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): assert expected_manual == result +@pytest.mark.parametrize( + ("extension", "mime_type"), + [ + (".xlsx", "text/plain"), + (None, "application/vnd.ms-excel"), + ], +) +def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type): + file = Mock(spec=File) + file.extension = extension + file.mime_type = mime_type + + with ( + patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"excel", + ), + patch( + "graphon.nodes.document_extractor.node._extract_text_from_excel", + return_value="excel text", + ) as mock_extract, + ): + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + assert result == "excel text" + mock_extract.assert_called_once_with(b"excel") + + +def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node): + file = Mock(spec=File) + file.extension = None + file.mime_type = None + + with patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"unknown", + ): + with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"): + _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + +def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_list = Mock(spec=ArrayFileSegment) + file_list.value = [Mock(spec=File)] + mock_graph_runtime_state.variable_pool.get.return_value = file_list + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("bad file"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "bad file" + + +def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("single file failed"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "single file failed" + + +def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + return_value="single file text", + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs == {"text": "single file text"} + + def _make_docx_zip(use_backslash: bool) -> bytes: """Helper to build a minimal in-memory DOCX zip. diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 782750e02e..aa9a1360b0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -19,6 +19,20 @@ from graphon.variables import ArrayFileSegment from tests.workflow_test_utils import build_test_graph_init_params +def _build_if_else_node( + *, + node_data: IfElseNodeData | dict[str, object], + init_params, + graph_runtime_state, +) -> IfElseNode: + return IfElseNode( + node_id=str(uuid.uuid4()), + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), + ) + + def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} @@ -61,9 +75,8 @@ def test_execute_if_else_result_true(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "and", @@ -104,13 +117,8 @@ def test_execute_if_else_result_true(): {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -155,9 +163,8 @@ def test_execute_if_else_result_false(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "or", @@ -174,13 +181,8 @@ def test_execute_if_else_result_false(): }, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -222,11 +224,6 @@ def test_array_file_contains_file_name(): ], ) - node_config = { - "id": "if-else", - "data": node_data.model_dump(), - } - # Create properly configured mock for graph_init_params graph_init_params = Mock() graph_init_params.workflow_id = "test_workflow" @@ -242,17 +239,12 @@ def test_array_file_contains_file_name(): } } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=graph_init_params, - graph_runtime_state=Mock(), - config=node_config, - ) + node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock()) node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", filename="ab", @@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition): "logical_operator": "and", "conditions": [condition.model_dump()], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() @@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions(): ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={ - "id": "if-else", - "data": node_data, - }, ) # Mock db.session.close() @@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure(): } ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index b217e4e8e7..465a4c0ff4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -19,6 +19,15 @@ from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract from graphon.variables import ArrayFileSegment +def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode: + return ListOperatorNode( + node_id="test_node_id", + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=MagicMock(), + ) + + @pytest.fixture def list_operator_node(): config = { @@ -35,10 +44,6 @@ def list_operator_node(): "title": "Test Title", } node_data = ListOperatorNodeData.model_validate(config) - node_config = { - "id": "test_node_id", - "data": node_data.model_dump(), - } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() graph_init_params.workflow_id = "test_workflow" @@ -54,12 +59,7 @@ def list_operator_node(): } } - node = ListOperatorNode( - id="test_node_id", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=MagicMock(), - ) + node = _build_list_operator_node(node_data, graph_init_params) node.graph_runtime_state = MagicMock() node.graph_runtime_state.variable_pool = MagicMock() return node @@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node): files = [ File( filename="image1.jpg", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", ), File( filename="document1.pdf", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", ), File( filename="image2.png", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", ), File( filename="audio1.mp3", - type=FileType.AUDIO, + file_type=FileType.AUDIO, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", extension=".txt", @@ -156,7 +156,7 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, extension=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 543f9878de..5655f80737 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables): inputs=user_inputs, ) - config = { - "id": "start", - "data": StartNodeData(title="Start", variables=variables).model_dump(), - } + node_data = StartNodeData(title="Start", variables=variables) graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, @@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables): ) return StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, @@ -109,7 +106,7 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"): + with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"): node._run() @@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot(): inputs={"profile": {"age": 20, "name": "Tom"}}, ) - config = { - "id": "start", - "data": StartNodeData( - title="Start", - variables=[ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ], - ).model_dump(), - } + node_data = StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index c806181340..284af68319 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -13,6 +13,7 @@ from core.workflow.system_variables import build_system_variables from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment @@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode: runtime = _StubToolRuntime() node = ToolNode( - id="node-instance", - config=config, + node_id="node-instance", + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, @@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode: return node -def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: +def _collect_events(generator: Generator) -> list[Any]: events: list[Any] = [] try: while True: events.append(next(generator)) - except StopIteration as stop: - return events, stop.value + except StopIteration: + return events def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: @@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li node_id=tool_node._node_id, tool_runtime=ToolRuntimeHandle(raw=object()), ) - return _collect_events(generator) + events = _collect_events(generator) + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert completed_events + return events, completed_events[-1].node_run_result.llm_usage def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", @@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index c8ddc53284..e3b5e3b591 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,10 +1,10 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.runtime import GraphRuntimeState from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -27,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": TRIGGER_PLUGIN_NODE_TYPE, - "title": "Trigger Event", - "plugin_id": "plugin-id", - "provider_id": "provider-id", - "event_name": "event-name", - "subscription_id": "subscription-id", - "plugin_unique_identifier": "plugin-unique-identifier", - "event_parameters": {}, - }, - } +def _build_node_data() -> TriggerEventNodeData: + return TriggerEventNodeData( + type=TRIGGER_PLUGIN_NODE_TYPE, + title="Trigger Event", + plugin_id="plugin-id", + provider_id="provider-id", + event_name="event-name", + subscription_id="subscription-id", + plugin_unique_identifier="plugin-unique-identifier", + event_parameters={}, ) def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: init_params, runtime_state = _build_context(graph_config={}) node = TriggerEventNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 1bbc12b23f..07d03bec05 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -30,11 +30,6 @@ def create_webhook_node( tenant_id: str = "test-tenant", ) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "webhook-node-1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="test-workflow", graph_config={}, @@ -56,8 +51,8 @@ def create_webhook_node( ) node = TriggerWebhookNode( - id="webhook-node-1", - config=node_config, + node_id="webhook-node-1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -66,10 +61,6 @@ def create_webhook_node( runtime_state.app_config = Mock() runtime_state.app_config.tenant_id = tenant_id - # Provide compatibility alias expected by node implementation - # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests - node.node_id = node.id - return node diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 427afa96ec..b839490d3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -24,11 +24,6 @@ from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="1", graph_config={}, @@ -48,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) start_at=0, ) node = TriggerWebhookNode( - id="1", - config=node_config, + node_id="1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -57,9 +52,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) # Provide tenant_id for conversion path runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() - # Compatibility alias for some nodes referencing `self.node_id` - node.node_id = node.id - return node @@ -225,7 +217,7 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="image.jpg", @@ -234,7 +226,7 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", filename="document.pdf", @@ -269,8 +261,19 @@ def test_webhook_node_run_with_file_params(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() @@ -284,7 +287,7 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="test.jpg", @@ -317,8 +320,19 @@ def test_webhook_node_run_mixed_parameters(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py new file mode 100644 index 0000000000..8b5fceeb37 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py @@ -0,0 +1,350 @@ +from types import SimpleNamespace + +import pytest +from pydantic import BaseModel + +from core.workflow.human_input_adapter import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + adapt_human_input_node_data_for_graph, + adapt_node_config_for_graph, + adapt_node_data_for_graph, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert " Team") - html = EmailDeliveryConfig.render_markdown_body( - "**Hello** [mail](mailto:test@example.com)" - ) - - assert rendered == "Open https://example.com and use 42" - assert sanitized == "Hello alert(1) Team" - assert "Hello" in html - assert ".txt' }) render(, { wrapper: createWrapper() }) - expect(screen.getByText('.txt')).toBeInTheDocument() + expect(screen.getByText('.txt'))!.toBeInTheDocument() }) it('should memoize the component', () => { @@ -343,7 +344,7 @@ describe('DocumentTableRow', () => { const { rerender } = render(, { wrapper }) rerender() - expect(screen.getByRole('row')).toBeInTheDocument() + expect(screen.getByRole('row'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx index 5461f34921..0d51837cf2 100644 --- a/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx +++ b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx @@ -39,7 +39,7 @@ const getFileExtension = (fileName: string): string => { const parts = fileName.split('.') if (parts.length <= 1 || (parts[0] === '' && parts.length === 2)) return '' - return parts[parts.length - 1].toLowerCase() + return parts[parts.length - 1]!.toLowerCase() } const DocumentSourceIcon: FC = React.memo(({ diff --git a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts index 449478eb7b..9eebae4f81 100644 --- a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts +++ b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts @@ -27,7 +27,7 @@ vi.mock('@/service/knowledge/use-document', () => ({ useDocumentDownloadZip: () => ({ mutateAsync: mockDownloadZip, isPending: mockIsDownloadingZip }), })) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: { success: mockToastSuccess, error: mockToastError, diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts index 8b6c40e2be..a46c4fcfcc 100644 --- a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts +++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts @@ -1,7 +1,7 @@ import type { CommonResponse } from '@/models/common' +import { toast } from '@langgenius/dify-ui/toast' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { toast } from '@/app/components/base/ui/toast' import { DocumentActionType } from '@/models/datasets' import { useDocumentArchive, diff --git a/web/app/components/datasets/documents/components/documents-header.tsx b/web/app/components/datasets/documents/components/documents-header.tsx index 4e098f5eda..bd74ad0487 100644 --- a/web/app/components/datasets/documents/components/documents-header.tsx +++ b/web/app/components/datasets/documents/components/documents-header.tsx @@ -4,13 +4,13 @@ import type { Item } from '@/app/components/base/select' import type { BuiltInMetadataItem, MetadataItemWithValueLength } from '@/app/components/datasets/metadata/types' import type { SortType } from '@/service/datasets' import { PlusIcon } from '@heroicons/react/24/solid' +import { Button } from '@langgenius/dify-ui/button' import { RiDraftLine, RiExternalLinkLine } from '@remixicon/react' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import Chip from '@/app/components/base/chip' import Input from '@/app/components/base/input' import Sort from '@/app/components/base/sort' -import { Button } from '@/app/components/base/ui/button' import AutoDisabledDocument from '@/app/components/datasets/common/document-status-with-action/auto-disabled-document' import IndexFailed from '@/app/components/datasets/common/document-status-with-action/index-failed' import StatusWithAction from '@/app/components/datasets/common/document-status-with-action/status-with-action' diff --git a/web/app/components/datasets/documents/components/empty-element.tsx b/web/app/components/datasets/documents/components/empty-element.tsx index 6eacf89264..506b8ab6db 100644 --- a/web/app/components/datasets/documents/components/empty-element.tsx +++ b/web/app/components/datasets/documents/components/empty-element.tsx @@ -1,8 +1,8 @@ 'use client' import type { FC } from 'react' import { PlusIcon } from '@heroicons/react/24/solid' +import { Button } from '@langgenius/dify-ui/button' import { useTranslation } from 'react-i18next' -import { Button } from '@/app/components/base/ui/button' import s from '../style.module.css' import { FolderPlusIcon, NotionIcon, ThreeDotsIcon } from './icons' diff --git a/web/app/components/datasets/documents/components/list.tsx b/web/app/components/datasets/documents/components/list.tsx index e40e4c061b..abd40c33a0 100644 --- a/web/app/components/datasets/documents/components/list.tsx +++ b/web/app/components/datasets/documents/components/list.tsx @@ -117,8 +117,8 @@ const DocumentList: FC = ({ return (
- - +
+
e.stopPropagation()}> diff --git a/web/app/components/datasets/documents/components/operations.tsx b/web/app/components/datasets/documents/components/operations.tsx index e7bbb03c94..8692da927d 100644 --- a/web/app/components/datasets/documents/components/operations.tsx +++ b/web/app/components/datasets/documents/components/operations.tsx @@ -1,15 +1,6 @@ import type { OperationName } from '../types' import type { CommonResponse } from '@/models/common' import type { DocumentDownloadResponse } from '@/service/datasets' -import { cn } from '@langgenius/dify-ui/cn' -import { useBoolean, useDebounceFn } from 'ahooks' -import { noop } from 'es-toolkit/function' -import * as React from 'react' -import { useCallback, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Divider from '@/app/components/base/divider' -import Switch from '@/app/components/base/switch' -import Tooltip from '@/app/components/base/tooltip' import { AlertDialog, AlertDialogActions, @@ -18,13 +9,22 @@ import { AlertDialogContent, AlertDialogDescription, AlertDialogTitle, -} from '@/app/components/base/ui/alert-dialog' +} from '@langgenius/dify-ui/alert-dialog' +import { cn } from '@langgenius/dify-ui/cn' import { DropdownMenu, DropdownMenuContent, DropdownMenuTrigger, -} from '@/app/components/base/ui/dropdown-menu' -import { toast } from '@/app/components/base/ui/toast' +} from '@langgenius/dify-ui/dropdown-menu' +import { Switch } from '@langgenius/dify-ui/switch' +import { toast } from '@langgenius/dify-ui/toast' +import { useBoolean, useDebounceFn } from 'ahooks' +import { noop } from 'es-toolkit/function' +import * as React from 'react' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Divider from '@/app/components/base/divider' +import Tooltip from '@/app/components/base/tooltip' import { IS_CE_EDITION } from '@/config' import { DataSourceType, DocumentActionType } from '@/models/datasets' import { useRouter } from '@/next/navigation' diff --git a/web/app/components/datasets/documents/components/rename-modal.tsx b/web/app/components/datasets/documents/components/rename-modal.tsx index c6f393f1ce..fc4626676b 100644 --- a/web/app/components/datasets/documents/components/rename-modal.tsx +++ b/web/app/components/datasets/documents/components/rename-modal.tsx @@ -1,13 +1,13 @@ 'use client' import type { FC } from 'react' +import { Button } from '@langgenius/dify-ui/button' +import { toast } from '@langgenius/dify-ui/toast' import { useBoolean } from 'ahooks' import * as React from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { Button } from '@/app/components/base/ui/button' -import { toast } from '@/app/components/base/ui/toast' import { renameDocumentName } from '@/service/datasets' type Props = { diff --git a/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx index 8a2e251770..7daff43a8b 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx @@ -569,7 +569,7 @@ describe('StepOneContent', () => { it('should render VectorSpaceFull when isShowVectorSpaceFull is true', () => { render() - expect(screen.getByTestId('vector-space-full')).toBeInTheDocument() + expect(screen.getByTestId('vector-space-full'))!.toBeInTheDocument() }) it('should not render VectorSpaceFull when isShowVectorSpaceFull is false', () => { @@ -587,7 +587,7 @@ describe('StepOneContent', () => { localFileListLength={2} />, ) - expect(screen.getByTestId('upgrade-card')).toBeInTheDocument() + expect(screen.getByTestId('upgrade-card'))!.toBeInTheDocument() }) it('should not render UpgradeCard when supportBatchUpload is true', () => { @@ -618,7 +618,7 @@ describe('StepOneContent', () => { render() const nextButton = screen.getByRole('button', { name: /datasetCreation.stepOne.button/i }) - expect(nextButton).toBeDisabled() + expect(nextButton)!.toBeDisabled() }) }) @@ -664,17 +664,17 @@ describe('StepTwoContent', () => { it('should render ProcessDocuments component', () => { render() - expect(screen.getByTestId('process-documents')).toBeInTheDocument() + expect(screen.getByTestId('process-documents'))!.toBeInTheDocument() }) it('should pass dataSourceNodeId to ProcessDocuments', () => { render() - expect(screen.getByTestId('datasource-node-id')).toHaveTextContent('custom-node') + expect(screen.getByTestId('datasource-node-id'))!.toHaveTextContent('custom-node') }) it('should pass isRunning to ProcessDocuments', () => { render() - expect(screen.getByTestId('is-running')).toHaveTextContent('true') + expect(screen.getByTestId('is-running'))!.toHaveTextContent('true') }) it('should call onProcess when process button is clicked', () => { @@ -709,18 +709,18 @@ describe('StepThreeContent', () => { it('should render Processing component', () => { render() - expect(screen.getByTestId('processing')).toBeInTheDocument() + expect(screen.getByTestId('processing'))!.toBeInTheDocument() }) it('should pass batchId to Processing', () => { render() - expect(screen.getByTestId('batch-id')).toHaveTextContent('batch-123') + expect(screen.getByTestId('batch-id'))!.toHaveTextContent('batch-123') }) it('should pass documents count to Processing', () => { const documents = [{ id: '1' }, { id: '2' }] render() - expect(screen.getByTestId('documents-count')).toHaveTextContent('2') + expect(screen.getByTestId('documents-count'))!.toHaveTextContent('2') }) }) @@ -787,8 +787,8 @@ describe('StepOnePreview', () => { currentLocalFile={createMockFile()} />, ) - expect(screen.getByTestId('file-preview')).toBeInTheDocument() - expect(screen.getByTestId('file-name')).toHaveTextContent('test.txt') + expect(screen.getByTestId('file-preview'))!.toBeInTheDocument() + expect(screen.getByTestId('file-name'))!.toHaveTextContent('test.txt') }) it('should render OnlineDocumentPreview when currentDocument is set', () => { @@ -799,7 +799,7 @@ describe('StepOnePreview', () => { currentDocument={createMockNotionPage()} />, ) - expect(screen.getByTestId('online-document-preview')).toBeInTheDocument() + expect(screen.getByTestId('online-document-preview'))!.toBeInTheDocument() }) it('should render WebsitePreview when currentWebsite is set', () => { @@ -809,7 +809,7 @@ describe('StepOnePreview', () => { currentWebsite={createMockCrawlResult()} />, ) - expect(screen.getByTestId('web-preview')).toBeInTheDocument() + expect(screen.getByTestId('web-preview'))!.toBeInTheDocument() }) it('should call hidePreviewLocalFile when hide button is clicked', () => { @@ -868,22 +868,22 @@ describe('StepTwoPreview', () => { it('should render ChunkPreview component', () => { render() - expect(screen.getByTestId('chunk-preview')).toBeInTheDocument() + expect(screen.getByTestId('chunk-preview'))!.toBeInTheDocument() }) it('should pass datasourceType to ChunkPreview', () => { render() - expect(screen.getByTestId('datasource-type')).toHaveTextContent(DatasourceType.onlineDocument) + expect(screen.getByTestId('datasource-type'))!.toHaveTextContent(DatasourceType.onlineDocument) }) it('should pass isIdle to ChunkPreview', () => { render() - expect(screen.getByTestId('is-idle')).toHaveTextContent('false') + expect(screen.getByTestId('is-idle'))!.toHaveTextContent('false') }) it('should pass isPendingPreview to ChunkPreview', () => { render() - expect(screen.getByTestId('is-pending')).toHaveTextContent('true') + expect(screen.getByTestId('is-pending'))!.toHaveTextContent('true') }) it('should call onPreview when preview button is clicked', () => { @@ -1092,7 +1092,7 @@ describe('Store Hooks', () => { mockStoreState.selectedFileIds = ['file-1'] const { result } = renderHook(() => useOnlineDrive()) expect(result.current.selectedOnlineDriveFileList).toHaveLength(1) - expect(result.current.selectedOnlineDriveFileList[0].id).toBe('file-1') + expect(result.current.selectedOnlineDriveFileList[0]!.id).toBe('file-1') }) }) }) @@ -1166,8 +1166,8 @@ describe('useDatasourceOptions', () => { const { result } = renderHook(() => useDatasourceOptions(mockNodes)) expect(result.current).toHaveLength(1) - expect(result.current[0].label).toBe('Local File Source') - expect(result.current[0].value).toBe('node-1') + expect(result.current[0]!.label).toBe('Local File Source') + expect(result.current[0]!.value).toBe('node-1') }) it('should return multiple options for multiple data source nodes', () => { @@ -1616,7 +1616,7 @@ describe('StepOneContent - All Datasource Types', () => { datasourceType={DatasourceType.onlineDocument} />, ) - expect(screen.getByTestId('online-documents-component')).toBeInTheDocument() + expect(screen.getByTestId('online-documents-component'))!.toBeInTheDocument() }) it('should render WebsiteCrawl when datasourceType is websiteCrawl', () => { @@ -1632,7 +1632,7 @@ describe('StepOneContent - All Datasource Types', () => { datasourceType={DatasourceType.websiteCrawl} />, ) - expect(screen.getByTestId('website-crawl-component')).toBeInTheDocument() + expect(screen.getByTestId('website-crawl-component'))!.toBeInTheDocument() }) it('should render OnlineDrive when datasourceType is onlineDrive', () => { @@ -1648,7 +1648,7 @@ describe('StepOneContent - All Datasource Types', () => { datasourceType={DatasourceType.onlineDrive} />, ) - expect(screen.getByTestId('online-drive-component')).toBeInTheDocument() + expect(screen.getByTestId('online-drive-component'))!.toBeInTheDocument() }) it('should render LocalFile when datasourceType is localFile', () => { @@ -1659,7 +1659,7 @@ describe('StepOneContent - All Datasource Types', () => { datasourceType={DatasourceType.localFile} />, ) - expect(screen.getByTestId('local-file-component')).toBeInTheDocument() + expect(screen.getByTestId('local-file-component'))!.toBeInTheDocument() }) }) @@ -1690,7 +1690,8 @@ describe('StepTwoPreview - File List Mapping', () => { ) // ChunkPreview should be rendered - expect(screen.getByTestId('chunk-preview')).toBeInTheDocument() + // ChunkPreview should be rendered + expect(screen.getByTestId('chunk-preview'))!.toBeInTheDocument() }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/__tests__/step-indicator.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/__tests__/step-indicator.spec.tsx index 7103dced26..7bffe3577e 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/__tests__/step-indicator.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/__tests__/step-indicator.spec.tsx @@ -19,14 +19,14 @@ describe('StepIndicator', () => { const { container } = render() const dots = container.querySelectorAll('.rounded-lg') // Second step (index 1) should be active - expect(dots[1].className).toContain('bg-state-accent-solid') - expect(dots[1].className).toContain('w-2') + expect(dots[1]!.className).toContain('bg-state-accent-solid') + expect(dots[1]!.className).toContain('w-2') }) it('should not apply active style to non-current steps', () => { const { container } = render() const dots = container.querySelectorAll('.rounded-lg') - expect(dots[1].className).toContain('bg-divider-solid') - expect(dots[2].className).toContain('bg-divider-solid') + expect(dots[1]!.className).toContain('bg-divider-solid') + expect(dots[2]!.className).toContain('bg-divider-solid') }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx index 64c5faac33..caffed6500 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx @@ -1,9 +1,9 @@ +import { Button } from '@langgenius/dify-ui/button' import { RiArrowRightLine } from '@remixicon/react' import * as React from 'react' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import Checkbox from '@/app/components/base/checkbox' -import { Button } from '@/app/components/base/ui/button' import Link from '@/next/link' import { useParams } from '@/next/navigation' diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/__tests__/index.spec.tsx index 0ac2dfce20..78542ad522 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/__tests__/index.spec.tsx @@ -129,7 +129,7 @@ describe('DatasourceIcon', () => { it('should render without crashing', () => { const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should render icon with background image', () => { @@ -138,15 +138,16 @@ describe('DatasourceIcon', () => { const { container } = render() const iconDiv = container.querySelector('[style*="background-image"]') - expect(iconDiv).toHaveStyle({ backgroundImage: `url(${iconUrl})` }) + expect(iconDiv)!.toHaveStyle({ backgroundImage: `url(${iconUrl})` }) }) it('should render with default size (sm)', () => { const { container } = render() // Assert - Default size is 'sm' which maps to 'w-5 h-5' - expect(container.firstChild).toHaveClass('w-5') - expect(container.firstChild).toHaveClass('h-5') + // Assert - Default size is 'sm' which maps to 'w-5 h-5' + expect(container.firstChild)!.toHaveClass('w-5') + expect(container.firstChild)!.toHaveClass('h-5') }) }) @@ -157,9 +158,9 @@ describe('DatasourceIcon', () => { , ) - expect(container.firstChild).toHaveClass('w-4') - expect(container.firstChild).toHaveClass('h-4') - expect(container.firstChild).toHaveClass('rounded-[5px]') + expect(container.firstChild)!.toHaveClass('w-4') + expect(container.firstChild)!.toHaveClass('h-4') + expect(container.firstChild)!.toHaveClass('rounded-[5px]') }) it('should render with sm size', () => { @@ -167,9 +168,9 @@ describe('DatasourceIcon', () => { , ) - expect(container.firstChild).toHaveClass('w-5') - expect(container.firstChild).toHaveClass('h-5') - expect(container.firstChild).toHaveClass('rounded-md') + expect(container.firstChild)!.toHaveClass('w-5') + expect(container.firstChild)!.toHaveClass('h-5') + expect(container.firstChild)!.toHaveClass('rounded-md') }) it('should render with md size', () => { @@ -177,9 +178,9 @@ describe('DatasourceIcon', () => { , ) - expect(container.firstChild).toHaveClass('w-6') - expect(container.firstChild).toHaveClass('h-6') - expect(container.firstChild).toHaveClass('rounded-lg') + expect(container.firstChild)!.toHaveClass('w-6') + expect(container.firstChild)!.toHaveClass('h-6') + expect(container.firstChild)!.toHaveClass('rounded-lg') }) }) @@ -189,7 +190,7 @@ describe('DatasourceIcon', () => { , ) - expect(container.firstChild).toHaveClass('custom-class') + expect(container.firstChild)!.toHaveClass('custom-class') }) it('should merge custom className with default classes', () => { @@ -197,9 +198,9 @@ describe('DatasourceIcon', () => { , ) - expect(container.firstChild).toHaveClass('custom-class') - expect(container.firstChild).toHaveClass('w-5') - expect(container.firstChild).toHaveClass('h-5') + expect(container.firstChild)!.toHaveClass('custom-class') + expect(container.firstChild)!.toHaveClass('w-5') + expect(container.firstChild)!.toHaveClass('h-5') }) }) @@ -208,7 +209,7 @@ describe('DatasourceIcon', () => { const { container } = render() const iconDiv = container.querySelector('[style*="background-image"]') - expect(iconDiv).toHaveStyle({ backgroundImage: 'url()' }) + expect(iconDiv)!.toHaveStyle({ backgroundImage: 'url()' }) }) it('should handle special characters in iconUrl', () => { @@ -217,7 +218,7 @@ describe('DatasourceIcon', () => { const { container } = render() const iconDiv = container.querySelector('[style*="background-image"]') - expect(iconDiv).toHaveStyle({ backgroundImage: `url(${iconUrl})` }) + expect(iconDiv)!.toHaveStyle({ backgroundImage: `url(${iconUrl})` }) }) it('should handle data URL as iconUrl', () => { @@ -226,7 +227,7 @@ describe('DatasourceIcon', () => { const { container } = render() const iconDiv = container.querySelector('[style*="background-image"]') - expect(iconDiv).toBeInTheDocument() + expect(iconDiv)!.toBeInTheDocument() }) }) }) @@ -235,25 +236,26 @@ describe('DatasourceIcon', () => { it('should have flex container classes', () => { const { container } = render() - expect(container.firstChild).toHaveClass('flex') - expect(container.firstChild).toHaveClass('items-center') - expect(container.firstChild).toHaveClass('justify-center') + expect(container.firstChild)!.toHaveClass('flex') + expect(container.firstChild)!.toHaveClass('items-center') + expect(container.firstChild)!.toHaveClass('justify-center') }) it('should have shadow-xs class from size map', () => { const { container } = render() // Assert - Default size 'sm' has shadow-xs - expect(container.firstChild).toHaveClass('shadow-xs') + // Assert - Default size 'sm' has shadow-xs + expect(container.firstChild)!.toHaveClass('shadow-xs') }) it('should have inner div with bg-cover class', () => { const { container } = render() const innerDiv = container.querySelector('.bg-cover') - expect(innerDiv).toBeInTheDocument() - expect(innerDiv).toHaveClass('bg-center') - expect(innerDiv).toHaveClass('rounded-md') + expect(innerDiv)!.toBeInTheDocument() + expect(innerDiv)!.toHaveClass('bg-center') + expect(innerDiv)!.toHaveClass('rounded-md') }) }) }) @@ -519,13 +521,13 @@ describe('OptionCard', () => { it('should render without crashing', () => { renderWithProviders() - expect(screen.getByText('Test Option')).toBeInTheDocument() + expect(screen.getByText('Test Option'))!.toBeInTheDocument() }) it('should render label text', () => { renderWithProviders() - expect(screen.getByText('Custom Label')).toBeInTheDocument() + expect(screen.getByText('Custom Label'))!.toBeInTheDocument() }) it('should render DatasourceIcon component', () => { @@ -533,7 +535,7 @@ describe('OptionCard', () => { // Assert - DatasourceIcon container should exist const iconContainer = container.querySelector('.size-8') - expect(iconContainer).toBeInTheDocument() + expect(iconContainer)!.toBeInTheDocument() }) it('should set title attribute for label truncation', () => { @@ -542,7 +544,7 @@ describe('OptionCard', () => { renderWithProviders() const labelElement = screen.getByText(longLabel) - expect(labelElement).toHaveAttribute('title', longLabel) + expect(labelElement)!.toHaveAttribute('title', longLabel) }) }) @@ -554,8 +556,8 @@ describe('OptionCard', () => { ) const card = container.firstChild - expect(card).toHaveClass('border-components-option-card-option-selected-border') - expect(card).toHaveClass('bg-components-option-card-option-selected-bg') + expect(card)!.toHaveClass('border-components-option-card-option-selected-border') + expect(card)!.toHaveClass('bg-components-option-card-option-selected-bg') }) it('should apply unselected styles when selected is false', () => { @@ -564,22 +566,22 @@ describe('OptionCard', () => { ) const card = container.firstChild - expect(card).toHaveClass('border-components-option-card-option-border') - expect(card).toHaveClass('bg-components-option-card-option-bg') + expect(card)!.toHaveClass('border-components-option-card-option-border') + expect(card)!.toHaveClass('bg-components-option-card-option-bg') }) it('should apply text-text-primary to label when selected', () => { renderWithProviders() const label = screen.getByText('Test Option') - expect(label).toHaveClass('text-text-primary') + expect(label)!.toHaveClass('text-text-primary') }) it('should apply text-text-secondary to label when not selected', () => { renderWithProviders() const label = screen.getByText('Test Option') - expect(label).toHaveClass('text-text-secondary') + expect(label)!.toHaveClass('text-text-secondary') }) }) @@ -593,7 +595,7 @@ describe('OptionCard', () => { // Act - Click on the label text's parent card const labelElement = screen.getByText('Test Option') const card = labelElement.closest('[class*="cursor-pointer"]') - expect(card).toBeInTheDocument() + expect(card)!.toBeInTheDocument() fireEvent.click(card!) expect(mockOnClick).toHaveBeenCalledTimes(1) @@ -607,11 +609,12 @@ describe('OptionCard', () => { // Act - Click on the label text's parent card should not throw const labelElement = screen.getByText('Test Option') const card = labelElement.closest('[class*="cursor-pointer"]') - expect(card).toBeInTheDocument() + expect(card)!.toBeInTheDocument() fireEvent.click(card!) // Assert - Component should still be rendered - expect(screen.getByText('Test Option')).toBeInTheDocument() + // Assert - Component should still be rendered + expect(screen.getByText('Test Option'))!.toBeInTheDocument() }) }) @@ -631,35 +634,35 @@ describe('OptionCard', () => { it('should have cursor-pointer class', () => { const { container } = renderWithProviders() - expect(container.firstChild).toHaveClass('cursor-pointer') + expect(container.firstChild)!.toHaveClass('cursor-pointer') }) it('should have flex layout classes', () => { const { container } = renderWithProviders() - expect(container.firstChild).toHaveClass('flex') - expect(container.firstChild).toHaveClass('items-center') - expect(container.firstChild).toHaveClass('gap-2') + expect(container.firstChild)!.toHaveClass('flex') + expect(container.firstChild)!.toHaveClass('items-center') + expect(container.firstChild)!.toHaveClass('gap-2') }) it('should have rounded-xl border', () => { const { container } = renderWithProviders() - expect(container.firstChild).toHaveClass('rounded-xl') - expect(container.firstChild).toHaveClass('border') + expect(container.firstChild)!.toHaveClass('rounded-xl') + expect(container.firstChild)!.toHaveClass('border') }) it('should have padding p-3', () => { const { container } = renderWithProviders() - expect(container.firstChild).toHaveClass('p-3') + expect(container.firstChild)!.toHaveClass('p-3') }) it('should have line-clamp-2 for label truncation', () => { renderWithProviders() const label = screen.getByText('Test Option') - expect(label).toHaveClass('line-clamp-2') + expect(label)!.toHaveClass('line-clamp-2') }) }) @@ -669,7 +672,7 @@ describe('OptionCard', () => { expect(OptionCard).toBeDefined() // React.memo wraps the component, so we check it renders correctly const { container } = renderWithProviders() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) }) }) @@ -698,27 +701,27 @@ describe('DataSourceOptions', () => { it('should render without crashing', () => { renderWithProviders() - expect(screen.getByText('Data Source 1')).toBeInTheDocument() - expect(screen.getByText('Data Source 2')).toBeInTheDocument() - expect(screen.getByText('Data Source 3')).toBeInTheDocument() + expect(screen.getByText('Data Source 1'))!.toBeInTheDocument() + expect(screen.getByText('Data Source 2'))!.toBeInTheDocument() + expect(screen.getByText('Data Source 3'))!.toBeInTheDocument() }) it('should render correct number of option cards', () => { renderWithProviders() - expect(screen.getByText('Data Source 1')).toBeInTheDocument() - expect(screen.getByText('Data Source 2')).toBeInTheDocument() - expect(screen.getByText('Data Source 3')).toBeInTheDocument() + expect(screen.getByText('Data Source 1'))!.toBeInTheDocument() + expect(screen.getByText('Data Source 2'))!.toBeInTheDocument() + expect(screen.getByText('Data Source 3'))!.toBeInTheDocument() }) it('should render with grid layout', () => { const { container } = renderWithProviders() const gridContainer = container.firstChild - expect(gridContainer).toHaveClass('grid') - expect(gridContainer).toHaveClass('w-full') - expect(gridContainer).toHaveClass('grid-cols-4') - expect(gridContainer).toHaveClass('gap-1') + expect(gridContainer)!.toHaveClass('grid') + expect(gridContainer)!.toHaveClass('w-full') + expect(gridContainer)!.toHaveClass('grid-cols-4') + expect(gridContainer)!.toHaveClass('gap-1') }) it('should render no option cards when options is empty', () => { @@ -728,16 +731,17 @@ describe('DataSourceOptions', () => { expect(screen.queryByText('Data Source')).not.toBeInTheDocument() // Grid container should still exist - expect(container.firstChild).toHaveClass('grid') + // Grid container should still exist + expect(container.firstChild)!.toHaveClass('grid') }) it('should render single option card when only one option exists', () => { - const singleOption = [createMockDatasourceOption(defaultNodes[0])] + const singleOption = [createMockDatasourceOption(defaultNodes[0]!)] mockUseDatasourceOptions.mockReturnValue(singleOption) renderWithProviders() - expect(screen.getByText('Data Source 1')).toBeInTheDocument() + expect(screen.getByText('Data Source 1'))!.toBeInTheDocument() expect(screen.queryByText('Data Source 2')).not.toBeInTheDocument() }) }) @@ -778,7 +782,7 @@ describe('DataSourceOptions', () => { // Assert - Check for selected styling on second card const cards = container.querySelectorAll('.rounded-xl.border') - expect(cards[1]).toHaveClass('border-components-option-card-option-selected-border') + expect(cards[1])!.toHaveClass('border-components-option-card-option-selected-border') }) it('should show no selection when datasourceNodeId is empty', () => { @@ -816,7 +820,7 @@ describe('DataSourceOptions', () => { // Assert initial selection let cards = container.querySelectorAll('.rounded-xl.border') - expect(cards[0]).toHaveClass('border-components-option-card-option-selected-border') + expect(cards[0])!.toHaveClass('border-components-option-card-option-selected-border') // Act - Change selection rerender( @@ -831,7 +835,7 @@ describe('DataSourceOptions', () => { // Assert new selection cards = container.querySelectorAll('.rounded-xl.border') expect(cards[0]).not.toHaveClass('border-components-option-card-option-selected-border') - expect(cards[1]).toHaveClass('border-components-option-card-option-selected-border') + expect(cards[1])!.toHaveClass('border-components-option-card-option-selected-border') }) }) @@ -847,7 +851,8 @@ describe('DataSourceOptions', () => { ) // Assert - Component renders without error - expect(screen.getByText('Data Source 1')).toBeInTheDocument() + // Assert - Component renders without error + expect(screen.getByText('Data Source 1'))!.toBeInTheDocument() }) }) }) @@ -870,7 +875,7 @@ describe('DataSourceOptions', () => { expect(mockOnSelect).toHaveBeenCalledTimes(1) expect(mockOnSelect).toHaveBeenCalledWith({ nodeId: 'node-1', - nodeData: defaultOptions[0].data, + nodeData: defaultOptions[0]!.data, } satisfies Datasource) }) @@ -948,7 +953,8 @@ describe('DataSourceOptions', () => { ) // Get initial click handlers - expect(screen.getByText('Data Source 1')).toBeInTheDocument() + // Get initial click handlers + expect(screen.getByText('Data Source 1'))!.toBeInTheDocument() // Trigger clicks to test handlers work fireEvent.click(screen.getByText('Data Source 1')) @@ -1003,7 +1009,7 @@ describe('DataSourceOptions', () => { expect(mockOnSelect2).toHaveBeenCalledTimes(1) expect(mockOnSelect2).toHaveBeenCalledWith({ nodeId: 'node-3', - nodeData: defaultOptions[2].data, + nodeData: defaultOptions[2]!.data, }) }) @@ -1022,7 +1028,7 @@ describe('DataSourceOptions', () => { fireEvent.click(screen.getByText('Data Source 1')) expect(mockOnSelect).toHaveBeenCalledWith({ nodeId: 'node-1', - nodeData: defaultOptions[0].data, + nodeData: defaultOptions[0]!.data, }) // Act - Change options @@ -1045,8 +1051,8 @@ describe('DataSourceOptions', () => { // Assert - Callback receives new option data expect(mockOnSelect).toHaveBeenLastCalledWith({ - nodeId: newOptions[0].value, - nodeData: newOptions[0].data, + nodeId: newOptions[0]!.value, + nodeData: newOptions[0]!.data, }) }) }) @@ -1070,7 +1076,7 @@ describe('DataSourceOptions', () => { expect(mockOnSelect).toHaveBeenCalledTimes(1) expect(mockOnSelect).toHaveBeenCalledWith({ nodeId: 'node-2', - nodeData: defaultOptions[1].data, + nodeData: defaultOptions[1]!.data, } satisfies Datasource) }) @@ -1090,7 +1096,7 @@ describe('DataSourceOptions', () => { expect(mockOnSelect).toHaveBeenCalledTimes(1) expect(mockOnSelect).toHaveBeenCalledWith({ nodeId: 'node-1', - nodeData: defaultOptions[0].data, + nodeData: defaultOptions[0]!.data, }) }) @@ -1112,15 +1118,15 @@ describe('DataSourceOptions', () => { expect(mockOnSelect).toHaveBeenCalledTimes(3) expect(mockOnSelect).toHaveBeenNthCalledWith(1, { nodeId: 'node-1', - nodeData: defaultOptions[0].data, + nodeData: defaultOptions[0]!.data, }) expect(mockOnSelect).toHaveBeenNthCalledWith(2, { nodeId: 'node-2', - nodeData: defaultOptions[1].data, + nodeData: defaultOptions[1]!.data, }) expect(mockOnSelect).toHaveBeenNthCalledWith(3, { nodeId: 'node-3', - nodeData: defaultOptions[2].data, + nodeData: defaultOptions[2]!.data, }) }) }) @@ -1164,7 +1170,7 @@ describe('DataSourceOptions', () => { />, ) - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should not crash when datasourceNodeId is undefined', () => { @@ -1176,7 +1182,7 @@ describe('DataSourceOptions', () => { />, ) - expect(screen.getByText('Data Source 1')).toBeInTheDocument() + expect(screen.getByText('Data Source 1'))!.toBeInTheDocument() }) }) @@ -1202,7 +1208,7 @@ describe('DataSourceOptions', () => { renderWithProviders() - expect(screen.getByText('Minimal Option')).toBeInTheDocument() + expect(screen.getByText('Minimal Option'))!.toBeInTheDocument() }) }) @@ -1219,8 +1225,8 @@ describe('DataSourceOptions', () => { />, ) - expect(screen.getByText('Data Source 1')).toBeInTheDocument() - expect(screen.getByText('Data Source 50')).toBeInTheDocument() + expect(screen.getByText('Data Source 1'))!.toBeInTheDocument() + expect(screen.getByText('Data Source 50'))!.toBeInTheDocument() }) }) @@ -1243,7 +1249,8 @@ describe('DataSourceOptions', () => { ) // Assert - Special characters should be escaped/rendered safely - expect(screen.getByText('Data Source ')).toBeInTheDocument() + // Assert - Special characters should be escaped/rendered safely + expect(screen.getByText('Data Source '))!.toBeInTheDocument() }) it('should handle unicode characters in option labels', () => { @@ -1263,7 +1270,7 @@ describe('DataSourceOptions', () => { />, ) - expect(screen.getByText('数据源 📁 Source émoji')).toBeInTheDocument() + expect(screen.getByText('数据源 📁 Source émoji'))!.toBeInTheDocument() }) it('should handle empty string as option value', () => { @@ -1276,13 +1283,13 @@ describe('DataSourceOptions', () => { renderWithProviders() - expect(screen.getByText('Empty Value Option')).toBeInTheDocument() + expect(screen.getByText('Empty Value Option'))!.toBeInTheDocument() }) }) describe('Boundary Conditions', () => { it('should handle single option selection correctly', () => { - const singleOption = [createMockDatasourceOption(defaultNodes[0])] + const singleOption = [createMockDatasourceOption(defaultNodes[0]!)] mockUseDatasourceOptions.mockReturnValue(singleOption) const mockOnSelect = vi.fn() @@ -1327,7 +1334,7 @@ describe('DataSourceOptions', () => { const labels = screen.getAllByText('Duplicate Label') expect(labels).toHaveLength(2) - fireEvent.click(labels[1]) + fireEvent.click(labels[1]!) expect(mockOnSelect).toHaveBeenCalledWith({ nodeId: 'node-b', nodeData: expect.objectContaining({ plugin_id: 'plugin-b' }), @@ -1347,6 +1354,37 @@ describe('DataSourceOptions', () => { unmount() + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted + // Assert - No errors thrown, component cleanly unmounted // Assert - No errors thrown, component cleanly unmounted expect(screen.queryByText('Data Source 1')).not.toBeInTheDocument() }) @@ -1367,6 +1405,37 @@ describe('DataSourceOptions', () => { // Unmount during/after interaction unmount() + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw + // Assert - Should not throw // Assert - Should not throw expect(screen.queryByText('Data Source 1')).not.toBeInTheDocument() }) @@ -1392,7 +1461,7 @@ describe('DataSourceOptions', () => { const cards = container.querySelectorAll('.rounded-xl.border') expect(cards[0]).not.toHaveClass('border-components-option-card-option-selected-border') - expect(cards[1]).toHaveClass('border-components-option-card-option-selected-border') + expect(cards[1])!.toHaveClass('border-components-option-card-option-selected-border') expect(cards[2]).not.toHaveClass('border-components-option-card-option-selected-border') }) @@ -1427,7 +1496,7 @@ describe('DataSourceOptions', () => { />, ) - expect(screen.getByText('Data Source 1')).toBeInTheDocument() + expect(screen.getByText('Data Source 1'))!.toBeInTheDocument() }) it.each([ @@ -1449,7 +1518,7 @@ describe('DataSourceOptions', () => { ) if (count > 0) - expect(screen.getByText('Data Source 1')).toBeInTheDocument() + expect(screen.getByText('Data Source 1'))!.toBeInTheDocument() else expect(screen.queryByText('Data Source 1')).not.toBeInTheDocument() }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx index 8e3a29cfe5..51cf34d273 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx @@ -31,7 +31,7 @@ const DataSourceOptions = ({ useEffect(() => { if (options.length > 0 && !datasourceNodeId) - handelSelect(options[0].value) + handelSelect(options[0]!.value) }, []) return ( diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/__tests__/header.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/__tests__/header.spec.tsx index b736935cc8..a6abad358e 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/__tests__/header.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/__tests__/header.spec.tsx @@ -2,7 +2,7 @@ import { render, screen } from '@testing-library/react' import { describe, expect, it, vi } from 'vitest' import Header from '../header' -vi.mock('@/app/components/base/ui/button', () => ({ +vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: { children: React.ReactNode }) => , })) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx index d595a50fe1..49b0cb0789 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx @@ -97,8 +97,8 @@ describe('CredentialSelector', () => { render() - expect(screen.getByTestId('portal-root')).toBeInTheDocument() - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + expect(screen.getByTestId('portal-root'))!.toBeInTheDocument() + expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument() }) it('should render current credential name in trigger', () => { @@ -106,7 +106,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText('Credential 1')).toBeInTheDocument() + expect(screen.getByText('Credential 1'))!.toBeInTheDocument() }) it('should render credential icon with correct props', () => { @@ -116,8 +116,8 @@ describe('CredentialSelector', () => { // Assert - CredentialIcon renders an img when avatarUrl is provided const iconImg = container.querySelector('img') - expect(iconImg).toBeInTheDocument() - expect(iconImg).toHaveAttribute('src', 'https://example.com/avatar-1.png') + expect(iconImg)!.toBeInTheDocument() + expect(iconImg)!.toHaveAttribute('src', 'https://example.com/avatar-1.png') }) it('should render dropdown arrow icon', () => { @@ -126,7 +126,7 @@ describe('CredentialSelector', () => { const { container } = render() const svgIcon = container.querySelector('svg') - expect(svgIcon).toBeInTheDocument() + expect(svgIcon)!.toBeInTheDocument() }) it('should not render dropdown content initially', () => { @@ -146,7 +146,8 @@ describe('CredentialSelector', () => { fireEvent.click(trigger) // Assert - All credentials should be visible (current credential appears in both trigger and list) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Assert - All credentials should be visible (current credential appears in both trigger and list) + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() // 3 in dropdown list + 1 in trigger (current) = 4 total expect(screen.getAllByText(/Credential \d/)).toHaveLength(4) }) @@ -160,7 +161,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText('Credential 1')).toBeInTheDocument() + expect(screen.getByText('Credential 1'))!.toBeInTheDocument() }) it('should display second credential when currentCredentialId matches second', () => { @@ -168,7 +169,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText('Credential 2')).toBeInTheDocument() + expect(screen.getByText('Credential 2'))!.toBeInTheDocument() }) it('should display third credential when currentCredentialId matches third', () => { @@ -176,7 +177,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText('Credential 3')).toBeInTheDocument() + expect(screen.getByText('Credential 3'))!.toBeInTheDocument() }) it.each([ @@ -188,7 +189,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText(expectedName)).toBeInTheDocument() + expect(screen.getByText(expectedName))!.toBeInTheDocument() }) }) @@ -201,7 +202,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText('Test Credential')).toBeInTheDocument() + expect(screen.getByText('Test Credential'))!.toBeInTheDocument() }) it('should render multiple credentials in dropdown', () => { @@ -226,7 +227,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText('Test & Credential ')).toBeInTheDocument() + expect(screen.getByText('Test & Credential '))!.toBeInTheDocument() }) }) @@ -293,6 +294,37 @@ describe('CredentialSelector', () => { const props = createDefaultProps() render() + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed + // Assert - Initially closed // Assert - Initially closed expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument() @@ -301,7 +333,8 @@ describe('CredentialSelector', () => { fireEvent.click(trigger) // Assert - Now open - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Assert - Now open + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() }) it('should call onCredentialChange when clicking a credential item', () => { @@ -327,7 +360,7 @@ describe('CredentialSelector', () => { const trigger = screen.getByTestId('portal-trigger') fireEvent.click(trigger) - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + expect(screen.getByTestId('portal-content'))!.toBeInTheDocument() const credential2 = screen.getByText('Credential 2') fireEvent.click(credential2) @@ -347,7 +380,8 @@ describe('CredentialSelector', () => { fireEvent.click(trigger) // Assert - Should not crash - expect(trigger).toBeInTheDocument() + // Assert - Should not crash + expect(trigger)!.toBeInTheDocument() }) it('should allow selecting credentials multiple times', () => { @@ -504,7 +538,8 @@ describe('CredentialSelector', () => { render() // Assert - Should display credential 2 - expect(screen.getByText('Credential 2')).toBeInTheDocument() + // Assert - Should display credential 2 + expect(screen.getByText('Credential 2'))!.toBeInTheDocument() }) it('should update currentCredential when currentCredentialId changes', () => { @@ -512,13 +547,15 @@ describe('CredentialSelector', () => { const { rerender } = render() // Assert initial - expect(screen.getByText('Credential 1')).toBeInTheDocument() + // Assert initial + expect(screen.getByText('Credential 1'))!.toBeInTheDocument() // Act - Change currentCredentialId rerender() // Assert - Should now display credential 3 - expect(screen.getByText('Credential 3')).toBeInTheDocument() + // Assert - Should now display credential 3 + expect(screen.getByText('Credential 3'))!.toBeInTheDocument() }) it('should update currentCredential when credentials array changes', () => { @@ -526,7 +563,8 @@ describe('CredentialSelector', () => { const { rerender } = render() // Assert initial - expect(screen.getByText('Credential 1')).toBeInTheDocument() + // Assert initial + expect(screen.getByText('Credential 1'))!.toBeInTheDocument() // Act - Change credentials const newCredentials = [ @@ -535,7 +573,8 @@ describe('CredentialSelector', () => { rerender() // Assert - Should display updated name - expect(screen.getByText('Updated Credential 1')).toBeInTheDocument() + // Assert - Should display updated name + expect(screen.getByText('Updated Credential 1'))!.toBeInTheDocument() }) it('should return undefined currentCredential when id not found', () => { @@ -581,11 +620,12 @@ describe('CredentialSelector', () => { const { rerender } = render() // Assert initial - expect(screen.getByText('Credential 1')).toBeInTheDocument() + // Assert initial + expect(screen.getByText('Credential 1'))!.toBeInTheDocument() rerender() - expect(screen.getByText('Credential 2')).toBeInTheDocument() + expect(screen.getByText('Credential 2'))!.toBeInTheDocument() }) it('should re-render when credentials array reference changes', () => { @@ -598,7 +638,7 @@ describe('CredentialSelector', () => { ] rerender() - expect(screen.getByText('New Name 1')).toBeInTheDocument() + expect(screen.getByText('New Name 1'))!.toBeInTheDocument() }) it('should re-render when onCredentialChange reference changes', () => { @@ -631,7 +671,8 @@ describe('CredentialSelector', () => { render() // Assert - Should render without crashing - expect(screen.getByTestId('portal-root')).toBeInTheDocument() + // Assert - Should render without crashing + expect(screen.getByTestId('portal-root'))!.toBeInTheDocument() }) it('should handle undefined avatar_url in credential', () => { @@ -648,12 +689,14 @@ describe('CredentialSelector', () => { const { container } = render() // Assert - Should render without crashing and show first letter fallback - expect(screen.getByText('No Avatar Credential')).toBeInTheDocument() + // Assert - Should render without crashing and show first letter fallback + expect(screen.getByText('No Avatar Credential'))!.toBeInTheDocument() // When avatar_url is undefined, CredentialIcon shows first letter instead of img const iconImg = container.querySelector('img') expect(iconImg).not.toBeInTheDocument() // First letter 'N' should be displayed - expect(screen.getByText('N')).toBeInTheDocument() + // First letter 'N' should be displayed + expect(screen.getByText('N'))!.toBeInTheDocument() }) it('should handle empty string name in credential', () => { @@ -669,7 +712,8 @@ describe('CredentialSelector', () => { render() // Assert - Should render without crashing - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + // Assert - Should render without crashing + expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument() }) it('should handle very long credential name', () => { @@ -685,7 +729,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText(longName)).toBeInTheDocument() + expect(screen.getByText(longName))!.toBeInTheDocument() }) it('should handle special characters in credential name', () => { @@ -701,7 +745,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText(specialName)).toBeInTheDocument() + expect(screen.getByText(specialName))!.toBeInTheDocument() }) it('should handle numeric id as string', () => { @@ -716,7 +760,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText('Numeric ID Credential')).toBeInTheDocument() + expect(screen.getByText('Numeric ID Credential'))!.toBeInTheDocument() }) it('should handle large number of credentials', () => { @@ -728,7 +772,7 @@ describe('CredentialSelector', () => { render() - expect(screen.getByText('Credential 50')).toBeInTheDocument() + expect(screen.getByText('Credential 50'))!.toBeInTheDocument() }) it('should handle credential selection with duplicate names', () => { @@ -752,7 +796,7 @@ describe('CredentialSelector', () => { const sameNameElements = screen.getAllByText('Same Name') expect(sameNameElements.length).toBe(3) - fireEvent.click(sameNameElements[2]) + fireEvent.click(sameNameElements[2]!) // Assert - Should call with the correct id even with duplicate names expect(mockOnChange).toHaveBeenCalledWith('cred-2') @@ -787,7 +831,8 @@ describe('CredentialSelector', () => { render() // Assert - Should render without crashing - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + // Assert - Should render without crashing + expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument() }) }) @@ -799,7 +844,7 @@ describe('CredentialSelector', () => { render() const trigger = screen.getByTestId('portal-trigger') - expect(trigger).toHaveClass('overflow-hidden') + expect(trigger)!.toHaveClass('overflow-hidden') }) it('should apply grow class to trigger', () => { @@ -808,7 +853,7 @@ describe('CredentialSelector', () => { render() const trigger = screen.getByTestId('portal-trigger') - expect(trigger).toHaveClass('grow') + expect(trigger)!.toHaveClass('grow') }) it('should apply z-10 class to dropdown content', () => { @@ -819,7 +864,7 @@ describe('CredentialSelector', () => { fireEvent.click(trigger) const content = screen.getByTestId('portal-content') - expect(content).toHaveClass('z-10') + expect(content)!.toHaveClass('z-10') }) }) @@ -831,7 +876,8 @@ describe('CredentialSelector', () => { render() // Assert - Trigger should display the correct credential - expect(screen.getByText('Credential 2')).toBeInTheDocument() + // Assert - Trigger should display the correct credential + expect(screen.getByText('Credential 2'))!.toBeInTheDocument() }) it('should pass isOpen state to Trigger component', () => { @@ -840,14 +886,15 @@ describe('CredentialSelector', () => { // Assert - Initially closed const portalRoot = screen.getByTestId('portal-root') - expect(portalRoot).toHaveAttribute('data-open', 'false') + expect(portalRoot)!.toHaveAttribute('data-open', 'false') // Act - Open const trigger = screen.getByTestId('portal-trigger') fireEvent.click(trigger) // Assert - Now open - expect(portalRoot).toHaveAttribute('data-open', 'true') + // Assert - Now open + expect(portalRoot)!.toHaveAttribute('data-open', 'true') }) it('should pass credentials to List component', () => { @@ -899,7 +946,7 @@ describe('CredentialSelector', () => { const props = createDefaultProps() render() - expect(screen.getByTestId('portal-root')).toBeInTheDocument() + expect(screen.getByTestId('portal-root'))!.toBeInTheDocument() }) it('should configure PortalToFollowElem with offset mainAxis 4', () => { @@ -907,7 +954,7 @@ describe('CredentialSelector', () => { const props = createDefaultProps() render() - expect(screen.getByTestId('portal-root')).toBeInTheDocument() + expect(screen.getByTestId('portal-root'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx index 2f14b0f3b8..116b762277 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx @@ -29,7 +29,7 @@ const CredentialSelector = ({ useEffect(() => { if (!currentCredential && credentials.length) - onCredentialChange(credentials[0].id) + onCredentialChange(credentials[0]!.id) }, [currentCredential, credentials]) const handleCredentialChange = useCallback((credentialId: string) => { diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx index 4d54a04d1f..b162411f6c 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx @@ -31,7 +31,7 @@ const Item = ({ name={name} size={20} /> - + {name} { diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx index 034556d96f..a285946272 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx @@ -1,10 +1,10 @@ import type { CredentialSelectorProps } from './credential-selector' +import { Button } from '@langgenius/dify-ui/button' import { RiBookOpenLine, RiEqualizer2Line } from '@remixicon/react' import * as React from 'react' import { useTranslation } from 'react-i18next' import Divider from '@/app/components/base/divider' import Tooltip from '@/app/components/base/tooltip' -import { Button } from '@/app/components/base/ui/button' import CredentialSelector from './credential-selector' type HeaderProps = { diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx index cc531aad8f..dc20688e9e 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx @@ -18,7 +18,7 @@ const { mockNotify, mockToast } = vi.hoisted(() => { return { mockNotify, mockToast } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: mockToast, })) @@ -404,7 +404,7 @@ describe('useLocalFileUpload', () => { // Should only process first 5 files (batch_count_limit) const firstCall = mockSetLocalFileList.mock.calls[0] - expect(firstCall[0].length).toBeLessThanOrEqual(5) + expect(firstCall![0].length).toBeLessThanOrEqual(5) }) }) @@ -591,7 +591,8 @@ describe('useLocalFileUpload', () => { }) // dragover should not throw - expect(dropzone).toBeInTheDocument() + // dragover should not throw + expect(dropzone)!.toBeInTheDocument() }) it('should set dragging false on dragleave from drag overlay', async () => { @@ -715,7 +716,7 @@ describe('useLocalFileUpload', () => { await waitFor(() => { expect(mockSetLocalFileList).toHaveBeenCalled() // Should only have 1 file (limited by supportBatchUpload: false) - const callArgs = mockSetLocalFileList.mock.calls[0][0] + const callArgs = mockSetLocalFileList.mock.calls[0]![0] expect(callArgs.length).toBe(1) }) }) @@ -873,7 +874,7 @@ describe('useLocalFileUpload', () => { }) await waitFor(() => { - const callArgs = mockSetLocalFileList.mock.calls[0][0] + const callArgs = mockSetLocalFileList.mock.calls[0]![0] expect(callArgs[0].progress).toBe(PROGRESS_NOT_STARTED) }) }) @@ -899,7 +900,7 @@ describe('useLocalFileUpload', () => { await waitFor(() => { const calls = mockSetLocalFileList.mock.calls - const lastCall = calls[calls.length - 1][0] + const lastCall = calls[calls.length - 1]![0] expect(lastCall.some((f: FileItem) => f.progress === PROGRESS_ERROR)).toBe(true) }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx index 6be0e28d31..c193638a6a 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx @@ -37,8 +37,8 @@ const { mockToastError } = vi.hoisted(() => ({ mockToastError: vi.fn(), })) -vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@langgenius/dify-ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, toast: { diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx index 5051d343cb..22bc8a65e0 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx @@ -1,11 +1,11 @@ import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' import type { DataSourceNotionPageMap, DataSourceNotionWorkspace } from '@/models/common' import type { DataSourceNodeCompletedResponse, DataSourceNodeErrorResponse } from '@/types/pipeline' +import { toast } from '@langgenius/dify-ui/toast' import { useCallback, useEffect, useMemo } from 'react' import { useShallow } from 'zustand/react/shallow' import Loading from '@/app/components/base/loading' import SearchInput from '@/app/components/base/notion-page-selector/search-input' -import { toast } from '@/app/components/base/ui/toast' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useDocLink } from '@/context/i18n' @@ -115,7 +115,7 @@ const OnlineDocuments = ({ const handleSelectPages = useCallback((newSelectedPagesId: Set) => { const { setSelectedPagesId, setOnlineDocuments } = dataSourceStore.getState() - const selectedPages = Array.from(newSelectedPagesId).map(pageId => PagesMapAndSelectedPagesId[pageId]) + const selectedPages = Array.from(newSelectedPagesId).map(pageId => PagesMapAndSelectedPagesId[pageId]!) setSelectedPagesId(new Set(Array.from(newSelectedPagesId))) setOnlineDocuments(selectedPages) }, [dataSourceStore, PagesMapAndSelectedPagesId]) @@ -160,7 +160,7 @@ const OnlineDocuments = ({ checkedIds={selectedPagesId} disabledValue={new Set()} searchValue={searchValue} - list={documentsData[0].pages || []} + list={documentsData[0]!.pages || []} pagesMap={PagesMapAndSelectedPagesId} onSelect={handleSelectPages} canPreview={!isInPipeline} diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/index.spec.tsx index a6d5738e2d..04676156e6 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/index.spec.tsx @@ -83,7 +83,7 @@ describe('PageSelector', () => { render() - expect(screen.getByTestId('virtual-list')).toBeInTheDocument() + expect(screen.getByTestId('virtual-list'))!.toBeInTheDocument() }) it('should render empty state when list is empty', () => { @@ -94,7 +94,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument() expect(screen.queryByTestId('virtual-list')).not.toBeInTheDocument() }) @@ -110,8 +110,8 @@ describe('PageSelector', () => { render() - expect(screen.getByText('Page 1')).toBeInTheDocument() - expect(screen.getByText('Page 2')).toBeInTheDocument() + expect(screen.getByText('Page 1'))!.toBeInTheDocument() + expect(screen.getByText('Page 2'))!.toBeInTheDocument() }) it('should render checkboxes when isMultipleChoice is true', () => { @@ -119,7 +119,7 @@ describe('PageSelector', () => { render() - expect(getCheckbox()).toBeInTheDocument() + expect(getCheckbox())!.toBeInTheDocument() }) it('should render radio buttons when isMultipleChoice is false', () => { @@ -127,7 +127,7 @@ describe('PageSelector', () => { render() - expect(getRadio()).toBeInTheDocument() + expect(getRadio())!.toBeInTheDocument() }) it('should render preview button when canPreview is true', () => { @@ -135,7 +135,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument() }) it('should not render preview button when canPreview is false', () => { @@ -153,7 +153,7 @@ describe('PageSelector', () => { // Assert - NotionIcon renders svg when page_icon is null const notionIcon = document.querySelector('.h-5.w-5') - expect(notionIcon).toBeInTheDocument() + expect(notionIcon)!.toBeInTheDocument() }) it('should render page name', () => { @@ -164,7 +164,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('My Custom Page')).toBeInTheDocument() + expect(screen.getByText('My Custom Page'))!.toBeInTheDocument() }) }) @@ -181,7 +181,7 @@ describe('PageSelector', () => { render() const checkbox = getCheckbox() - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() expect(isCheckboxChecked(checkbox)).toBe(true) }) @@ -196,7 +196,7 @@ describe('PageSelector', () => { render() const checkbox = getCheckbox() - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() expect(isCheckboxChecked(checkbox)).toBe(false) }) @@ -206,7 +206,7 @@ describe('PageSelector', () => { render() const checkbox = getCheckbox() - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() expect(isCheckboxChecked(checkbox)).toBe(false) }) @@ -225,9 +225,9 @@ describe('PageSelector', () => { render() const checkboxes = getAllCheckboxes() - expect(isCheckboxChecked(checkboxes[0])).toBe(true) - expect(isCheckboxChecked(checkboxes[1])).toBe(false) - expect(isCheckboxChecked(checkboxes[2])).toBe(true) + expect(isCheckboxChecked(checkboxes[0]!)).toBe(true) + expect(isCheckboxChecked(checkboxes[1]!)).toBe(false) + expect(isCheckboxChecked(checkboxes[2]!)).toBe(true) }) }) @@ -243,7 +243,7 @@ describe('PageSelector', () => { render() const checkbox = getCheckbox() - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() expect(isCheckboxDisabled(checkbox)).toBe(true) }) @@ -258,7 +258,7 @@ describe('PageSelector', () => { render() const checkbox = getCheckbox() - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() expect(isCheckboxDisabled(checkbox)).toBe(false) }) @@ -276,8 +276,8 @@ describe('PageSelector', () => { render() const checkboxes = getAllCheckboxes() - expect(isCheckboxDisabled(checkboxes[0])).toBe(true) - expect(isCheckboxDisabled(checkboxes[1])).toBe(false) + expect(isCheckboxDisabled(checkboxes[0]!)).toBe(true) + expect(isCheckboxDisabled(checkboxes[1]!)).toBe(false) }) }) @@ -301,6 +301,37 @@ describe('PageSelector', () => { expect(screen.getAllByText('Apple Page').length).toBeGreaterThan(0) expect(screen.getAllByText('Apple Pie').length).toBeGreaterThan(0) // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" + // Banana Page is filtered out because it doesn't contain "Apple" expect(screen.queryByText('Banana Page')).not.toBeInTheDocument() }) @@ -314,7 +345,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument() }) it('should show all pages when searchValue is empty', () => { @@ -330,8 +361,8 @@ describe('PageSelector', () => { render() - expect(screen.getByText('Page 1')).toBeInTheDocument() - expect(screen.getByText('Page 2')).toBeInTheDocument() + expect(screen.getByText('Page 1'))!.toBeInTheDocument() + expect(screen.getByText('Page 2'))!.toBeInTheDocument() }) it('should show breadcrumbs when searchValue is present', () => { @@ -345,7 +376,8 @@ describe('PageSelector', () => { render() // Assert - page name should be visible - expect(screen.getByText('Grandchild 1')).toBeInTheDocument() + // Assert - page name should be visible + expect(screen.getByText('Grandchild 1'))!.toBeInTheDocument() }) it('should perform case-sensitive search', () => { @@ -374,7 +406,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument() }) it('should hide preview button when canPreview is false', () => { @@ -391,7 +423,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument() }) }) @@ -401,7 +433,7 @@ describe('PageSelector', () => { render() - expect(getCheckbox()).toBeInTheDocument() + expect(getCheckbox())!.toBeInTheDocument() expect(getRadio()).not.toBeInTheDocument() }) @@ -410,7 +442,7 @@ describe('PageSelector', () => { render() - expect(getRadio()).toBeInTheDocument() + expect(getRadio())!.toBeInTheDocument() expect(getCheckbox()).not.toBeInTheDocument() }) @@ -420,7 +452,7 @@ describe('PageSelector', () => { render() - expect(getCheckbox()).toBeInTheDocument() + expect(getCheckbox())!.toBeInTheDocument() }) }) @@ -449,7 +481,7 @@ describe('PageSelector', () => { render() fireEvent.click(getCheckbox()) - const calledSet = mockOnSelect.mock.calls[0][0] as Set + const calledSet = mockOnSelect.mock.calls[0]![0] as Set expect(calledSet.has('page-1')).toBe(true) }) }) @@ -498,13 +530,15 @@ describe('PageSelector', () => { const { rerender } = render() // Assert - Initial render - expect(screen.getByText('Page 1')).toBeInTheDocument() + // Assert - Initial render + expect(screen.getByText('Page 1'))!.toBeInTheDocument() // Rerender with new credential rerender() // Assert - Should still show pages (reset and rebuild) - expect(screen.getByText('Page 1')).toBeInTheDocument() + // Assert - Should still show pages (reset and rebuild) + expect(screen.getByText('Page 1'))!.toBeInTheDocument() }) }) }) @@ -521,7 +555,39 @@ describe('PageSelector', () => { render() // Assert - Only root level page should be visible initially - expect(screen.getByText(rootPage.page_name)).toBeInTheDocument() + // Assert - Only root level page should be visible initially + expect(screen.getByText(rootPage.page_name))!.toBeInTheDocument() + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded + // Child pages should not be visible until expanded // Child pages should not be visible until expanded expect(screen.queryByText(childPage1.page_name)).not.toBeInTheDocument() }) @@ -540,9 +606,9 @@ describe('PageSelector', () => { if (arrowButton) fireEvent.click(arrowButton) - expect(screen.getByText(rootPage.page_name)).toBeInTheDocument() - expect(screen.getByText(childPage1.page_name)).toBeInTheDocument() - expect(screen.getByText(childPage2.page_name)).toBeInTheDocument() + expect(screen.getByText(rootPage.page_name))!.toBeInTheDocument() + expect(screen.getByText(childPage1.page_name))!.toBeInTheDocument() + expect(screen.getByText(childPage2.page_name))!.toBeInTheDocument() }) it('should maintain currentPreviewPageId state', () => { @@ -560,7 +626,7 @@ describe('PageSelector', () => { render() const previewButtons = screen.getAllByText('common.dataSource.notion.selector.preview') - fireEvent.click(previewButtons[0]) + fireEvent.click(previewButtons[0]!) expect(mockOnPreview).toHaveBeenCalledWith('page-1') }) @@ -596,13 +662,14 @@ describe('PageSelector', () => { }) const { rerender } = render() - expect(screen.getByText('Page 1')).toBeInTheDocument() + expect(screen.getByText('Page 1'))!.toBeInTheDocument() // Change credential rerender() // Assert - Component should still render correctly - expect(screen.getByText('Page 1')).toBeInTheDocument() + // Assert - Component should still render correctly + expect(screen.getByText('Page 1'))!.toBeInTheDocument() }) it('should filter root pages correctly on initialization', () => { @@ -615,7 +682,8 @@ describe('PageSelector', () => { render() // Assert - Only root level pages visible - expect(screen.getByText(rootPage.page_name)).toBeInTheDocument() + // Assert - Only root level pages visible + expect(screen.getByText(rootPage.page_name))!.toBeInTheDocument() expect(screen.queryByText(childPage1.page_name)).not.toBeInTheDocument() }) @@ -633,7 +701,8 @@ describe('PageSelector', () => { render() // Assert - Orphan page should be visible at root level - expect(screen.getByText('Orphan Page')).toBeInTheDocument() + // Assert - Orphan page should be visible at root level + expect(screen.getByText('Orphan Page'))!.toBeInTheDocument() }) }) @@ -654,8 +723,9 @@ describe('PageSelector', () => { fireEvent.click(expandArrow) // Assert - Children should be visible - expect(screen.getByText(childPage1.page_name)).toBeInTheDocument() - expect(screen.getByText(childPage2.page_name)).toBeInTheDocument() + // Assert - Children should be visible + expect(screen.getByText(childPage1.page_name))!.toBeInTheDocument() + expect(screen.getByText(childPage2.page_name))!.toBeInTheDocument() }) it('should have stable handleToggle that collapses descendants', () => { @@ -675,6 +745,37 @@ describe('PageSelector', () => { fireEvent.click(expandArrow) } + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again + // Assert - Children should be hidden again // Assert - Children should be hidden again expect(screen.queryByText(childPage1.page_name)).not.toBeInTheDocument() expect(screen.queryByText(childPage2.page_name)).not.toBeInTheDocument() @@ -698,7 +799,7 @@ describe('PageSelector', () => { // Assert - onSelect should be called with the page and its descendants expect(mockOnSelect).toHaveBeenCalled() - const selectedSet = mockOnSelect.mock.calls[0][0] as Set + const selectedSet = mockOnSelect.mock.calls[0]![0] as Set expect(selectedSet.has('root-page')).toBe(true) }) @@ -752,7 +853,7 @@ describe('PageSelector', () => { // Assert - Tree structure should be built (verified by expand functionality) const expandArrow = document.querySelector('[class*="hover:bg-components-button-ghost-bg-hover"]') - expect(expandArrow).toBeInTheDocument() // Root page has children + expect(expandArrow)!.toBeInTheDocument() // Root page has children }) it('should recompute listMapWithChildrenAndDescendants when list changes', () => { @@ -763,7 +864,7 @@ describe('PageSelector', () => { }) const { rerender } = render() - expect(screen.getByText('Page 1')).toBeInTheDocument() + expect(screen.getByText('Page 1'))!.toBeInTheDocument() // Update with new list const newList = [ @@ -772,7 +873,7 @@ describe('PageSelector', () => { ] rerender() - expect(screen.getByText('Page 1')).toBeInTheDocument() + expect(screen.getByText('Page 1'))!.toBeInTheDocument() // Page 2 won't show because dataList state hasn't updated (only resets on credentialId change) }) @@ -793,7 +894,8 @@ describe('PageSelector', () => { rerender() // Assert - Should not throw - expect(screen.getByText('Page 1')).toBeInTheDocument() + // Assert - Should not throw + expect(screen.getByText('Page 1'))!.toBeInTheDocument() }) it('should handle empty list in memoization', () => { @@ -804,7 +906,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument() }) }) @@ -819,6 +921,37 @@ describe('PageSelector', () => { render() + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden + // Initially children are hidden // Initially children are hidden expect(screen.queryByText(childPage1.page_name)).not.toBeInTheDocument() @@ -827,7 +960,8 @@ describe('PageSelector', () => { fireEvent.click(expandArrow) // Children become visible - expect(screen.getByText(childPage1.page_name)).toBeInTheDocument() + // Children become visible + expect(screen.getByText(childPage1.page_name))!.toBeInTheDocument() }) it('should check/uncheck page when clicking checkbox', () => { @@ -873,11 +1007,11 @@ describe('PageSelector', () => { render() const radios = getAllRadios() - fireEvent.click(radios[1]) // Click on page-2 + fireEvent.click(radios[1]!) // Click on page-2 // Assert - Should clear page-1 and select page-2 expect(mockOnSelect).toHaveBeenCalled() - const selectedSet = mockOnSelect.mock.calls[0][0] as Set + const selectedSet = mockOnSelect.mock.calls[0]![0] as Set expect(selectedSet.has('page-2')).toBe(true) expect(selectedSet.has('page-1')).toBe(false) }) @@ -912,7 +1046,7 @@ describe('PageSelector', () => { // Assert - Only the clicked page should be selected (no descendants) expect(mockOnSelect).toHaveBeenCalled() - const selectedSet = mockOnSelect.mock.calls[0][0] as Set + const selectedSet = mockOnSelect.mock.calls[0]![0] as Set expect(selectedSet.size).toBe(1) expect(selectedSet.has('root-page')).toBe(true) }) @@ -927,7 +1061,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument() }) it('should handle null page_icon', () => { @@ -941,7 +1075,7 @@ describe('PageSelector', () => { // Assert - NotionIcon renders svg (RiFileTextLine) when page_icon is null const notionIcon = document.querySelector('.h-5.w-5') - expect(notionIcon).toBeInTheDocument() + expect(notionIcon)!.toBeInTheDocument() }) it('should handle page_icon with all properties', () => { @@ -956,7 +1090,8 @@ describe('PageSelector', () => { render() // Assert - NotionIcon renders the emoji - expect(screen.getByText('📄')).toBeInTheDocument() + // Assert - NotionIcon renders the emoji + expect(screen.getByText('📄'))!.toBeInTheDocument() }) it('should handle empty searchValue correctly', () => { @@ -964,7 +1099,7 @@ describe('PageSelector', () => { render() - expect(screen.getByTestId('virtual-list')).toBeInTheDocument() + expect(screen.getByTestId('virtual-list'))!.toBeInTheDocument() }) it('should handle special characters in page name', () => { @@ -976,7 +1111,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('Test ')).toBeInTheDocument() + expect(screen.getByText('Test '))!.toBeInTheDocument() }) it('should handle unicode characters in page name', () => { @@ -988,7 +1123,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText('测试页面 🔍 привет')).toBeInTheDocument() + expect(screen.getByText('测试页面 🔍 привет'))!.toBeInTheDocument() }) it('should handle very long page names', () => { @@ -1001,7 +1136,7 @@ describe('PageSelector', () => { render() - expect(screen.getByText(longName)).toBeInTheDocument() + expect(screen.getByText(longName))!.toBeInTheDocument() }) it('should handle deeply nested hierarchy', () => { @@ -1027,7 +1162,8 @@ describe('PageSelector', () => { render() // Assert - Only root level visible - expect(screen.getByText('Level 0')).toBeInTheDocument() + // Assert - Only root level visible + expect(screen.getByText('Level 0'))!.toBeInTheDocument() expect(screen.queryByText('Level 1')).not.toBeInTheDocument() }) @@ -1048,7 +1184,8 @@ describe('PageSelector', () => { render() // Assert - Should render the orphan page at root level - expect(screen.getByText('Orphan Page')).toBeInTheDocument() + // Assert - Should render the orphan page at root level + expect(screen.getByText('Orphan Page'))!.toBeInTheDocument() }) it('should handle empty checkedIds Set', () => { @@ -1057,7 +1194,7 @@ describe('PageSelector', () => { render() const checkbox = getCheckbox() - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() expect(isCheckboxChecked(checkbox)).toBe(false) }) @@ -1067,7 +1204,7 @@ describe('PageSelector', () => { render() const checkbox = getCheckbox() - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() expect(isCheckboxDisabled(checkbox)).toBe(false) }) @@ -1112,16 +1249,16 @@ describe('PageSelector', () => { render() - expect(screen.getByTestId('virtual-list')).toBeInTheDocument() + expect(screen.getByTestId('virtual-list'))!.toBeInTheDocument() if (propVariation.canPreview) - expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument() else expect(screen.queryByText('common.dataSource.notion.selector.preview')).not.toBeInTheDocument() if (propVariation.isMultipleChoice) - expect(getCheckbox()).toBeInTheDocument() + expect(getCheckbox())!.toBeInTheDocument() else - expect(getRadio()).toBeInTheDocument() + expect(getRadio())!.toBeInTheDocument() }) it('should handle all default prop values', () => { @@ -1140,8 +1277,9 @@ describe('PageSelector', () => { render() // Assert - Defaults should be applied - expect(getCheckbox()).toBeInTheDocument() - expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument() + // Assert - Defaults should be applied + expect(getCheckbox())!.toBeInTheDocument() + expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument() }) }) @@ -1166,8 +1304,8 @@ describe('PageSelector', () => { recursivePushInParentDescendants(pagesMap, listTreeMap, childEntry, childEntry) expect(listTreeMap.parent).toBeDefined() - expect(listTreeMap.parent.children.has('child')).toBe(true) - expect(listTreeMap.parent.descendants.has('child')).toBe(true) + expect(listTreeMap.parent!.children.has('child')).toBe(true) + expect(listTreeMap.parent!.descendants.has('child')).toBe(true) expect(childEntry.depth).toBe(1) expect(childEntry.ancestors).toContain('Parent') }) @@ -1274,8 +1412,8 @@ describe('PageSelector', () => { expect(l2Entry.depth).toBe(2) expect(l2Entry.ancestors).toEqual(['Level 0', 'Level 1']) - expect(listTreeMap.l1.children.has('l2')).toBe(true) - expect(listTreeMap.l0.descendants.has('l2')).toBe(true) + expect(listTreeMap.l1!.children.has('l2')).toBe(true) + expect(listTreeMap.l0!.descendants.has('l2')).toBe(true) }) it('should update existing parent entry', () => { @@ -1329,7 +1467,7 @@ describe('PageSelector', () => { // Assert - Item should have preview styling class const itemContainer = screen.getByText('Test Page').closest('[class*="group"]') - expect(itemContainer).toHaveClass('bg-state-base-hover') + expect(itemContainer)!.toHaveClass('bg-state-base-hover') }) it('should show arrow for pages with children', () => { @@ -1343,7 +1481,7 @@ describe('PageSelector', () => { // Assert - Root page should have expand arrow const arrowContainer = document.querySelector('[class*="hover:bg-components-button-ghost-bg-hover"]') - expect(arrowContainer).toBeInTheDocument() + expect(arrowContainer)!.toBeInTheDocument() }) it('should not show arrow for leaf pages', () => { diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/utils.spec.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/utils.spec.ts index 2a081ef418..a7175a47de 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/utils.spec.ts +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/utils.spec.ts @@ -28,11 +28,11 @@ describe('recursivePushInParentDescendants', () => { child1: makePageEntry({ page_id: 'child1', parent_id: 'parent1', page_name: 'Child' }), } - recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child1, listTreeMap.child1) + recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child1!, listTreeMap.child1!) expect(listTreeMap.parent1).toBeDefined() - expect(listTreeMap.parent1.children.has('child1')).toBe(true) - expect(listTreeMap.parent1.descendants.has('child1')).toBe(true) + expect(listTreeMap.parent1!.children.has('child1')).toBe(true) + expect(listTreeMap.parent1!.descendants.has('child1')).toBe(true) }) it('should recursively populate ancestors for deeply nested items', () => { @@ -47,11 +47,11 @@ describe('recursivePushInParentDescendants', () => { child: makePageEntry({ page_id: 'child', parent_id: 'parent', page_name: 'Child' }), } - recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child, listTreeMap.child) + recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child!, listTreeMap.child!) - expect(listTreeMap.child.depth).toBe(2) - expect(listTreeMap.child.ancestors).toContain('Grandparent') - expect(listTreeMap.child.ancestors).toContain('Parent') + expect(listTreeMap.child!.depth).toBe(2) + expect(listTreeMap.child!.ancestors).toContain('Grandparent') + expect(listTreeMap.child!.ancestors).toContain('Parent') }) it('should do nothing for root parent', () => { @@ -63,7 +63,7 @@ describe('recursivePushInParentDescendants', () => { root_child: makePageEntry({ page_id: 'root_child', parent_id: 'root', page_name: 'Root Child' }), } - recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.root_child, listTreeMap.root_child) + recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.root_child!, listTreeMap.root_child!) // No new entries should be added since parent is root expect(Object.keys(listTreeMap)).toEqual(['root_child']) @@ -76,7 +76,7 @@ describe('recursivePushInParentDescendants', () => { // Should not throw recursivePushInParentDescendants(pagesMap, listTreeMap, current, current) - expect(listTreeMap.orphan.depth).toBe(0) + expect(listTreeMap.orphan!.depth).toBe(0) }) it('should add to existing parent entry when parent already in tree', () => { @@ -91,10 +91,10 @@ describe('recursivePushInParentDescendants', () => { child2: makePageEntry({ page_id: 'child2', parent_id: 'parent', page_name: 'Child2' }), } - recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child2, listTreeMap.child2) + recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child2!, listTreeMap.child2!) - expect(listTreeMap.parent.children.has('child2')).toBe(true) - expect(listTreeMap.parent.descendants.has('child2')).toBe(true) - expect(listTreeMap.parent.children.has('child1')).toBe(true) + expect(listTreeMap.parent!.children.has('child2')).toBe(true) + expect(listTreeMap.parent!.descendants.has('child2')).toBe(true) + expect(listTreeMap.parent!.children.has('child1')).toBe(true) }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx index 4f555f3e1f..2d4eb81212 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx @@ -11,7 +11,7 @@ const Title = ({ const { t } = useTranslation() return ( -
+
{t('onlineDocument.pageSelectorTitle', { ns: 'datasetPipeline', name })}
) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx index 7c1941afd9..c8fdf49fd1 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx @@ -49,8 +49,8 @@ const { mockToastError } = vi.hoisted(() => ({ mockToastError: vi.fn(), })) -vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@langgenius/dify-ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, toast: { @@ -259,8 +259,8 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('header')).toBeInTheDocument() - expect(screen.getByTestId('file-list')).toBeInTheDocument() + expect(screen.getByTestId('header'))!.toBeInTheDocument() + expect(screen.getByTestId('file-list'))!.toBeInTheDocument() }) it('should render Header with correct props', () => { @@ -271,9 +271,9 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('header-doc-title')).toHaveTextContent('Docs') - expect(screen.getByTestId('header-plugin-name')).toHaveTextContent('My Online Drive') - expect(screen.getByTestId('header-credential-id')).toHaveTextContent('cred-123') + expect(screen.getByTestId('header-doc-title'))!.toHaveTextContent('Docs') + expect(screen.getByTestId('header-plugin-name'))!.toHaveTextContent('My Online Drive') + expect(screen.getByTestId('header-credential-id'))!.toHaveTextContent('cred-123') }) it('should render FileList with correct props', () => { @@ -290,11 +290,11 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list')).toBeInTheDocument() - expect(screen.getByTestId('file-list-keywords')).toHaveTextContent('search-term') - expect(screen.getByTestId('file-list-breadcrumbs')).toHaveTextContent('folder1/folder2') - expect(screen.getByTestId('file-list-bucket')).toHaveTextContent('my-bucket') - expect(screen.getByTestId('file-list-selected-count')).toHaveTextContent('2') + expect(screen.getByTestId('file-list'))!.toBeInTheDocument() + expect(screen.getByTestId('file-list-keywords'))!.toHaveTextContent('search-term') + expect(screen.getByTestId('file-list-breadcrumbs'))!.toHaveTextContent('folder1/folder2') + expect(screen.getByTestId('file-list-bucket'))!.toHaveTextContent('my-bucket') + expect(screen.getByTestId('file-list-selected-count'))!.toHaveTextContent('2') }) it('should pass docLink with correct path to Header', () => { @@ -371,7 +371,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('header-plugin-name')).toHaveTextContent('Custom Online Drive') + expect(screen.getByTestId('header-plugin-name'))!.toHaveTextContent('Custom Online Drive') }) }) @@ -411,7 +411,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-is-in-pipeline')).toHaveTextContent('true') + expect(screen.getByTestId('file-list-is-in-pipeline'))!.toHaveTextContent('true') }) }) @@ -421,7 +421,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-support-batch')).toHaveTextContent('true') + expect(screen.getByTestId('file-list-support-batch'))!.toHaveTextContent('true') }) it('should pass supportBatchUpload false to FileList when supportBatchUpload is false', () => { @@ -429,7 +429,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-support-batch')).toHaveTextContent('false') + expect(screen.getByTestId('file-list-support-batch'))!.toHaveTextContent('false') }) it.each([ @@ -441,7 +441,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-support-batch')).toHaveTextContent(expected) + expect(screen.getByTestId('file-list-support-batch'))!.toHaveTextContent(expected) }) }) @@ -504,7 +504,7 @@ describe('OnlineDrive', () => { render() await waitFor(() => { - expect(screen.getByTestId('file-list-loading')).toHaveTextContent('true') + expect(screen.getByTestId('file-list-loading'))!.toHaveTextContent('true') }) }) @@ -566,7 +566,8 @@ describe('OnlineDrive', () => { render() // Assert - filteredOnlineDriveFileList should have 2 items matching 'test' - expect(screen.getByTestId('file-list-count')).toHaveTextContent('2') + // Assert - filteredOnlineDriveFileList should have 2 items matching 'test' + expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('2') }) it('should return all files when keywords is empty', () => { @@ -580,7 +581,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-count')).toHaveTextContent('3') + expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('3') }) it('should filter files case-insensitively', () => { @@ -594,7 +595,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-count')).toHaveTextContent('2') + expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('2') }) }) @@ -932,7 +933,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('header-credentials-count')).toHaveTextContent('0') + expect(screen.getByTestId('header-credentials-count'))!.toHaveTextContent('0') }) it('should handle undefined credentials data', () => { @@ -943,7 +944,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('header-credentials-count')).toHaveTextContent('0') + expect(screen.getByTestId('header-credentials-count'))!.toHaveTextContent('0') }) it('should handle undefined pipelineId', async () => { @@ -969,7 +970,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-count')).toHaveTextContent('0') + expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('0') }) it('should handle empty breadcrumbs', () => { @@ -978,7 +979,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-breadcrumbs')).toHaveTextContent('') + expect(screen.getByTestId('file-list-breadcrumbs'))!.toHaveTextContent('') }) it('should handle empty bucket', () => { @@ -987,7 +988,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-bucket')).toHaveTextContent('') + expect(screen.getByTestId('file-list-bucket'))!.toHaveTextContent('') }) it('should handle special characters in keywords', () => { @@ -1001,7 +1002,8 @@ describe('OnlineDrive', () => { render() // Assert - Should find file with special characters - expect(screen.getByTestId('file-list-count')).toHaveTextContent('1') + // Assert - Should find file with special characters + expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('1') }) it('should handle very long file names', () => { @@ -1013,7 +1015,7 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('file-list-count')).toHaveTextContent('1') + expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('1') }) it('should handle bucket list initiation response', async () => { @@ -1051,10 +1053,10 @@ describe('OnlineDrive', () => { render() - expect(screen.getByTestId('header')).toBeInTheDocument() - expect(screen.getByTestId('file-list')).toBeInTheDocument() - expect(screen.getByTestId('file-list-is-in-pipeline')).toHaveTextContent(String(propVariation.isInPipeline)) - expect(screen.getByTestId('file-list-support-batch')).toHaveTextContent(String(propVariation.supportBatchUpload)) + expect(screen.getByTestId('header'))!.toBeInTheDocument() + expect(screen.getByTestId('file-list'))!.toBeInTheDocument() + expect(screen.getByTestId('file-list-is-in-pipeline'))!.toHaveTextContent(String(propVariation.isInPipeline)) + expect(screen.getByTestId('file-list-support-batch'))!.toHaveTextContent(String(propVariation.supportBatchUpload)) }) it.each([ @@ -1117,7 +1119,7 @@ describe('Header', () => { render(
) - expect(screen.getByText('Documentation')).toBeInTheDocument() + expect(screen.getByText('Documentation'))!.toBeInTheDocument() }) it('should render doc link with correct href', () => { @@ -1129,9 +1131,9 @@ describe('Header', () => { render(
) const link = screen.getByRole('link') - expect(link).toHaveAttribute('href', 'https://custom-docs.com/path') - expect(link).toHaveAttribute('target', '_blank') - expect(link).toHaveAttribute('rel', 'noopener noreferrer') + expect(link)!.toHaveAttribute('href', 'https://custom-docs.com/path') + expect(link)!.toHaveAttribute('target', '_blank') + expect(link)!.toHaveAttribute('rel', 'noopener noreferrer') }) it('should render doc title text', () => { @@ -1139,7 +1141,7 @@ describe('Header', () => { render(
) - expect(screen.getByText('My Documentation Title')).toBeInTheDocument() + expect(screen.getByText('My Documentation Title'))!.toBeInTheDocument() }) it('should render configuration button', () => { @@ -1147,7 +1149,7 @@ describe('Header', () => { render(
) - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByRole('button'))!.toBeInTheDocument() }) }) @@ -1164,7 +1166,7 @@ describe('Header', () => { render(
) if (docTitle) - expect(screen.getByText(docTitle)).toBeInTheDocument() + expect(screen.getByText(docTitle))!.toBeInTheDocument() }) }) @@ -1178,7 +1180,7 @@ describe('Header', () => { render(
) - expect(screen.getByRole('link')).toHaveAttribute('href', docLink) + expect(screen.getByRole('link'))!.toHaveAttribute('href', docLink) }) }) @@ -1209,7 +1211,7 @@ describe('Header', () => { render(
) const titleSpan = screen.getByTitle('Accessible Title') - expect(titleSpan).toBeInTheDocument() + expect(titleSpan)!.toBeInTheDocument() }) }) }) @@ -1437,10 +1439,10 @@ describe('utils', () => { const result = convertOnlineDriveData(data, [], 'my-bucket') expect(result.fileList).toHaveLength(4) - expect(result.fileList[0].type).toBe(OnlineDriveFileType.folder) - expect(result.fileList[1].type).toBe(OnlineDriveFileType.file) - expect(result.fileList[2].type).toBe(OnlineDriveFileType.folder) - expect(result.fileList[3].type).toBe(OnlineDriveFileType.file) + expect(result.fileList[0]!.type).toBe(OnlineDriveFileType.folder) + expect(result.fileList[1]!.type).toBe(OnlineDriveFileType.file) + expect(result.fileList[2]!.type).toBe(OnlineDriveFileType.folder) + expect(result.fileList[3]!.type).toBe(OnlineDriveFileType.file) }) }) @@ -1539,7 +1541,7 @@ describe('utils', () => { const result = convertOnlineDriveData(data, [], 'my-bucket') - expect(result.fileList[0].size).toBe(0) + expect(result.fileList[0]!.size).toBe(0) }) it('should handle files with very large size', () => { @@ -1555,7 +1557,7 @@ describe('utils', () => { const result = convertOnlineDriveData(data, [], 'my-bucket') - expect(result.fileList[0].size).toBe(largeSize) + expect(result.fileList[0]!.size).toBe(largeSize) }) it('should handle files with special characters in name', () => { @@ -1574,9 +1576,9 @@ describe('utils', () => { const result = convertOnlineDriveData(data, [], 'my-bucket') - expect(result.fileList[0].name).toBe('file[1] (copy).txt') - expect(result.fileList[1].name).toBe('doc-with-dash_and_underscore.pdf') - expect(result.fileList[2].name).toBe('file with spaces.txt') + expect(result.fileList[0]!.name).toBe('file[1] (copy).txt') + expect(result.fileList[1]!.name).toBe('doc-with-dash_and_underscore.pdf') + expect(result.fileList[2]!.name).toBe('file with spaces.txt') }) it('should handle complex next_page_parameters', () => { diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/utils.spec.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/utils.spec.ts index 7c5761be8a..9ac2ef9f89 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/utils.spec.ts +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/utils.spec.ts @@ -93,10 +93,10 @@ describe('online-drive utils', () => { const result = convertOnlineDriveData(data, [], 'bucket-1') expect(result.fileList).toHaveLength(2) - expect(result.fileList[0].type).toBe(OnlineDriveFileType.file) - expect(result.fileList[0].size).toBe(100) - expect(result.fileList[1].type).toBe(OnlineDriveFileType.folder) - expect(result.fileList[1].size).toBeUndefined() + expect(result.fileList[0]!.type).toBe(OnlineDriveFileType.file) + expect(result.fileList[0]!.size).toBe(100) + expect(result.fileList[1]!.type).toBe(OnlineDriveFileType.folder) + expect(result.fileList[1]!.size).toBeUndefined() expect(result.isTruncated).toBe(true) expect(result.nextPageParameters).toEqual({ token: 'next' }) expect(result.hasBucket).toBe(true) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/connect/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/connect/index.tsx index 5b1b0a6b1a..6a7190161d 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/connect/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/connect/index.tsx @@ -1,7 +1,7 @@ import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' +import { Button } from '@langgenius/dify-ui/button' import { useTranslation } from 'react-i18next' import { Icon3Dots } from '@/app/components/base/icons/src/vender/line/others' -import { Button } from '@/app/components/base/ui/button' import BlockIcon from '@/app/components/workflow/block-icon' import { useToolIcon } from '@/app/components/workflow/hooks' import { BlockEnum } from '@/app/components/workflow/types' diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/__tests__/index.spec.tsx index 07308361ad..dcb1922fe9 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/__tests__/index.spec.tsx @@ -60,7 +60,8 @@ describe('Header', () => { render(
) // Assert - search input should be visible - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + // Assert - search input should be visible + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should render with correct container styles', () => { @@ -70,12 +71,12 @@ describe('Header', () => { // Assert - container should have correct class names const wrapper = container.firstChild as HTMLElement - expect(wrapper).toHaveClass('flex') - expect(wrapper).toHaveClass('items-center') - expect(wrapper).toHaveClass('gap-x-2') - expect(wrapper).toHaveClass('bg-components-panel-bg') - expect(wrapper).toHaveClass('p-1') - expect(wrapper).toHaveClass('pl-3') + expect(wrapper)!.toHaveClass('flex') + expect(wrapper)!.toHaveClass('items-center') + expect(wrapper)!.toHaveClass('gap-x-2') + expect(wrapper)!.toHaveClass('bg-components-panel-bg') + expect(wrapper)!.toHaveClass('p-1') + expect(wrapper)!.toHaveClass('pl-3') }) it('should render Input component with correct props', () => { @@ -84,8 +85,8 @@ describe('Header', () => { render(
) const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder') - expect(input).toBeInTheDocument() - expect(input).toHaveValue('test-value') + expect(input)!.toBeInTheDocument() + expect(input)!.toHaveValue('test-value') }) it('should render Input with search icon', () => { @@ -95,7 +96,7 @@ describe('Header', () => { // Assert - Input should have search icon class const searchIcon = container.querySelector('.i-ri-search-line.h-4.w-4') - expect(searchIcon).toBeInTheDocument() + expect(searchIcon)!.toBeInTheDocument() }) it('should render Input with correct wrapper width', () => { @@ -105,7 +106,7 @@ describe('Header', () => { // Assert - Input wrapper should have w-[200px] class const inputWrapper = container.querySelector('.w-\\[200px\\]') - expect(inputWrapper).toBeInTheDocument() + expect(inputWrapper)!.toBeInTheDocument() }) }) @@ -117,7 +118,7 @@ describe('Header', () => { render(
) const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder') - expect(input).toHaveValue('') + expect(input)!.toHaveValue('') }) it('should display input value correctly', () => { @@ -126,7 +127,7 @@ describe('Header', () => { render(
) const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder') - expect(input).toHaveValue('search-query') + expect(input)!.toHaveValue('search-query') }) it('should handle special characters in inputValue', () => { @@ -136,7 +137,7 @@ describe('Header', () => { render(
) const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder') - expect(input).toHaveValue(specialChars) + expect(input)!.toHaveValue(specialChars) }) it('should handle unicode characters in inputValue', () => { @@ -146,7 +147,7 @@ describe('Header', () => { render(
) const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder') - expect(input).toHaveValue(unicodeValue) + expect(input)!.toHaveValue(unicodeValue) }) }) @@ -157,7 +158,8 @@ describe('Header', () => { render(
) // Assert - Component should render without errors - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + // Assert - Component should render without errors + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should render with single breadcrumb', () => { @@ -165,7 +167,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should render with multiple breadcrumbs', () => { @@ -173,7 +175,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) }) @@ -184,7 +186,8 @@ describe('Header', () => { render(
) // Assert - keywords are passed through, component renders - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + // Assert - keywords are passed through, component renders + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) }) @@ -194,7 +197,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should render with bucket value', () => { @@ -202,7 +205,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) }) @@ -212,7 +215,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should handle positive search results', () => { @@ -221,7 +224,8 @@ describe('Header', () => { render(
) // Assert - Breadcrumbs will show search results text when keywords exist and results > 0 - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + // Assert - Breadcrumbs will show search results text when keywords exist and results > 0 + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should handle large search results count', () => { @@ -229,7 +233,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) }) @@ -239,7 +243,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should render correctly when isInPipeline is true', () => { @@ -247,7 +251,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) }) }) @@ -265,7 +269,7 @@ describe('Header', () => { expect(mockHandleInputChange).toHaveBeenCalledTimes(1) // Verify that onChange event was triggered (React's synthetic event structure) - expect(mockHandleInputChange.mock.calls[0][0]).toHaveProperty('type', 'change') + expect(mockHandleInputChange.mock.calls[0]![0]).toHaveProperty('type', 'change') }) it('should call handleInputChange on each keystroke', () => { @@ -290,7 +294,7 @@ describe('Header', () => { fireEvent.change(input, { target: { value: '' } }) expect(mockHandleInputChange).toHaveBeenCalledTimes(1) - expect(mockHandleInputChange.mock.calls[0][0]).toHaveProperty('type', 'change') + expect(mockHandleInputChange.mock.calls[0]![0]).toHaveProperty('type', 'change') }) it('should handle whitespace-only input', () => { @@ -302,7 +306,7 @@ describe('Header', () => { fireEvent.change(input, { target: { value: ' ' } }) expect(mockHandleInputChange).toHaveBeenCalledTimes(1) - expect(mockHandleInputChange.mock.calls[0][0]).toHaveProperty('type', 'change') + expect(mockHandleInputChange.mock.calls[0]![0]).toHaveProperty('type', 'change') }) }) @@ -317,7 +321,7 @@ describe('Header', () => { // Act - Find and click the clear icon container const clearButton = screen.getByTestId('input-clear') - expect(clearButton).toBeInTheDocument() + expect(clearButton)!.toBeInTheDocument() fireEvent.click(clearButton!) expect(mockHandleResetKeywords).toHaveBeenCalledTimes(1) @@ -338,7 +342,7 @@ describe('Header', () => { // Act & Assert - Clear icon should be visible const clearIcon = screen.getByTestId('input-clear') - expect(clearIcon).toBeInTheDocument() + expect(clearIcon)!.toBeInTheDocument() }) }) }) @@ -365,21 +369,23 @@ describe('Header', () => { rerender(
) // Assert - Component renders without errors - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + // Assert - Component renders without errors + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should re-render when inputValue changes', () => { const props = createDefaultProps({ inputValue: 'initial' }) const { rerender } = render(
) const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder') - expect(input).toHaveValue('initial') + expect(input)!.toHaveValue('initial') // Act - Rerender with different inputValue const newProps = createDefaultProps({ inputValue: 'changed' }) rerender(
) // Assert - Input value should be updated - expect(input).toHaveValue('changed') + // Assert - Input value should be updated + expect(input)!.toHaveValue('changed') }) it('should re-render when breadcrumbs change', () => { @@ -391,7 +397,8 @@ describe('Header', () => { rerender(
) // Assert - Component renders without errors - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + // Assert - Component renders without errors + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should re-render when keywords change', () => { @@ -403,7 +410,8 @@ describe('Header', () => { rerender(
) // Assert - Component renders without errors - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + // Assert - Component renders without errors + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) }) @@ -415,7 +423,7 @@ describe('Header', () => { render(
) const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder') - expect(input).toHaveValue(longValue) + expect(input)!.toHaveValue(longValue) }) it('should handle very long breadcrumb paths', () => { @@ -424,7 +432,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should handle breadcrumbs with special characters', () => { @@ -433,7 +441,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should handle breadcrumbs with unicode names', () => { @@ -442,7 +450,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should handle bucket with special characters', () => { @@ -450,7 +458,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should pass the event object to handleInputChange callback', () => { @@ -463,7 +471,7 @@ describe('Header', () => { // Assert - Verify the event object is passed correctly expect(mockHandleInputChange).toHaveBeenCalledTimes(1) - const eventArg = mockHandleInputChange.mock.calls[0][0] + const eventArg = mockHandleInputChange.mock.calls[0]![0] expect(eventArg).toHaveProperty('type', 'change') expect(eventArg).toHaveProperty('target') }) @@ -480,7 +488,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it.each([ @@ -493,7 +501,7 @@ describe('Header', () => { render(
) - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it.each([ @@ -507,7 +515,7 @@ describe('Header', () => { render(
) const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder') - expect(input).toHaveValue(inputValue) + expect(input)!.toHaveValue(inputValue) }) }) @@ -525,7 +533,8 @@ describe('Header', () => { render(
) // Assert - Component should render successfully, meaning props are passed correctly - expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument() + // Assert - Component should render successfully, meaning props are passed correctly + expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument() }) it('should pass correct props to Input component', () => { @@ -540,7 +549,7 @@ describe('Header', () => { render(
) const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder') - expect(input).toHaveValue('test-input') + expect(input)!.toHaveValue('test-input') // Test onChange handler fireEvent.change(input, { target: { value: 'new-value' } }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/bucket.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/bucket.spec.tsx index c407be51ac..83e17e6e04 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/bucket.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/bucket.spec.tsx @@ -22,18 +22,18 @@ describe('Bucket', () => { it('should render bucket name', () => { render() - expect(screen.getByText('my-bucket')).toBeInTheDocument() + expect(screen.getByText('my-bucket'))!.toBeInTheDocument() }) it('should render bucket icon', () => { render() - expect(screen.getByTestId('buckets-gray')).toBeInTheDocument() + expect(screen.getByTestId('buckets-gray'))!.toBeInTheDocument() }) it('should call handleBackToBucketList on icon button click', () => { render() const buttons = screen.getAllByRole('button') - fireEvent.click(buttons[0]) + fireEvent.click(buttons[0]!) expect(defaultProps.handleBackToBucketList).toHaveBeenCalledOnce() }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/index.spec.tsx index a6aaf3a50b..906a9e01e0 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/index.spec.tsx @@ -1,4 +1,4 @@ -import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' import Breadcrumbs from '../index' @@ -44,6 +44,16 @@ const resetMockStoreState = () => { mockStoreState.setBucket = vi.fn() } +const getDropdownTrigger = () => { + return document.querySelector('[aria-haspopup="menu"]') as HTMLElement | null +} + +const openCollapsedBreadcrumbDropdown = () => { + const dropdownTrigger = getDropdownTrigger() + expect(dropdownTrigger).toBeInTheDocument() + fireEvent.click(dropdownTrigger as HTMLElement) +} + describe('Breadcrumbs', () => { beforeEach(() => { vi.clearAllMocks() @@ -58,7 +68,7 @@ describe('Breadcrumbs', () => { // Assert - Container should be in the document const container = document.querySelector('.flex.grow') - expect(container).toBeInTheDocument() + expect(container)!.toBeInTheDocument() }) it('should render with correct container styles', () => { @@ -67,10 +77,10 @@ describe('Breadcrumbs', () => { const { container } = render() const wrapper = container.firstChild as HTMLElement - expect(wrapper).toHaveClass('flex') - expect(wrapper).toHaveClass('grow') - expect(wrapper).toHaveClass('items-center') - expect(wrapper).toHaveClass('overflow-hidden') + expect(wrapper)!.toHaveClass('flex') + expect(wrapper)!.toHaveClass('grow') + expect(wrapper)!.toHaveClass('items-center') + expect(wrapper)!.toHaveClass('overflow-hidden') }) describe('Search Results Display', () => { @@ -84,7 +94,8 @@ describe('Breadcrumbs', () => { render() // Assert - Search result text should be displayed - expect(screen.getByText(/datasetPipeline\.onlineDrive\.breadcrumbs\.searchResult/)).toBeInTheDocument() + // Assert - Search result text should be displayed + expect(screen.getByText(/datasetPipeline\.onlineDrive\.breadcrumbs\.searchResult/))!.toBeInTheDocument() }) it('should not show search results when keywords is empty', () => { @@ -121,7 +132,8 @@ describe('Breadcrumbs', () => { render() // Assert - Should use bucket name in search result - expect(screen.getByText(/searchResult.*my-bucket/i)).toBeInTheDocument() + // Assert - Should use bucket name in search result + expect(screen.getByText(/searchResult.*my-bucket/i))!.toBeInTheDocument() }) it('should use last breadcrumb as folderName when breadcrumbs exist', () => { @@ -135,7 +147,8 @@ describe('Breadcrumbs', () => { render() // Assert - Should use last breadcrumb in search result - expect(screen.getByText(/searchResult.*folder2/i)).toBeInTheDocument() + // Assert - Should use last breadcrumb in search result + expect(screen.getByText(/searchResult.*folder2/i))!.toBeInTheDocument() }) }) @@ -150,7 +163,7 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets')).toBeInTheDocument() + expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets'))!.toBeInTheDocument() }) it('should not show all buckets title when breadcrumbs exist', () => { @@ -174,6 +187,37 @@ describe('Breadcrumbs', () => { render() + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead + // Assert - Should show bucket name instead // Assert - Should show bucket name instead expect(screen.queryByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets')).not.toBeInTheDocument() }) @@ -190,7 +234,8 @@ describe('Breadcrumbs', () => { render() // Assert - Bucket name should be displayed - expect(screen.getByText('test-bucket')).toBeInTheDocument() + // Assert - Bucket name should be displayed + expect(screen.getByText('test-bucket'))!.toBeInTheDocument() }) it('should not render Bucket when hasBucket is false', () => { @@ -202,6 +247,37 @@ describe('Breadcrumbs', () => { render() + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead + // Assert - Bucket should not be displayed, Drive should be shown instead // Assert - Bucket should not be displayed, Drive should be shown instead expect(screen.queryByText('test-bucket')).not.toBeInTheDocument() }) @@ -217,7 +293,8 @@ describe('Breadcrumbs', () => { render() // Assert - "All Files" should be displayed - expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allFiles')).toBeInTheDocument() + // Assert - "All Files" should be displayed + expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allFiles'))!.toBeInTheDocument() }) it('should not render Drive component when hasBucket is true', () => { @@ -243,8 +320,8 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText('folder1')).toBeInTheDocument() - expect(screen.getByText('folder2')).toBeInTheDocument() + expect(screen.getByText('folder1'))!.toBeInTheDocument() + expect(screen.getByText('folder2'))!.toBeInTheDocument() }) it('should render last breadcrumb as active', () => { @@ -257,8 +334,8 @@ describe('Breadcrumbs', () => { // Assert - Last breadcrumb should have active styles const lastBreadcrumb = screen.getByText('folder2') - expect(lastBreadcrumb).toHaveClass('system-sm-medium') - expect(lastBreadcrumb).toHaveClass('text-text-secondary') + expect(lastBreadcrumb)!.toHaveClass('system-sm-medium') + expect(lastBreadcrumb)!.toHaveClass('text-text-secondary') }) it('should render non-last breadcrumbs with tertiary styles', () => { @@ -271,8 +348,8 @@ describe('Breadcrumbs', () => { // Assert - First breadcrumb should have tertiary styles const firstBreadcrumb = screen.getByText('folder1') - expect(firstBreadcrumb).toHaveClass('system-sm-regular') - expect(firstBreadcrumb).toHaveClass('text-text-tertiary') + expect(firstBreadcrumb)!.toHaveClass('system-sm-regular') + expect(firstBreadcrumb)!.toHaveClass('text-text-tertiary') }) }) @@ -287,7 +364,8 @@ describe('Breadcrumbs', () => { render() // Assert - Dropdown trigger (more button) should be present - expect(screen.getByRole('button', { name: '' })).toBeInTheDocument() + // Assert - Dropdown trigger (more button) should be present + expect(screen.getByRole('button', { name: '' }))!.toBeInTheDocument() }) it('should not show dropdown when breadcrumbs do not exceed displayBreadcrumbNum', () => { @@ -301,8 +379,10 @@ describe('Breadcrumbs', () => { // Assert - Should not have dropdown, just regular breadcrumbs // All breadcrumbs should be directly visible - expect(screen.getByText('folder1')).toBeInTheDocument() - expect(screen.getByText('folder2')).toBeInTheDocument() + // Assert - Should not have dropdown, just regular breadcrumbs + // All breadcrumbs should be directly visible + expect(screen.getByText('folder1'))!.toBeInTheDocument() + expect(screen.getByText('folder2'))!.toBeInTheDocument() // Count buttons - should be 3 (allFiles + folder1 + folder2) const buttons = container.querySelectorAll('button') expect(buttons.length).toBe(3) @@ -318,9 +398,41 @@ describe('Breadcrumbs', () => { render() // Assert - First breadcrumb and last breadcrumb should be visible - expect(screen.getByText('folder1')).toBeInTheDocument() - expect(screen.getByText('folder2')).toBeInTheDocument() - expect(screen.getByText('folder5')).toBeInTheDocument() + // Assert - First breadcrumb and last breadcrumb should be visible + expect(screen.getByText('folder1'))!.toBeInTheDocument() + expect(screen.getByText('folder2'))!.toBeInTheDocument() + expect(screen.getByText('folder5'))!.toBeInTheDocument() + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown + // Middle breadcrumbs should be in dropdown // Middle breadcrumbs should be in dropdown expect(screen.queryByText('folder3')).not.toBeInTheDocument() expect(screen.queryByText('folder4')).not.toBeInTheDocument() @@ -335,15 +447,11 @@ describe('Breadcrumbs', () => { render() // Act - Click on dropdown trigger (the ... button) - const dropdownTrigger = screen.getAllByRole('button').find(btn => btn.querySelector('svg')) - if (dropdownTrigger) - fireEvent.click(dropdownTrigger) + openCollapsedBreadcrumbDropdown() // Assert - Collapsed breadcrumbs should be visible - await waitFor(() => { - expect(screen.getByText('folder3')).toBeInTheDocument() - expect(screen.getByText('folder4')).toBeInTheDocument() - }) + expect(await screen.findByText('folder3')).toBeInTheDocument() + expect(await screen.findByText('folder4')).toBeInTheDocument() }) }) }) @@ -357,7 +465,8 @@ describe('Breadcrumbs', () => { render() // Assert - Only Drive should be visible - expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allFiles')).toBeInTheDocument() + // Assert - Only Drive should be visible + expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allFiles'))!.toBeInTheDocument() }) it('should handle single breadcrumb', () => { @@ -366,7 +475,7 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText('single-folder')).toBeInTheDocument() + expect(screen.getByText('single-folder'))!.toBeInTheDocument() }) it('should handle breadcrumbs with special characters', () => { @@ -377,8 +486,8 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText('folder [1]')).toBeInTheDocument() - expect(screen.getByText('folder (copy)')).toBeInTheDocument() + expect(screen.getByText('folder [1]'))!.toBeInTheDocument() + expect(screen.getByText('folder (copy)'))!.toBeInTheDocument() }) it('should handle breadcrumbs with unicode characters', () => { @@ -389,8 +498,8 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText('文件夹')).toBeInTheDocument() - expect(screen.getByText('フォルダ')).toBeInTheDocument() + expect(screen.getByText('文件夹'))!.toBeInTheDocument() + expect(screen.getByText('フォルダ'))!.toBeInTheDocument() }) }) @@ -403,7 +512,7 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText(/searchResult/)).toBeInTheDocument() + expect(screen.getByText(/searchResult/))!.toBeInTheDocument() }) it('should handle whitespace keywords', () => { @@ -415,7 +524,8 @@ describe('Breadcrumbs', () => { render() // Assert - Whitespace is truthy, so should show search results - expect(screen.getByText(/searchResult/)).toBeInTheDocument() + // Assert - Whitespace is truthy, so should show search results + expect(screen.getByText(/searchResult/))!.toBeInTheDocument() }) }) @@ -428,7 +538,7 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText('production-bucket')).toBeInTheDocument() + expect(screen.getByText('production-bucket'))!.toBeInTheDocument() }) it('should handle bucket with special characters', () => { @@ -439,7 +549,7 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText('bucket-v2.0_backup')).toBeInTheDocument() + expect(screen.getByText('bucket-v2.0_backup'))!.toBeInTheDocument() }) }) @@ -452,6 +562,37 @@ describe('Breadcrumbs', () => { render() + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results + // Assert - Should not show search results // Assert - Should not show search results expect(screen.queryByText(/searchResult/)).not.toBeInTheDocument() }) @@ -464,7 +605,7 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText(/searchResult.*10000/)).toBeInTheDocument() + expect(screen.getByText(/searchResult.*10000/))!.toBeInTheDocument() }) }) @@ -480,9 +621,7 @@ describe('Breadcrumbs', () => { // Assert - Should collapse because 3 > 2 // Dropdown should be present - const buttons = screen.getAllByRole('button') - const hasDropdownTrigger = buttons.some(btn => btn.querySelector('svg')) - expect(hasDropdownTrigger).toBe(true) + expect(getDropdownTrigger()).toBeInTheDocument() }) it('should use displayBreadcrumbNum=3 when isInPipeline is false', () => { @@ -495,9 +634,10 @@ describe('Breadcrumbs', () => { render() // Assert - Should NOT collapse because 3 <= 3 - expect(screen.getByText('folder1')).toBeInTheDocument() - expect(screen.getByText('folder2')).toBeInTheDocument() - expect(screen.getByText('folder3')).toBeInTheDocument() + // Assert - Should NOT collapse because 3 <= 3 + expect(screen.getByText('folder1'))!.toBeInTheDocument() + expect(screen.getByText('folder2'))!.toBeInTheDocument() + expect(screen.getByText('folder3'))!.toBeInTheDocument() }) it('should reduce displayBreadcrumbNum by 1 when bucket is set', () => { @@ -511,9 +651,7 @@ describe('Breadcrumbs', () => { render() // Assert - Should collapse because 3 > 2 - const buttons = screen.getAllByRole('button') - const hasDropdownTrigger = buttons.some(btn => btn.querySelector('svg')) - expect(hasDropdownTrigger).toBe(true) + expect(getDropdownTrigger()).toBeInTheDocument() }) }) }) @@ -533,9 +671,11 @@ describe('Breadcrumbs', () => { // Assert - displayBreadcrumbNum = 3, so 4 breadcrumbs should collapse // First 2 visible, dropdown, last 1 visible - expect(screen.getByText('a')).toBeInTheDocument() - expect(screen.getByText('b')).toBeInTheDocument() - expect(screen.getByText('d')).toBeInTheDocument() + // Assert - displayBreadcrumbNum = 3, so 4 breadcrumbs should collapse + // First 2 visible, dropdown, last 1 visible + expect(screen.getByText('a'))!.toBeInTheDocument() + expect(screen.getByText('b'))!.toBeInTheDocument() + expect(screen.getByText('d'))!.toBeInTheDocument() expect(screen.queryByText('c')).not.toBeInTheDocument() }) @@ -550,8 +690,9 @@ describe('Breadcrumbs', () => { render() // Assert - displayBreadcrumbNum = 2, so 3 breadcrumbs should collapse - expect(screen.getByText('a')).toBeInTheDocument() - expect(screen.getByText('c')).toBeInTheDocument() + // Assert - displayBreadcrumbNum = 2, so 3 breadcrumbs should collapse + expect(screen.getByText('a'))!.toBeInTheDocument() + expect(screen.getByText('c'))!.toBeInTheDocument() expect(screen.queryByText('b')).not.toBeInTheDocument() }) @@ -566,8 +707,9 @@ describe('Breadcrumbs', () => { render() // Assert - displayBreadcrumbNum = 3 - 1 = 2, so 3 breadcrumbs should collapse - expect(screen.getByText('a')).toBeInTheDocument() - expect(screen.getByText('c')).toBeInTheDocument() + // Assert - displayBreadcrumbNum = 3 - 1 = 2, so 3 breadcrumbs should collapse + expect(screen.getByText('a'))!.toBeInTheDocument() + expect(screen.getByText('c'))!.toBeInTheDocument() expect(screen.queryByText('b')).not.toBeInTheDocument() }) }) @@ -582,9 +724,7 @@ describe('Breadcrumbs', () => { render() // Act - Click dropdown to see collapsed items - const dropdownTrigger = screen.getAllByRole('button').find(btn => btn.querySelector('svg')) - if (dropdownTrigger) - fireEvent.click(dropdownTrigger) + openCollapsedBreadcrumbDropdown() // prefixBreadcrumbs = ['f1', 'f2'] // collapsedBreadcrumbs = ['f3', 'f4'] @@ -592,10 +732,8 @@ describe('Breadcrumbs', () => { expect(screen.getByText('f1')).toBeInTheDocument() expect(screen.getByText('f2')).toBeInTheDocument() expect(screen.getByText('f5')).toBeInTheDocument() - await waitFor(() => { - expect(screen.getByText('f3')).toBeInTheDocument() - expect(screen.getByText('f4')).toBeInTheDocument() - }) + expect(await screen.findByText('f3')).toBeInTheDocument() + expect(await screen.findByText('f4')).toBeInTheDocument() }) it('should not collapse when breadcrumbs.length <= displayBreadcrumbNum', () => { @@ -608,8 +746,9 @@ describe('Breadcrumbs', () => { render() // Assert - All breadcrumbs should be visible - expect(screen.getByText('f1')).toBeInTheDocument() - expect(screen.getByText('f2')).toBeInTheDocument() + // Assert - All breadcrumbs should be visible + expect(screen.getByText('f1'))!.toBeInTheDocument() + expect(screen.getByText('f2'))!.toBeInTheDocument() }) }) }) @@ -627,7 +766,7 @@ describe('Breadcrumbs', () => { // Act - Click bucket icon button (first button in Bucket component) const buttons = screen.getAllByRole('button') - fireEvent.click(buttons[0]) // Bucket icon button + fireEvent.click(buttons[0]!) // Bucket icon button expect(mockStoreState.setOnlineDriveFileList).toHaveBeenCalledWith([]) expect(mockStoreState.setSelectedFileIds).toHaveBeenCalledWith([]) @@ -739,15 +878,8 @@ describe('Breadcrumbs', () => { render() // Act - Open dropdown and click on collapsed breadcrumb (f3, index=2) - const dropdownTrigger = screen.getAllByRole('button').find(btn => btn.querySelector('svg')) - if (dropdownTrigger) - fireEvent.click(dropdownTrigger) - - await waitFor(() => { - expect(screen.getByText('f3')).toBeInTheDocument() - }) - - fireEvent.click(screen.getByText('f3')) + openCollapsedBreadcrumbDropdown() + fireEvent.click(await screen.findByText('f3')) // Assert - Should slice to index 2 + 1 = 3 expect(mockStoreState.setBreadcrumbs).toHaveBeenCalledWith(['f1', 'f2', 'f3']) @@ -771,19 +903,19 @@ describe('Breadcrumbs', () => { // Assert - Component should render without errors const container = document.querySelector('.flex.grow') - expect(container).toBeInTheDocument() + expect(container)!.toBeInTheDocument() }) it('should re-render when breadcrumbs change', () => { mockStoreState.hasBucket = false const props = createDefaultProps({ breadcrumbs: ['folder1'] }) const { rerender } = render() - expect(screen.getByText('folder1')).toBeInTheDocument() + expect(screen.getByText('folder1'))!.toBeInTheDocument() // Act - Rerender with different breadcrumbs rerender() - expect(screen.getByText('folder2')).toBeInTheDocument() + expect(screen.getByText('folder2'))!.toBeInTheDocument() }) }) @@ -798,7 +930,7 @@ describe('Breadcrumbs', () => { render() - expect(screen.getByText(longName)).toBeInTheDocument() + expect(screen.getByText(longName))!.toBeInTheDocument() }) it('should handle many breadcrumbs', async () => { @@ -810,17 +942,13 @@ describe('Breadcrumbs', () => { render() // Act - Open dropdown - const dropdownTrigger = screen.getAllByRole('button').find(btn => btn.querySelector('svg')) - if (dropdownTrigger) - fireEvent.click(dropdownTrigger) + openCollapsedBreadcrumbDropdown() // Assert - First, last, and collapsed should be accessible expect(screen.getByText('folder-0')).toBeInTheDocument() expect(screen.getByText('folder-1')).toBeInTheDocument() expect(screen.getByText('folder-19')).toBeInTheDocument() - await waitFor(() => { - expect(screen.getByText('folder-2')).toBeInTheDocument() - }) + expect(await screen.findByText('folder-2')).toBeInTheDocument() }) it('should handle empty bucket string', () => { @@ -833,7 +961,8 @@ describe('Breadcrumbs', () => { render() // Assert - Should show all buckets title - expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets')).toBeInTheDocument() + // Assert - Should show all buckets title + expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets'))!.toBeInTheDocument() }) it('should handle breadcrumb with only whitespace', () => { @@ -845,7 +974,8 @@ describe('Breadcrumbs', () => { render() // Assert - Both should be rendered - expect(screen.getByText('normal-folder')).toBeInTheDocument() + // Assert - Both should be rendered + expect(screen.getByText('normal-folder'))!.toBeInTheDocument() }) }) @@ -863,7 +993,7 @@ describe('Breadcrumbs', () => { // Assert - Component should render without errors const container = document.querySelector('.flex.grow') - expect(container).toBeInTheDocument() + expect(container)!.toBeInTheDocument() }) it.each([ @@ -879,9 +1009,7 @@ describe('Breadcrumbs', () => { render() // Assert - Should collapse because breadcrumbs.length > expectedNum - const buttons = screen.getAllByRole('button') - const hasDropdownTrigger = buttons.some(btn => btn.querySelector('svg')) - expect(hasDropdownTrigger).toBe(true) + expect(getDropdownTrigger()).toBeInTheDocument() }) }) @@ -916,7 +1044,8 @@ describe('Breadcrumbs', () => { render() // Assert - Search result should be shown, navigation elements should be hidden - expect(screen.getByText(/searchResult/)).toBeInTheDocument() + // Assert - Search result should be shown, navigation elements should be hidden + expect(screen.getByText(/searchResult/))!.toBeInTheDocument() expect(screen.queryByText('my-bucket')).not.toBeInTheDocument() }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/__tests__/index.spec.tsx index 0157d3cf79..d57e8340e9 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/__tests__/index.spec.tsx @@ -23,7 +23,8 @@ describe('Dropdown', () => { render() // Assert - Trigger button should be visible - expect(screen.getByRole('button')).toBeInTheDocument() + // Assert - Trigger button should be visible + expect(screen.getByRole('button'))!.toBeInTheDocument() }) it('should render trigger button with more icon', () => { @@ -31,10 +32,10 @@ describe('Dropdown', () => { const { container } = render() - // Assert - Button should have RiMoreFill icon (rendered as svg) + // Assert - Button should have the more icon const button = screen.getByRole('button') expect(button).toBeInTheDocument() - expect(container.querySelector('svg')).toBeInTheDocument() + expect(container.querySelector('.i-ri-more-fill')).toBeInTheDocument() }) it('should render separator after dropdown', () => { @@ -43,7 +44,8 @@ describe('Dropdown', () => { render() // Assert - Separator "/" should be visible - expect(screen.getByText('/')).toBeInTheDocument() + // Assert - Separator "/" should be visible + expect(screen.getByText('/'))!.toBeInTheDocument() }) it('should render trigger button with correct default styles', () => { @@ -52,11 +54,11 @@ describe('Dropdown', () => { render() const button = screen.getByRole('button') - expect(button).toHaveClass('flex') - expect(button).toHaveClass('size-6') - expect(button).toHaveClass('items-center') - expect(button).toHaveClass('justify-center') - expect(button).toHaveClass('rounded-md') + expect(button)!.toHaveClass('flex') + expect(button)!.toHaveClass('size-6') + expect(button)!.toHaveClass('items-center') + expect(button)!.toHaveClass('justify-center') + expect(button)!.toHaveClass('rounded-md') }) it('should not render menu content when closed', () => { @@ -64,6 +66,37 @@ describe('Dropdown', () => { render() + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed + // Assert - Menu content should not be visible when dropdown is closed // Assert - Menu content should not be visible when dropdown is closed expect(screen.queryByText('visible-folder')).not.toBeInTheDocument() }) @@ -77,8 +110,8 @@ describe('Dropdown', () => { // Assert - Menu items should be visible await waitFor(() => { - expect(screen.getByText('test-folder1')).toBeInTheDocument() - expect(screen.getByText('test-folder2')).toBeInTheDocument() + expect(screen.getByText('test-folder1'))!.toBeInTheDocument() + expect(screen.getByText('test-folder2'))!.toBeInTheDocument() }) }) }) @@ -98,7 +131,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder1')).toBeInTheDocument() + expect(screen.getByText('folder1'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder1')) @@ -120,7 +153,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder2')).toBeInTheDocument() + expect(screen.getByText('folder2'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder2')) @@ -140,9 +173,9 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder-a')).toBeInTheDocument() - expect(screen.getByText('folder-b')).toBeInTheDocument() - expect(screen.getByText('folder-c')).toBeInTheDocument() + expect(screen.getByText('folder-a'))!.toBeInTheDocument() + expect(screen.getByText('folder-b'))!.toBeInTheDocument() + expect(screen.getByText('folder-c'))!.toBeInTheDocument() }) }) @@ -155,7 +188,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('single-folder')).toBeInTheDocument() + expect(screen.getByText('single-folder'))!.toBeInTheDocument() }) }) @@ -170,7 +203,8 @@ describe('Dropdown', () => { // Assert - Menu should be rendered but with no items await waitFor(() => { // The menu container should exist but be empty - expect(screen.getByRole('button')).toBeInTheDocument() + // The menu container should exist but be empty + expect(screen.getByRole('button'))!.toBeInTheDocument() }) }) @@ -183,9 +217,9 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder [1]')).toBeInTheDocument() - expect(screen.getByText('folder (copy)')).toBeInTheDocument() - expect(screen.getByText('folder-v2.0')).toBeInTheDocument() + expect(screen.getByText('folder [1]'))!.toBeInTheDocument() + expect(screen.getByText('folder (copy)'))!.toBeInTheDocument() + expect(screen.getByText('folder-v2.0'))!.toBeInTheDocument() }) }) @@ -198,9 +232,9 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('文件夹')).toBeInTheDocument() - expect(screen.getByText('フォルダ')).toBeInTheDocument() - expect(screen.getByText('Папка')).toBeInTheDocument() + expect(screen.getByText('文件夹'))!.toBeInTheDocument() + expect(screen.getByText('フォルダ'))!.toBeInTheDocument() + expect(screen.getByText('Папка'))!.toBeInTheDocument() }) }) }) @@ -218,7 +252,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder1')).toBeInTheDocument() + expect(screen.getByText('folder1'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder1')) @@ -236,6 +270,37 @@ describe('Dropdown', () => { render() + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible + // Assert - Menu content should not be visible // Assert - Menu content should not be visible expect(screen.queryByText('test-folder')).not.toBeInTheDocument() }) @@ -247,7 +312,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('test-folder')).toBeInTheDocument() + expect(screen.getByText('test-folder'))!.toBeInTheDocument() }) }) @@ -258,7 +323,7 @@ describe('Dropdown', () => { // Act - Open and then close fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('test-folder')).toBeInTheDocument() + expect(screen.getByText('test-folder'))!.toBeInTheDocument() }) fireEvent.click(screen.getByRole('button')) @@ -280,7 +345,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('test-folder')).toBeInTheDocument() + expect(screen.getByText('test-folder'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('test-folder')) @@ -297,14 +362,15 @@ describe('Dropdown', () => { const button = screen.getByRole('button') // Assert - Initial state (closed): should have hover:bg-state-base-hover - expect(button).toHaveClass('hover:bg-state-base-hover') + // Assert - Initial state (closed): should have hover:bg-state-base-hover + expect(button)!.toHaveClass('hover:bg-state-base-hover') // Act - Open dropdown fireEvent.click(button) // Assert - Open state: should have bg-state-base-hover await waitFor(() => { - expect(button).toHaveClass('bg-state-base-hover') + expect(button)!.toHaveClass('bg-state-base-hover') }) }) }) @@ -317,6 +383,37 @@ describe('Dropdown', () => { const props = createDefaultProps({ breadcrumbs: ['folder'] }) render() + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed + // Act & Assert - Initially closed // Act & Assert - Initially closed expect(screen.queryByText('folder')).not.toBeInTheDocument() @@ -325,7 +422,7 @@ describe('Dropdown', () => { // Assert - Now open await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) }) @@ -338,7 +435,7 @@ describe('Dropdown', () => { // 1st click - open fireEvent.click(button) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) // 2nd click - close @@ -350,7 +447,7 @@ describe('Dropdown', () => { // 3rd click - open again fireEvent.click(button) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) }) }) @@ -368,7 +465,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder1')).toBeInTheDocument() + expect(screen.getByText('folder1'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder1')) @@ -394,7 +491,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder1')).toBeInTheDocument() + expect(screen.getByText('folder1'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder1')) @@ -423,7 +520,7 @@ describe('Dropdown', () => { // Act - Open and click fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder')) @@ -431,7 +528,7 @@ describe('Dropdown', () => { rerender() fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder')) @@ -450,7 +547,7 @@ describe('Dropdown', () => { // Act - Open and click with first callback fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder')) @@ -466,7 +563,7 @@ describe('Dropdown', () => { // Open and click with second callback fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder')) @@ -482,7 +579,8 @@ describe('Dropdown', () => { rerender() // Assert - Component should render without errors - expect(screen.getByRole('button')).toBeInTheDocument() + // Assert - Component should render without errors + expect(screen.getByRole('button'))!.toBeInTheDocument() }) }) @@ -499,7 +597,7 @@ describe('Dropdown', () => { // Assert - Should handle gracefully (open after odd number of clicks) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) }) @@ -513,7 +611,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText(longName)).toBeInTheDocument() + expect(screen.getByText(longName))!.toBeInTheDocument() }) }) @@ -528,8 +626,8 @@ describe('Dropdown', () => { // Assert - First and last items should be visible await waitFor(() => { - expect(screen.getByText('folder-0')).toBeInTheDocument() - expect(screen.getByText('folder-19')).toBeInTheDocument() + expect(screen.getByText('folder-0'))!.toBeInTheDocument() + expect(screen.getByText('folder-19'))!.toBeInTheDocument() }) }) @@ -544,7 +642,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder')) @@ -562,7 +660,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('folder')) @@ -578,7 +676,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('normal-folder')).toBeInTheDocument() + expect(screen.getByText('normal-folder'))!.toBeInTheDocument() }) }) @@ -591,7 +689,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('folder')).toBeInTheDocument() + expect(screen.getByText('folder'))!.toBeInTheDocument() }) }) }) @@ -613,9 +711,9 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText(breadcrumbs[0])).toBeInTheDocument() + expect(screen.getByText(breadcrumbs[0]!))!.toBeInTheDocument() }) - fireEvent.click(screen.getByText(breadcrumbs[0])) + fireEvent.click(screen.getByText(breadcrumbs[0]!)) expect(mockOnBreadcrumbClick).toHaveBeenCalledWith(expectedIndex) }) @@ -634,7 +732,7 @@ describe('Dropdown', () => { // Assert - Should render without errors await waitFor(() => { if (breadcrumbs.length > 0) - expect(screen.getByText(breadcrumbs[0])).toBeInTheDocument() + expect(screen.getByText(breadcrumbs[0]!))!.toBeInTheDocument() }) }) }) @@ -650,9 +748,9 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('Documents')).toBeInTheDocument() - expect(screen.getByText('Projects')).toBeInTheDocument() - expect(screen.getByText('Archive')).toBeInTheDocument() + expect(screen.getByText('Documents'))!.toBeInTheDocument() + expect(screen.getByText('Projects'))!.toBeInTheDocument() + expect(screen.getByText('Archive'))!.toBeInTheDocument() }) }) @@ -668,7 +766,7 @@ describe('Dropdown', () => { // Act - Open and click on second item fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('second')).toBeInTheDocument() + expect(screen.getByText('second'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('second')) @@ -687,7 +785,7 @@ describe('Dropdown', () => { // Act - Open and click on middle item fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText('item2')).toBeInTheDocument() + expect(screen.getByText('item2'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('item2')) @@ -714,7 +812,7 @@ describe('Dropdown', () => { fireEvent.click(screen.getByRole('button')) await waitFor(() => { - expect(screen.getByText(`folder-${String.fromCharCode(97 + i)}`)).toBeInTheDocument() + expect(screen.getByText(`folder-${String.fromCharCode(97 + i)}`))!.toBeInTheDocument() }) fireEvent.click(screen.getByText(`folder-${String.fromCharCode(97 + i)}`)) @@ -731,7 +829,7 @@ describe('Dropdown', () => { render() const button = screen.getByRole('button') - expect(button).toBeInTheDocument() + expect(button)!.toBeInTheDocument() expect(button.tagName).toBe('BUTTON') }) @@ -741,7 +839,7 @@ describe('Dropdown', () => { render() const button = screen.getByRole('button') - expect(button).toHaveAttribute('type', 'button') + expect(button)!.toHaveAttribute('type', 'button') }) }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/index.tsx index 7178b45b34..43b5fcc71a 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/index.tsx @@ -1,12 +1,11 @@ import { cn } from '@langgenius/dify-ui/cn' -import { RiMoreFill } from '@remixicon/react' +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuTrigger, +} from '@langgenius/dify-ui/dropdown-menu' import * as React from 'react' import { useCallback, useState } from 'react' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' import Menu from './menu' type DropdownProps = { @@ -22,26 +21,17 @@ const Dropdown = ({ }: DropdownProps) => { const [open, setOpen] = useState(false) - const handleTrigger = useCallback(() => { - setOpen(prev => !prev) - }, []) - const handleBreadCrumbClick = useCallback((index: number) => { onBreadcrumbClick(index) setOpen(false) }, [onBreadcrumbClick]) return ( - - + }> - - + + - + / - + ) } diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/item.tsx index 864cade85c..6f04ede88a 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/item.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/item.tsx @@ -18,7 +18,7 @@ const Item = ({ return (
{name} diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/index.tsx index 714b393d03..1dfa0443d0 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/index.tsx @@ -83,7 +83,7 @@ const Breadcrumbs = ({ return (
{showSearchResult && ( -
+
{t('onlineDrive.breadcrumbs.searchResult', { ns: 'datasetPipeline', searchResultsLength, @@ -92,7 +92,7 @@ const Breadcrumbs = ({
)} {!showSearchResult && showBucketListTitle && ( -
+
{t('onlineDrive.breadcrumbs.allBuckets', { ns: 'datasetPipeline' })}
)} @@ -152,7 +152,7 @@ const Breadcrumbs = ({ void diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/index.tsx index f0245cc9a4..00386ec135 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/index.tsx @@ -39,7 +39,7 @@ const List = ({ if (anchorRef.current) { observerRef.current = new IntersectionObserver((entries) => { const { setNextPageParameters, currentNextPageParametersRef, isTruncated } = dataSourceStore.getState() - if (entries[0].isIntersecting && isTruncated.current && !isLoading) + if (entries[0]!.isIntersecting && isTruncated.current && !isLoading) setNextPageParameters(currentNextPageParametersRef.current) }, { rootMargin: '100px', diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/utils.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/utils.ts index 6b367ebe3c..07b23a56ff 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/utils.ts +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/list/utils.ts @@ -8,7 +8,7 @@ export const getFileExtension = (fileName: string): string => { if (parts.length <= 1 || (parts[0] === '' && parts.length === 2)) return '' - return parts[parts.length - 1].toLowerCase() + return parts[parts.length - 1]!.toLowerCase() } export const getFileType = (fileName: string) => { @@ -17,13 +17,13 @@ export const getFileType = (fileName: string) => { if (extension === 'gif') return FileAppearanceTypeEnum.gif - if (FILE_EXTS.image.includes(extension.toUpperCase())) + if (FILE_EXTS.image!.includes(extension.toUpperCase())) return FileAppearanceTypeEnum.image - if (FILE_EXTS.video.includes(extension.toUpperCase())) + if (FILE_EXTS.video!.includes(extension.toUpperCase())) return FileAppearanceTypeEnum.video - if (FILE_EXTS.audio.includes(extension.toUpperCase())) + if (FILE_EXTS.audio!.includes(extension.toUpperCase())) return FileAppearanceTypeEnum.audio if (extension === 'html' || extension === 'htm' || extension === 'xml' || extension === 'json') @@ -44,7 +44,7 @@ export const getFileType = (fileName: string) => { if (extension === 'pptx' || extension === 'ppt') return FileAppearanceTypeEnum.ppt - if (FILE_EXTS.document.includes(extension.toUpperCase())) + if (FILE_EXTS.document!.includes(extension.toUpperCase())) return FileAppearanceTypeEnum.document return FileAppearanceTypeEnum.custom diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/header.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/header.tsx index 5c12aaa68c..bc51751aef 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/header.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/header.tsx @@ -1,7 +1,7 @@ +import { Button } from '@langgenius/dify-ui/button' import { RiBookOpenLine, RiEqualizer2Line } from '@remixicon/react' import * as React from 'react' import Divider from '@/app/components/base/divider' -import { Button } from '@/app/components/base/ui/button' type HeaderProps = { onClickConfiguration?: () => void diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx index 2113e8841c..76614e3865 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/index.tsx @@ -1,10 +1,10 @@ import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' import type { OnlineDriveFile } from '@/models/pipeline' import type { DataSourceNodeCompletedResponse, DataSourceNodeErrorResponse } from '@/types/pipeline' +import { toast } from '@langgenius/dify-ui/toast' import { produce } from 'immer' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useShallow } from 'zustand/react/shallow' -import { toast } from '@/app/components/base/ui/toast' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useDocLink } from '@/context/i18n' diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/utils.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/utils.ts index 89b84069fe..8f0ff13ac5 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/utils.ts +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/utils.ts @@ -10,7 +10,7 @@ export const isBucketListInitiation = (data: OnlineDriveData[], prefix: string[] if (bucket || prefix.length > 0) return false const hasBucket = data.every(item => !!item.bucket) - return hasBucket && (data.length > 1 || (data.length === 1 && !!data[0].bucket && data[0].files.length === 0)) + return hasBucket && (data.length > 1 || (data.length === 1 && !!data[0]!.bucket && data[0]!.files.length === 0)) } export const convertOnlineDriveData = (data: OnlineDriveData[], prefix: string[], bucket: string): { @@ -38,7 +38,7 @@ export const convertOnlineDriveData = (data: OnlineDriveData[], prefix: string[] hasBucket = true } else { - data[0].files.forEach((file) => { + data[0]!.files.forEach((file) => { const { id, name, size, type } = file const isFileType = isFile(type) fileList.push({ @@ -48,9 +48,9 @@ export const convertOnlineDriveData = (data: OnlineDriveData[], prefix: string[] type: isFileType ? OnlineDriveFileType.file : OnlineDriveFileType.folder, }) }) - isTruncated = data[0].is_truncated ?? false - nextPageParameters = data[0].next_page_parameters ?? {} - hasBucket = !!data[0].bucket + isTruncated = data[0]!.is_truncated ?? false + nextPageParameters = data[0]!.next_page_parameters ?? {} + hasBucket = !!data[0]!.bucket } return { fileList, isTruncated, nextPageParameters, hasBucket } } diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/crawled-result-item.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/crawled-result-item.spec.tsx index 62dba84e30..cded02b431 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/crawled-result-item.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/crawled-result-item.spec.tsx @@ -3,7 +3,7 @@ import { render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import CrawledResultItem from '../crawled-result-item' -vi.mock('@/app/components/base/ui/button', () => ({ +vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => ( ), diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/crawled-result.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/crawled-result.spec.tsx index 9c71f91d8d..6c476e6dd6 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/crawled-result.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/crawled-result.spec.tsx @@ -49,9 +49,9 @@ const createItem = (url: string): CrawlResultItem => ({ }) const defaultList: CrawlResultItem[] = [ - createItem('https://example.com/a'), - createItem('https://example.com/b'), - createItem('https://example.com/c'), + createItem('https://example.com/a')!, + createItem('https://example.com/b')!, + createItem('https://example.com/c')!, ] describe('CrawledResult', () => { @@ -70,19 +70,18 @@ describe('CrawledResult', () => { it('should render scrap time info with correct total and time', () => { render() - expect( - screen.getByText(/scrapTimeInfo/), - ).toBeInTheDocument() + expect(screen.getByText(/scrapTimeInfo/))!.toBeInTheDocument() // The global i18n mock serialises params, so verify total and time appear - expect(screen.getByText(/"total":3/)).toBeInTheDocument() - expect(screen.getByText(/"time":"12.3"/)).toBeInTheDocument() + // The global i18n mock serialises params, so verify total and time appear + expect(screen.getByText(/"total":3/))!.toBeInTheDocument() + expect(screen.getByText(/"time":"12.3"/))!.toBeInTheDocument() }) it('should render all items from list', () => { render() for (const item of defaultList) { - expect(screen.getByTestId(`crawled-item-${item.source_url}`)).toBeInTheDocument() + expect(screen.getByTestId(`crawled-item-${item.source_url}`))!.toBeInTheDocument() } }) @@ -91,7 +90,7 @@ describe('CrawledResult', () => { , ) - expect(container.firstChild).toHaveClass('my-custom-class') + expect(container.firstChild)!.toHaveClass('my-custom-class') }) }) @@ -100,7 +99,7 @@ describe('CrawledResult', () => { it('should show check-all checkbox in multiple choice mode', () => { render() - expect(screen.getByTestId('check-all-checkbox')).toBeInTheDocument() + expect(screen.getByTestId('check-all-checkbox'))!.toBeInTheDocument() }) it('should hide check-all checkbox in single choice mode', () => { @@ -117,7 +116,7 @@ describe('CrawledResult', () => { render( , ) @@ -150,13 +149,13 @@ describe('CrawledResult', () => { render( , ) - fireEvent.click(screen.getByTestId(`check-${defaultList[1].source_url}`)) + fireEvent.click(screen.getByTestId(`check-${defaultList[1]!.source_url}`)) expect(onSelectedChange).toHaveBeenCalledWith([defaultList[0], defaultList[1]]) }) @@ -166,13 +165,13 @@ describe('CrawledResult', () => { render( , ) - fireEvent.click(screen.getByTestId(`check-${defaultList[1].source_url}`)) + fireEvent.click(screen.getByTestId(`check-${defaultList[1]!.source_url}`)) expect(onSelectedChange).toHaveBeenCalledWith([defaultList[1]]) }) @@ -182,13 +181,13 @@ describe('CrawledResult', () => { render( , ) - fireEvent.click(screen.getByTestId(`check-${defaultList[0].source_url}`)) + fireEvent.click(screen.getByTestId(`check-${defaultList[0]!.source_url}`)) expect(onSelectedChange).toHaveBeenCalledWith([defaultList[1]]) }) @@ -206,7 +205,7 @@ describe('CrawledResult', () => { />, ) - fireEvent.click(screen.getByTestId(`preview-${defaultList[1].source_url}`)) + fireEvent.click(screen.getByTestId(`preview-${defaultList[1]!.source_url}`)) expect(onPreview).toHaveBeenCalledWith(defaultList[1], 1) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/index.spec.tsx index 99002e687c..fa5633e2df 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/__tests__/index.spec.tsx @@ -39,7 +39,7 @@ describe('CheckboxWithLabel', () => { it('should render without crashing', () => { render() - expect(screen.getByText('Test Label')).toBeInTheDocument() + expect(screen.getByText('Test Label'))!.toBeInTheDocument() }) it('should render checkbox in unchecked state', () => { @@ -47,7 +47,7 @@ describe('CheckboxWithLabel', () => { // Assert - Custom checkbox component uses div with data-testid const checkbox = container.querySelector('[data-testid^="checkbox"]') - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() expect(checkbox).not.toHaveClass('bg-components-checkbox-bg') }) @@ -56,7 +56,7 @@ describe('CheckboxWithLabel', () => { // Assert - Checked state has check icon const checkIcon = container.querySelector('[data-testid^="check-icon"]') - expect(checkIcon).toBeInTheDocument() + expect(checkIcon)!.toBeInTheDocument() }) it('should render tooltip when provided', () => { @@ -64,7 +64,7 @@ describe('CheckboxWithLabel', () => { // Assert - Tooltip trigger should be present const tooltipTrigger = document.querySelector('[class*="ml-0.5"]') - expect(tooltipTrigger).toBeInTheDocument() + expect(tooltipTrigger)!.toBeInTheDocument() }) it('should not render tooltip when not provided', () => { @@ -82,14 +82,14 @@ describe('CheckboxWithLabel', () => { ) const label = container.querySelector('label') - expect(label).toHaveClass('custom-class') + expect(label)!.toHaveClass('custom-class') }) it('should apply custom labelClassName', () => { render() const labelText = screen.getByText('Test Label') - expect(labelText).toHaveClass('custom-label-class') + expect(labelText)!.toHaveClass('custom-label-class') }) }) @@ -147,8 +147,8 @@ describe('CrawledResultItem', () => { it('should render without crashing', () => { render() - expect(screen.getByText('Test Page Title')).toBeInTheDocument() - expect(screen.getByText('https://example.com/page1')).toBeInTheDocument() + expect(screen.getByText('Test Page Title'))!.toBeInTheDocument() + expect(screen.getByText('https://example.com/page1'))!.toBeInTheDocument() }) it('should render checkbox when isMultipleChoice is true', () => { @@ -156,7 +156,7 @@ describe('CrawledResultItem', () => { // Assert - Custom checkbox uses data-testid const checkbox = container.querySelector('[data-testid^="checkbox"]') - expect(checkbox).toBeInTheDocument() + expect(checkbox)!.toBeInTheDocument() }) it('should render radio when isMultipleChoice is false', () => { @@ -164,7 +164,7 @@ describe('CrawledResultItem', () => { // Assert - Radio component has size-4 rounded-full classes const radio = container.querySelector('.size-4.rounded-full') - expect(radio).toBeInTheDocument() + expect(radio)!.toBeInTheDocument() }) it('should render checkbox as checked when isChecked is true', () => { @@ -172,13 +172,13 @@ describe('CrawledResultItem', () => { // Assert - Checked state shows check icon const checkIcon = container.querySelector('[data-testid^="check-icon"]') - expect(checkIcon).toBeInTheDocument() + expect(checkIcon)!.toBeInTheDocument() }) it('should render preview button when showPreview is true', () => { render() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByRole('button'))!.toBeInTheDocument() }) it('should not render preview button when showPreview is false', () => { @@ -191,15 +191,15 @@ describe('CrawledResultItem', () => { const { container } = render() const item = container.firstChild - expect(item).toHaveClass('bg-state-base-active') + expect(item)!.toHaveClass('bg-state-base-active') }) it('should apply hover styles when isPreview is false', () => { const { container } = render() const item = container.firstChild - expect(item).toHaveClass('group') - expect(item).toHaveClass('hover:bg-state-base-hover') + expect(item)!.toHaveClass('group') + expect(item)!.toHaveClass('hover:bg-state-base-hover') }) }) @@ -209,7 +209,7 @@ describe('CrawledResultItem', () => { render() - expect(screen.getByText('Custom Title')).toBeInTheDocument() + expect(screen.getByText('Custom Title'))!.toBeInTheDocument() }) it('should display payload source_url', () => { @@ -217,7 +217,7 @@ describe('CrawledResultItem', () => { render() - expect(screen.getByText('https://custom.url/path')).toBeInTheDocument() + expect(screen.getByText('https://custom.url/path'))!.toBeInTheDocument() }) it('should set title attribute for truncation tooltip', () => { @@ -226,7 +226,7 @@ describe('CrawledResultItem', () => { render() const titleElement = screen.getByText('Very Long Title') - expect(titleElement).toHaveAttribute('title', 'Very Long Title') + expect(titleElement)!.toHaveAttribute('title', 'Very Long Title') }) }) @@ -310,22 +310,24 @@ describe('CrawledResult', () => { render() // Assert - Check for time info which contains total count - expect(screen.getByText(/1.5/)).toBeInTheDocument() + // Assert - Check for time info which contains total count + expect(screen.getByText(/1.5/))!.toBeInTheDocument() }) it('should render all list items', () => { render() - expect(screen.getByText('Page 1')).toBeInTheDocument() - expect(screen.getByText('Page 2')).toBeInTheDocument() - expect(screen.getByText('Page 3')).toBeInTheDocument() + expect(screen.getByText('Page 1'))!.toBeInTheDocument() + expect(screen.getByText('Page 2'))!.toBeInTheDocument() + expect(screen.getByText('Page 3'))!.toBeInTheDocument() }) it('should display scrape time info', () => { render() // Assert - Check for the time display - expect(screen.getByText(/2.5/)).toBeInTheDocument() + // Assert - Check for the time display + expect(screen.getByText(/2.5/))!.toBeInTheDocument() }) it('should render select all checkbox when isMultipleChoice is true', () => { @@ -350,7 +352,7 @@ describe('CrawledResult', () => { it('should show "Select All" when not all items are checked', () => { render() - expect(screen.getByText(/selectAll|Select All/i)).toBeInTheDocument() + expect(screen.getByText(/selectAll|Select All/i))!.toBeInTheDocument() }) it('should show "Reset All" when all items are checked', () => { @@ -358,7 +360,7 @@ describe('CrawledResult', () => { render() - expect(screen.getByText(/resetAll|Reset All/i)).toBeInTheDocument() + expect(screen.getByText(/resetAll|Reset All/i))!.toBeInTheDocument() }) }) @@ -368,7 +370,7 @@ describe('CrawledResult', () => { , ) - expect(container.firstChild).toHaveClass('custom-class') + expect(container.firstChild)!.toHaveClass('custom-class') }) it('should highlight item at previewIndex', () => { @@ -378,7 +380,7 @@ describe('CrawledResult', () => { // Assert - Second item should have active state const items = container.querySelectorAll('[class*="rounded-lg"][class*="cursor-pointer"]') - expect(items[1]).toHaveClass('bg-state-base-active') + expect(items[1])!.toHaveClass('bg-state-base-active') }) it('should pass showPreview to items', () => { @@ -411,7 +413,7 @@ describe('CrawledResult', () => { // Act - Click select all checkbox (first checkbox) const checkboxes = container.querySelectorAll('[data-testid^="checkbox"]') - fireEvent.click(checkboxes[0]) + fireEvent.click(checkboxes[0]!) expect(mockOnSelectedChange).toHaveBeenCalledWith(list) }) @@ -429,7 +431,7 @@ describe('CrawledResult', () => { ) const checkboxes = container.querySelectorAll('[data-testid^="checkbox"]') - fireEvent.click(checkboxes[0]) + fireEvent.click(checkboxes[0]!) expect(mockOnSelectedChange).toHaveBeenCalledWith([]) }) @@ -441,14 +443,14 @@ describe('CrawledResult', () => { , ) // Act - Click second item checkbox (index 2, accounting for select all) const checkboxes = container.querySelectorAll('[data-testid^="checkbox"]') - fireEvent.click(checkboxes[2]) + fireEvent.click(checkboxes[2]!) expect(mockOnSelectedChange).toHaveBeenCalledWith([list[0], list[1]]) }) @@ -460,14 +462,14 @@ describe('CrawledResult', () => { , ) // Act - Uncheck first item (index 1, after select all) const checkboxes = container.querySelectorAll('[data-testid^="checkbox"]') - fireEvent.click(checkboxes[1]) + fireEvent.click(checkboxes[1]!) expect(mockOnSelectedChange).toHaveBeenCalledWith([list[1]]) }) @@ -479,7 +481,7 @@ describe('CrawledResult', () => { , @@ -487,7 +489,7 @@ describe('CrawledResult', () => { // Act - Click second item radio (Radio uses size-4 rounded-full classes) const radios = container.querySelectorAll('.size-4.rounded-full') - fireEvent.click(radios[1]) + fireEvent.click(radios[1]!) // Assert - Should only select the clicked item expect(mockOnSelectedChange).toHaveBeenCalledWith([list[1]]) @@ -506,7 +508,7 @@ describe('CrawledResult', () => { ) const buttons = screen.getAllByRole('button') - fireEvent.click(buttons[1]) // Second item's preview button + fireEvent.click(buttons[1]!) // Second item's preview button expect(mockOnPreview).toHaveBeenCalledWith(list[1], 1) }) @@ -525,10 +527,11 @@ describe('CrawledResult', () => { // Act - Click preview button should trigger early return in handlePreview const buttons = screen.getAllByRole('button') - fireEvent.click(buttons[0]) + fireEvent.click(buttons[0]!) // Assert - Should not throw error, component still renders - expect(screen.getByText('Page 1')).toBeInTheDocument() + // Assert - Should not throw error, component still renders + expect(screen.getByText('Page 1'))!.toBeInTheDocument() }) }) @@ -537,7 +540,8 @@ describe('CrawledResult', () => { render() // Assert - Should show time info with 0 count - expect(screen.getByText(/0.5/)).toBeInTheDocument() + // Assert - Should show time info with 0 count + expect(screen.getByText(/0.5/))!.toBeInTheDocument() }) it('should handle single item list', () => { @@ -545,13 +549,13 @@ describe('CrawledResult', () => { render() - expect(screen.getByText('Test Page Title')).toBeInTheDocument() + expect(screen.getByText('Test Page Title'))!.toBeInTheDocument() }) it('should format usedTime to one decimal place', () => { render() - expect(screen.getByText(/1.6/)).toBeInTheDocument() + expect(screen.getByText(/1.6/))!.toBeInTheDocument() }) }) }) @@ -571,13 +575,13 @@ describe('Crawling', () => { it('should render without crashing', () => { render() - expect(screen.getByText(/5\/10/)).toBeInTheDocument() + expect(screen.getByText(/5\/10/))!.toBeInTheDocument() }) it('should display crawled count and total', () => { render() - expect(screen.getByText(/3\/15/)).toBeInTheDocument() + expect(screen.getByText(/3\/15/))!.toBeInTheDocument() }) it('should render skeleton items', () => { @@ -602,19 +606,19 @@ describe('Crawling', () => { , ) - expect(container.firstChild).toHaveClass('custom-crawling-class') + expect(container.firstChild)!.toHaveClass('custom-crawling-class') }) it('should handle zero values', () => { render() - expect(screen.getByText(/0\/0/)).toBeInTheDocument() + expect(screen.getByText(/0\/0/))!.toBeInTheDocument() }) it('should handle large numbers', () => { render() - expect(screen.getByText(/999\/1000/)).toBeInTheDocument() + expect(screen.getByText(/999\/1000/))!.toBeInTheDocument() }) }) @@ -623,9 +627,10 @@ describe('Crawling', () => { const { container } = render() // Assert - Check for various width classes - expect(container.querySelector('.w-\\[35\\%\\]')).toBeInTheDocument() - expect(container.querySelector('.w-\\[50\\%\\]')).toBeInTheDocument() - expect(container.querySelector('.w-\\[40\\%\\]')).toBeInTheDocument() + // Assert - Check for various width classes + expect(container.querySelector('.w-\\[35\\%\\]'))!.toBeInTheDocument() + expect(container.querySelector('.w-\\[50\\%\\]'))!.toBeInTheDocument() + expect(container.querySelector('.w-\\[40\\%\\]'))!.toBeInTheDocument() }) }) }) @@ -644,27 +649,27 @@ describe('ErrorMessage', () => { it('should render without crashing', () => { render() - expect(screen.getByText('Error Title')).toBeInTheDocument() + expect(screen.getByText('Error Title'))!.toBeInTheDocument() }) it('should render error icon', () => { const { container } = render() const icon = container.querySelector('svg') - expect(icon).toBeInTheDocument() - expect(icon).toHaveClass('text-text-destructive') + expect(icon)!.toBeInTheDocument() + expect(icon)!.toHaveClass('text-text-destructive') }) it('should render title', () => { render() - expect(screen.getByText('Custom Error Title')).toBeInTheDocument() + expect(screen.getByText('Custom Error Title'))!.toBeInTheDocument() }) it('should render error message when provided', () => { render() - expect(screen.getByText('Detailed error description')).toBeInTheDocument() + expect(screen.getByText('Detailed error description'))!.toBeInTheDocument() }) it('should not render error message when not provided', () => { @@ -682,14 +687,15 @@ describe('ErrorMessage', () => { , ) - expect(container.firstChild).toHaveClass('custom-error-class') + expect(container.firstChild)!.toHaveClass('custom-error-class') }) it('should render with empty errorMsg', () => { render() // Assert - Empty string should not render message div - expect(screen.getByText('Error Title')).toBeInTheDocument() + // Assert - Empty string should not render message div + expect(screen.getByText('Error Title'))!.toBeInTheDocument() }) it('should handle long title text', () => { @@ -697,7 +703,7 @@ describe('ErrorMessage', () => { render() - expect(screen.getByText(longTitle)).toBeInTheDocument() + expect(screen.getByText(longTitle))!.toBeInTheDocument() }) it('should handle long error message', () => { @@ -705,7 +711,7 @@ describe('ErrorMessage', () => { render() - expect(screen.getByText(longErrorMsg)).toBeInTheDocument() + expect(screen.getByText(longErrorMsg))!.toBeInTheDocument() }) }) @@ -713,19 +719,19 @@ describe('ErrorMessage', () => { it('should have error background styling', () => { const { container } = render() - expect(container.firstChild).toHaveClass('bg-toast-error-bg') + expect(container.firstChild)!.toHaveClass('bg-toast-error-bg') }) it('should have border styling', () => { const { container } = render() - expect(container.firstChild).toHaveClass('border-components-panel-border') + expect(container.firstChild)!.toHaveClass('border-components-panel-border') }) it('should have rounded-sm corners', () => { const { container } = render() - expect(container.firstChild).toHaveClass('rounded-xl') + expect(container.firstChild)!.toHaveClass('rounded-xl') }) }) }) @@ -744,8 +750,9 @@ describe('Base Components Integration', () => { ) // Assert - Both items should render - expect(screen.getByText('Page 1')).toBeInTheDocument() - expect(screen.getByText('Page 2')).toBeInTheDocument() + // Assert - Both items should render + expect(screen.getByText('Page 1'))!.toBeInTheDocument() + expect(screen.getByText('Page 2'))!.toBeInTheDocument() }) it('should render CrawledResult with CheckboxWithLabel for select all', () => { @@ -784,13 +791,13 @@ describe('Base Components Integration', () => { // Act - Select first item (index 1, after select all) const checkboxes = container.querySelectorAll('[data-testid^="checkbox"]') - fireEvent.click(checkboxes[1]) + fireEvent.click(checkboxes[1]!) expect(mockOnSelectedChange).toHaveBeenCalledWith([list[0]]) // Act - Preview second item const previewButtons = screen.getAllByRole('button') - fireEvent.click(previewButtons[1]) + fireEvent.click(previewButtons[1]!) expect(mockOnPreview).toHaveBeenCalledWith(list[1], 1) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result-item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result-item.tsx index 664a251e25..19019b4dd7 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result-item.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result-item.tsx @@ -1,12 +1,12 @@ 'use client' import type { CrawlResultItem as CrawlResultItemType } from '@/models/datasets' +import { Button } from '@langgenius/dify-ui/button' import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import Checkbox from '@/app/components/base/checkbox' import Radio from '@/app/components/base/radio/ui' -import { Button } from '@/app/components/base/ui/button' type CrawledResultItemProps = { payload: CrawlResultItemType diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result.tsx index 8cd9a46054..b91c5bd6bc 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/crawled-result.tsx @@ -59,7 +59,7 @@ const CrawledResult = ({ const handlePreview = useCallback((index: number) => { if (!onPreview) return - onPreview(list[index], index) + onPreview(list[index]!, index) }, [list, onPreview]) return ( diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/__tests__/index.spec.tsx index cea569fa5f..2ab8ad6d4b 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/__tests__/index.spec.tsx @@ -10,8 +10,8 @@ const { mockToastError } = vi.hoisted(() => ({ mockToastError: vi.fn(), })) -vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@langgenius/dify-ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, toast: { diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx index 899c70e216..46c6c6f462 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/website-crawl/base/options/index.tsx @@ -1,5 +1,7 @@ import type { RAGPipelineVariables } from '@/models/pipeline' +import { Button } from '@langgenius/dify-ui/button' import { cn } from '@langgenius/dify-ui/cn' +import { toast } from '@langgenius/dify-ui/toast' import { RiPlayLargeLine } from '@remixicon/react' import { useBoolean } from 'ahooks' import { useEffect, useMemo } from 'react' @@ -8,8 +10,6 @@ import { useAppForm } from '@/app/components/base/form' import BaseField from '@/app/components/base/form/form-scenarios/base/field' import { generateZodSchema } from '@/app/components/base/form/form-scenarios/base/utils' import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general' -import { Button } from '@/app/components/base/ui/button' -import { toast } from '@/app/components/base/ui/toast' import { useConfigurations, useInitialData } from '@/app/components/rag-pipeline/hooks/use-input-fields' import { CrawlStep } from '@/models/datasets' @@ -43,7 +43,7 @@ const Options = ({ if (!result.success) { const issues = result.error.issues const firstIssue = issues[0] - const errorMessage = `"${firstIssue.path.join('.')}" ${firstIssue.message}` + const errorMessage = `"${firstIssue!.path.join('.')}" ${firstIssue!.message}` toast.error(errorMessage) return errorMessage } diff --git a/web/app/components/datasets/documents/create-from-pipeline/hooks/__tests__/use-datasource-store.spec.ts b/web/app/components/datasets/documents/create-from-pipeline/hooks/__tests__/use-datasource-store.spec.ts index 155b41541b..70d95000d7 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/hooks/__tests__/use-datasource-store.spec.ts +++ b/web/app/components/datasets/documents/create-from-pipeline/hooks/__tests__/use-datasource-store.spec.ts @@ -85,7 +85,7 @@ describe('useOnlineDocument', () => { const { result } = renderHook(() => useOnlineDocument(), { wrapper: createWrapper(store) }) expect(result.current.PagesMapAndSelectedPagesId).toHaveProperty('p1') - expect(result.current.PagesMapAndSelectedPagesId.p1.workspace_id).toBe('w1') + expect(result.current.PagesMapAndSelectedPagesId.p1!.workspace_id).toBe('w1') }) it('should hide preview online document', () => { diff --git a/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-actions.ts b/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-actions.ts index 66bd325c33..bad3b779fa 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-actions.ts +++ b/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-actions.ts @@ -251,7 +251,7 @@ export const useDatasourceActions = ({ if (datasourceType === DatasourceType.onlineDocument) { const allIds = currentWorkspacePages?.map(page => page.page_id) || [] if (onlineDocuments.length < allIds.length) { - const selectedPages = Array.from(allIds).map(pageId => PagesMapAndSelectedPagesId[pageId]) + const selectedPages = Array.from(allIds).map(pageId => PagesMapAndSelectedPagesId[pageId]!) setOnlineDocuments(selectedPages) setSelectedPagesId(new Set(allIds)) } diff --git a/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-ui-state.ts b/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-ui-state.ts index e398f90a48..f4c222f652 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-ui-state.ts +++ b/web/app/components/datasets/documents/create-from-pipeline/hooks/use-datasource-ui-state.ts @@ -122,7 +122,7 @@ export const useDatasourceUIState = ({ return { datasourceType, - isShowVectorSpaceFull, + isShowVectorSpaceFull: isShowVectorSpaceFull!, nextBtnDisabled, showSelect, totalOptions, diff --git a/web/app/components/datasets/documents/create-from-pipeline/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/index.tsx index 62c1b919fe..8b8fad5885 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/index.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/index.tsx @@ -133,7 +133,7 @@ const CreateFormPipeline = () => { [DatasourceType.onlineDrive]: selectedFileIds.length, } const count = datasourceType ? multipleCheckMap[datasourceType] : 0 - if (count > 1) { + if (count! > 1) { showPlanUpgradeModal() return } diff --git a/web/app/components/datasets/documents/create-from-pipeline/left-header.tsx b/web/app/components/datasets/documents/create-from-pipeline/left-header.tsx index 925da57197..ab432d6962 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/left-header.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/left-header.tsx @@ -1,8 +1,8 @@ import type { Step } from './step-indicator' +import { Button } from '@langgenius/dify-ui/button' import { RiArrowLeftLine } from '@remixicon/react' import * as React from 'react' import Effect from '@/app/components/base/effect' -import { Button } from '@/app/components/base/ui/button' import Link from '@/next/link' import { useParams } from '@/next/navigation' import StepIndicator from './step-indicator' diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/file-preview.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/file-preview.spec.tsx index 715d1650df..82011102b6 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/file-preview.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/file-preview.spec.tsx @@ -39,30 +39,30 @@ describe('FilePreview', () => { it('should render preview label', () => { render() - expect(screen.getByText('datasetPipeline.addDocuments.stepOne.preview')).toBeInTheDocument() + expect(screen.getByText('datasetPipeline.addDocuments.stepOne.preview'))!.toBeInTheDocument() }) it('should render file name', () => { render() - expect(screen.getByText('document.pdf')).toBeInTheDocument() + expect(screen.getByText('document.pdf'))!.toBeInTheDocument() }) it('should render file content when loaded', () => { render() - expect(screen.getByText('file content here with some text')).toBeInTheDocument() + expect(screen.getByText('file content here with some text'))!.toBeInTheDocument() }) it('should render loading state', () => { mockIsFetching = true render() - expect(screen.getByTestId('loading')).toBeInTheDocument() + expect(screen.getByTestId('loading'))!.toBeInTheDocument() }) it('should call hidePreview when close button clicked', () => { render() const buttons = screen.getAllByRole('button') const closeBtn = buttons[buttons.length - 1] - fireEvent.click(closeBtn) + fireEvent.click(closeBtn!) expect(defaultProps.hidePreview).toHaveBeenCalled() }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/online-document-preview.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/online-document-preview.spec.tsx index 1e094fedb0..96033a6f60 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/online-document-preview.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/online-document-preview.spec.tsx @@ -9,8 +9,8 @@ const { mockToastError } = vi.hoisted(() => ({ mockToastError: vi.fn(), })) -vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@langgenius/dify-ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, toast: { diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/web-preview.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/web-preview.spec.tsx index 1f59e11035..7bef4d0716 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/web-preview.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/preview/__tests__/web-preview.spec.tsx @@ -20,29 +20,29 @@ describe('WebPreview', () => { it('should render preview label', () => { render() - expect(screen.getByText('datasetPipeline.addDocuments.stepOne.preview')).toBeInTheDocument() + expect(screen.getByText('datasetPipeline.addDocuments.stepOne.preview'))!.toBeInTheDocument() }) it('should render page title', () => { render() - expect(screen.getByText('Test Page')).toBeInTheDocument() + expect(screen.getByText('Test Page'))!.toBeInTheDocument() }) it('should render source URL', () => { render() - expect(screen.getByText('https://example.com')).toBeInTheDocument() + expect(screen.getByText('https://example.com'))!.toBeInTheDocument() }) it('should render markdown content', () => { render() - expect(screen.getByText('Hello **markdown** content')).toBeInTheDocument() + expect(screen.getByText('Hello **markdown** content'))!.toBeInTheDocument() }) it('should call hidePreview when close button clicked', () => { render() const buttons = screen.getAllByRole('button') const closeBtn = buttons[buttons.length - 1] - fireEvent.click(closeBtn) + fireEvent.click(closeBtn!) expect(defaultProps.hidePreview).toHaveBeenCalled() }) }) diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/chunk-preview.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/chunk-preview.tsx index a56e5bd6dc..9b56c6c8a3 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/preview/chunk-preview.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/preview/chunk-preview.tsx @@ -1,13 +1,13 @@ import type { NotionPage } from '@/models/common' import type { CrawlResultItem, CustomFile, DocumentItem, FileIndexingEstimateResponse } from '@/models/datasets' import type { OnlineDriveFile } from '@/models/pipeline' +import { Button } from '@langgenius/dify-ui/button' import { RiSearchEyeLine } from '@remixicon/react' import * as React from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Badge from '@/app/components/base/badge' import { SkeletonContainer, SkeletonPoint, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton' -import { Button } from '@/app/components/base/ui/button' import SummaryLabel from '@/app/components/datasets/documents/detail/completed/common/summary-label' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { ChunkingMode } from '@/models/datasets' @@ -55,9 +55,9 @@ const ChunkPreview = ({ const currentDocForm = useDatasetDetailContextWithSelector(s => s.dataset?.doc_form) const [previewFile, setPreviewFile] = useState(localFiles[0] as DocumentItem) - const [previewOnlineDocument, setPreviewOnlineDocument] = useState(onlineDocuments[0]) - const [previewWebsitePage, setPreviewWebsitePage] = useState(websitePages[0]) - const [previewOnlineDriveFile, setPreviewOnlineDriveFile] = useState(onlineDriveFiles[0]) + const [previewOnlineDocument, setPreviewOnlineDocument] = useState(onlineDocuments[0]!) + const [previewWebsitePage, setPreviewWebsitePage] = useState(websitePages[0]!) + const [previewOnlineDriveFile, setPreviewOnlineDriveFile] = useState(onlineDriveFiles[0]!) return ( -
+
+
{t('addDocuments.stepOne.preview', { ns: 'datasetPipeline' })}
-
{`${fileName}.${file.extension || ''}`}
-
+
{`${fileName}.${file.extension || ''}`}
+
)} {!isFetching && fileData && ( -
+
{fileData.content}
)} diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/online-document-preview.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/online-document-preview.tsx index ff2f9f46a4..793ba1f22b 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/preview/online-document-preview.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/preview/online-document-preview.tsx @@ -1,12 +1,12 @@ 'use client' import type { NotionPage } from '@/models/common' +import { toast } from '@langgenius/dify-ui/toast' import { RiCloseLine } from '@remixicon/react' import * as React from 'react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { Notion } from '@/app/components/base/icons/src/public/common' import { Markdown } from '@/app/components/base/markdown' -import { toast } from '@/app/components/base/ui/toast' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { usePreviewOnlineDocument } from '@/service/use-pipeline' import { formatNumberAbbreviated } from '@/utils/format' @@ -50,12 +50,12 @@ const OnlineDocumentPreview = ({ }, [currentPage.page_id]) return ( -
-
+
+
{t('addDocuments.stepOne.preview', { ns: 'datasetPipeline' })}
-
{currentPage?.page_name}
-
+
{currentPage?.page_name}
+
{currentPage.type} · @@ -76,7 +76,7 @@ const OnlineDocumentPreview = ({
)} {!isPending && content && ( -
+
)} diff --git a/web/app/components/datasets/documents/create-from-pipeline/preview/web-preview.tsx b/web/app/components/datasets/documents/create-from-pipeline/preview/web-preview.tsx index a527f12c0d..f83d68902f 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/preview/web-preview.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/preview/web-preview.tsx @@ -17,12 +17,12 @@ const WebsitePreview = ({ const { t } = useTranslation() return ( -
-
+
+
{t('addDocuments.stepOne.preview', { ns: 'datasetPipeline' })}
-
{currentWebsite.title}
-
+
{currentWebsite.title}
+
{currentWebsite.source_url} · @@ -38,7 +38,7 @@ const WebsitePreview = ({
-
+
{currentWebsite.markdown}
diff --git a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/components.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/components.spec.tsx index ff5f8afa66..16b6ef1373 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/components.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/components.spec.tsx @@ -11,8 +11,8 @@ const { mockToastError } = vi.hoisted(() => ({ mockToastError: vi.fn(), })) -vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@langgenius/dify-ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, toast: { diff --git a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/form.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/form.spec.tsx index 09f28fc5da..dc54ba2757 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/form.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/form.spec.tsx @@ -8,8 +8,8 @@ const { mockToastError } = vi.hoisted(() => ({ mockToastError: vi.fn(), })) -vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { - const actual = await importOriginal() +vi.mock('@langgenius/dify-ui/toast', async (importOriginal) => { + const actual = await importOriginal() return { ...actual, toast: { diff --git a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/header.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/header.spec.tsx index 7e9eabaeda..431fa76f2c 100644 --- a/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/header.spec.tsx +++ b/web/app/components/datasets/documents/create-from-pipeline/process-documents/__tests__/header.spec.tsx @@ -2,7 +2,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import Header from '../header' -vi.mock('@/app/components/base/ui/button', () => ({ +vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children, onClick, disabled, variant }: { children: React.ReactNode, onClick: () => void, disabled?: boolean, variant: string }) => (