diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md
index 4da070bdbf..105c979c58 100644
--- a/.agents/skills/frontend-testing/SKILL.md
+++ b/.agents/skills/frontend-testing/SKILL.md
@@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path:
- ✅ **Import real project components** directly (including base components and siblings)
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
-- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
+- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`)
- ❌ **DO NOT mock** sibling/child components in the same directory
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
@@ -325,12 +325,12 @@ For more detailed information, refer to:
### Reference Examples in Codebase
- `web/utils/classnames.spec.ts` - Utility function tests
-- `web/app/components/base/button/index.spec.tsx` - Component tests
+- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests
- `web/__mocks__/provider-context.ts` - Mock factory example
### Project Configuration
-- `web/vitest.config.ts` - Vitest configuration
+- `web/vite.config.ts` - Vite/Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/scripts/analyze-component.js` - Component analysis tool
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md
index 10b8fb66f9..99258498dd 100644
--- a/.agents/skills/frontend-testing/references/checklist.md
+++ b/.agents/skills/frontend-testing/references/checklist.md
@@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Integration vs Mocking
-- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
+- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.)
- [ ] Import real project components instead of mocking
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
- [ ] Prefer integration testing when using single spec file
@@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Mocks
-- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
+- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`)
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md
index f58377c4a5..8c2f1c0c58 100644
--- a/.agents/skills/frontend-testing/references/mocking.md
+++ b/.agents/skills/frontend-testing/references/mocking.md
@@ -2,29 +2,27 @@
## ⚠️ Important: What NOT to Mock
-### DO NOT Mock Base Components
+### DO NOT Mock Base Components or dify-ui Primitives
-**Never mock components from `@/app/components/base/`** such as:
+**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as:
-- `Loading`, `Spinner`
-- `Button`, `Input`, `Select`
-- `Tooltip`, `Modal`, `Dropdown`
-- `Icon`, `Badge`, `Tag`
+- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag`
+- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast`
**Why?**
-- Base components will have their own dedicated tests
+- These components have their own dedicated tests
- Mocking them creates false positives (tests pass but real integration fails)
- Using real components tests actual integration behavior
```typescript
-// ❌ WRONG: Don't mock base components
+// ❌ WRONG: Don't mock base components or dify-ui primitives
vi.mock('@/app/components/base/loading', () => () =>
Loading
)
-vi.mock('@/app/components/base/button', () => ({ children }: any) => )
+vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => }))
-// ✅ CORRECT: Import and use real base components
+// ✅ CORRECT: Import and use the real components
import Loading from '@/app/components/base/loading'
-import Button from '@/app/components/base/button'
+import { Button } from '@langgenius/dify-ui/button'
// They will render normally in tests
```
@@ -319,7 +317,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ✅ DO
-1. **Use real base components** - Import from `@/app/components/base/` directly
+1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly
1. **Use real project components** - Prefer importing over mocking
1. **Use real Zustand stores** - Set test state via `store.setState()`
1. **Reset mocks in `beforeEach`**, not `afterEach`
@@ -330,7 +328,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ❌ DON'T
-1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
+1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.)
1. **Don't mock Zustand store modules** - Use real stores with `setState()`
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
@@ -342,7 +340,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
```
Need to use a component in test?
│
-├─ Is it from @/app/components/base/*?
+├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*?
│ └─ YES → Import real component, DO NOT mock
│
├─ Is it a project component?
diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml
deleted file mode 100644
index b0f0a36bc9..0000000000
--- a/.github/workflows/anti-slop.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Anti-Slop PR Check
-
-on:
- pull_request_target:
- types: [opened, edited, synchronize]
-
-permissions:
- pull-requests: write
- contents: read
-
-jobs:
- anti-slop:
- runs-on: ubuntu-latest
- steps:
- - uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
- with:
- github-token: ${{ secrets.GITHUB_TOKEN }}
- close-pr: false
- failure-add-pr-labels: "needs-revision"
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index fd910531db..717413937f 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -35,7 +35,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -84,7 +84,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -105,7 +105,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@@ -156,7 +156,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml
index 772ab8dd56..35683b112f 100644
--- a/.github/workflows/autofix.yml
+++ b/.github/workflows/autofix.yml
@@ -25,7 +25,7 @@ jobs:
- name: Check Docker Compose inputs
if: github.event_name != 'merge_group'
id: docker-compose-changes
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
docker/generate_docker_compose
@@ -35,7 +35,7 @@ jobs:
- name: Check web inputs
if: github.event_name != 'merge_group'
id: web-changes
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
@@ -48,7 +48,7 @@ jobs:
- name: Check api inputs
if: github.event_name != 'merge_group'
id: api-changes
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
api/**
@@ -58,7 +58,7 @@ jobs:
python-version: "3.11"
- if: github.event_name != 'merge_group'
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
- name: Generate Docker Compose
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
@@ -120,8 +120,7 @@ jobs:
- name: ESLint autofix
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
run: |
- cd web
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
- if: github.event_name != 'merge_group'
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
+ uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4
diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml
index 5991abe3ba..17b867dd6d 100644
--- a/.github/workflows/db-migration-test.yml
+++ b/.github/workflows/db-migration-test.yml
@@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
@@ -40,7 +40,7 @@ jobs:
cp middleware.env.example middleware.env
- name: Set up Middlewares
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
@@ -94,7 +94,7 @@ jobs:
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml
index eefb1ebbb9..c55b013dbe 100644
--- a/.github/workflows/pyrefly-diff-comment.yml
+++ b/.github/workflows/pyrefly-diff-comment.yml
@@ -76,13 +76,11 @@ jobs:
diff += '\\n\\n... (truncated) ...';
}
- const body = diff.trim()
- ? '### Pyrefly Diff\n\nbase → PR\n\n```diff\n' + diff + '\n```\n'
- : '### Pyrefly Diff\nNo changes detected.';
-
- await github.rest.issues.createComment({
- issue_number: prNumber,
- owner: context.repo.owner,
- repo: context.repo.repo,
- body,
- });
+ if (diff.trim()) {
+ await github.rest.issues.createComment({
+ issue_number: prNumber,
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ body: '### Pyrefly Diff\n\nbase → PR\n\n```diff\n' + diff + '\n```\n',
+ });
+ }
diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml
index ac3732579c..eb15cd6f75 100644
--- a/.github/workflows/pyrefly-diff.yml
+++ b/.github/workflows/pyrefly-diff.yml
@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml
index 974da99aad..3c6c96a664 100644
--- a/.github/workflows/pyrefly-type-coverage-comment.yml
+++ b/.github/workflows/pyrefly-type-coverage-comment.yml
@@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml
index c795c32e31..0599c94eef 100644
--- a/.github/workflows/pyrefly-type-coverage.yml
+++ b/.github/workflows/pyrefly-type-coverage.yml
@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index c32fc9d0cb..d8c7ebbad3 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -25,7 +25,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
api/**
@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: false
python-version: "3.12"
@@ -73,10 +73,12 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
+ e2e/**
+ sdks/nodejs-client/**
packages/**
package.json
pnpm-lock.yaml
@@ -93,16 +95,16 @@ jobs:
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
- uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
+ uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
- path: web/.eslintcache
- key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
+ path: .eslintcache
+ key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
- ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
+ ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
- working-directory: ./web
+ working-directory: .
run: vp run lint:ci
- name: Web tsslint
@@ -112,7 +114,7 @@ jobs:
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
- working-directory: ./web
+ working-directory: .
run: vp run type-check
- name: Web dead code check
@@ -122,9 +124,9 @@ jobs:
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
- uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
+ uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
- path: web/.eslintcache
+ path: .eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:
@@ -140,7 +142,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
**.sh
diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml
index 467f31fccf..bf33207a14 100644
--- a/.github/workflows/tool-test-sdks.yaml
+++ b/.github/workflows/tool-test-sdks.yaml
@@ -30,7 +30,7 @@ jobs:
persist-credentials: false
- name: Use Node.js
- uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
+ uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
with:
node-version: 22
cache: ''
diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml
index 541200293d..eecbbb1a56 100644
--- a/.github/workflows/translate-i18n-claude.yml
+++ b/.github/workflows/translate-i18n-claude.yml
@@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
- uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
+ uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/vdb-tests-full.yml b/.github/workflows/vdb-tests-full.yml
index f0def8fe7a..b79e8927d7 100644
--- a/.github/workflows/vdb-tests-full.yml
+++ b/.github/workflows/vdb-tests-full.yml
@@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -65,7 +65,7 @@ jobs:
# tiflash
- name: Set up Full Vector Store Matrix
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml
diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml
index f3966f15b9..bd13d662c3 100644
--- a/.github/workflows/vdb-tests.yml
+++ b/.github/workflows/vdb-tests.yml
@@ -33,7 +33,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -62,7 +62,7 @@ jobs:
# tiflash
- name: Set up Vector Stores for Smoke Coverage
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml
diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml
index 10dc31bde8..6bd4d4f406 100644
--- a/.github/workflows/web-e2e.yml
+++ b/.github/workflows/web-e2e.yml
@@ -28,7 +28,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml
index f3ab4c62c7..2a5cf19645 100644
--- a/.github/workflows/web-tests.yml
+++ b/.github/workflows/web-tests.yml
@@ -89,3 +89,37 @@ jobs:
flags: web
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
+
+ dify-ui-test:
+ name: dify-ui Tests
+ runs-on: ubuntu-latest
+ env:
+ CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
+ defaults:
+ run:
+ shell: bash
+ working-directory: ./packages/dify-ui
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
+ with:
+ persist-credentials: false
+
+ - name: Setup web environment
+ uses: ./.github/actions/setup-web
+
+ - name: Install Chromium for Browser Mode
+ run: vp exec playwright install --with-deps chromium
+
+ - name: Run dify-ui tests
+ run: vp test run --coverage --silent=passed-only
+
+ - name: Report coverage
+ if: ${{ env.CODECOV_TOKEN != '' }}
+ uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
+ with:
+ directory: packages/dify-ui/coverage
+ flags: dify-ui
+ env:
+ CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
diff --git a/.gitignore b/.gitignore
index 53dea88899..836bddbb49 100644
--- a/.gitignore
+++ b/.gitignore
@@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json.template
+!.vscode/settings.example.json
!.vscode/README.md
api/.vscode
# vscode Code History Extension
@@ -236,9 +237,15 @@ scripts/stress-test/reports/
.playwright-mcp/
.serena/
+# vitest browser mode attachments (failure screenshots, traces, etc.)
+.vitest-attachments/
+**/__screenshots__/
+
# settings
*.local.json
*.local.md
# Code Agent Folder
.qoder/*
+
+.eslintcache
diff --git a/.vite-hooks/pre-commit b/.vite-hooks/pre-commit
index 13bbd81cf6..d48381bce2 100755
--- a/.vite-hooks/pre-commit
+++ b/.vite-hooks/pre-commit
@@ -56,44 +56,9 @@ if $api_modified; then
fi
fi
-if $web_modified; then
- if $skip_web_checks; then
- echo "Git operation in progress, skipping web checks"
- exit 0
- fi
-
- echo "Running ESLint on web module"
-
- if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then
- web_ts_modified=false
- else
- ts_diff_status=$?
- if [ $ts_diff_status -eq 1 ]; then
- web_ts_modified=true
- else
- echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)."
- exit $ts_diff_status
- fi
- fi
-
- cd ./web || exit 1
- vp staged
-
- if $web_ts_modified; then
- echo "Running TypeScript type-check:tsgo"
- if ! npm run type-check:tsgo; then
- echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
- exit 1
- fi
- else
- echo "No staged TypeScript changes detected, skipping type-check:tsgo"
- fi
-
- echo "Running knip"
- if ! npm run knip; then
- echo "Knip check failed. Please run 'npm run knip' to fix the errors."
- exit 1
- fi
-
- cd ../
+if $skip_web_checks; then
+ echo "Git operation in progress, skipping web checks"
+ exit 0
fi
+
+vp staged
diff --git a/web/.vscode/settings.example.json b/.vscode/settings.example.json
similarity index 86%
rename from web/.vscode/settings.example.json
rename to .vscode/settings.example.json
index 4b356f5b7a..7cdbc51a3b 100644
--- a/web/.vscode/settings.example.json
+++ b/.vscode/settings.example.json
@@ -1,12 +1,16 @@
{
- // Disable the default formatter, use eslint instead
- "prettier.enable": false,
- "editor.formatOnSave": false,
+ "cucumber.features": [
+ "e2e/features/**/*.feature",
+ ],
+ "cucumber.glue": [
+ "e2e/features/**/*.ts",
+ ],
+
+ "tailwindCSS.experimental.configFile": "web/app/styles/globals.css",
// Auto fix
"editor.codeActionsOnSave": {
"source.fixAll.eslint": "explicit",
- "source.organizeImports": "never"
},
// Silent the stylistic rules in your IDE, but still auto fix them
diff --git a/api/commands/account.py b/api/commands/account.py
index 6a2a2e0428..761323a73d 100644
--- a/api/commands/account.py
+++ b/api/commands/account.py
@@ -2,6 +2,7 @@ import base64
import secrets
import click
+from sqlalchemy.orm import Session
from constants.languages import languages
from extensions.ext_database import db
@@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm):
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
- account = db.session.merge(account)
- account.password = base64_password_hashed
- account.password_salt = base64_salt
- db.session.commit()
+ with Session(db.engine) as session:
+ account = session.merge(account)
+ account.password = base64_password_hashed
+ account.password_salt = base64_salt
+ session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
- account = db.session.merge(account)
- account.email = normalized_new_email
- db.session.commit()
+ with Session(db.engine) as session:
+ account = session.merge(account)
+ account.email = normalized_new_email
+ session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
diff --git a/api/constants/dsl_version.py b/api/constants/dsl_version.py
new file mode 100644
index 0000000000..b0fbe0075c
--- /dev/null
+++ b/api/constants/dsl_version.py
@@ -0,0 +1 @@
+CURRENT_APP_DSL_VERSION = "0.6.0"
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 051d08aa36..9102983d86 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -129,6 +129,7 @@ class AppNamePayload(BaseModel):
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
+ icon_type: IconType | None = Field(default=None, description="Icon type")
icon_background: str | None = Field(default=None, description="Icon background color")
@@ -729,7 +730,12 @@ class AppIconApi(Resource):
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
- app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
+ app_model = app_service.update_app_icon(
+ app_model,
+ args.icon or "",
+ args.icon_background or "",
+ args.icon_type,
+ )
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py
index cead33d14f..9c8b095b9f 100644
--- a/api/controllers/console/app/conversation_variables.py
+++ b/api/controllers/console/app/conversation_variables.py
@@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
return value
try:
diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py
index 5b1abc98dc..d517f695b8 100644
--- a/api/controllers/console/app/mcp_server.py
+++ b/api/controllers/console/app/mcp_server.py
@@ -18,12 +18,6 @@ from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
-def _to_timestamp(value: datetime | int | None) -> int | None:
- if isinstance(value, datetime):
- return int(value.timestamp())
- return value
-
-
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
@@ -36,19 +30,25 @@ class MCPServerUpdatePayload(BaseModel):
status: str | None = Field(default=None, description="Server status")
+def _to_timestamp(value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return int(value.timestamp())
+ return value
+
+
class AppMCPServerResponse(ResponseModel):
id: str
name: str
server_code: str
description: str
- status: str
+ status: AppMCPServerStatus
parameters: dict[str, Any] | list[Any] | str
created_at: int | None = None
updated_at: int | None = None
@field_validator("parameters", mode="before")
@classmethod
- def _parse_json_string(cls, value: Any) -> Any:
+ def _normalize_parameters(cls, value: Any) -> Any:
if isinstance(value, str):
try:
return json.loads(value)
@@ -70,7 +70,9 @@ class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
- @console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
+ @console_ns.response(
+ 200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__]
+ )
@login_required
@account_initialization_required
@setup_required
@@ -85,7 +87,9 @@ class AppMCPServerController(Resource):
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
- @console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
+ @console_ns.response(
+ 201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__]
+ )
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@@ -111,13 +115,15 @@ class AppMCPServerController(Resource):
)
db.session.add(server)
db.session.commit()
- return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
+ return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
- @console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
+ @console_ns.response(
+ 200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__]
+ )
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@get_app_model
@@ -154,7 +160,7 @@ class AppMCPServerRefreshController(Resource):
@console_ns.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"})
- @console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
+ @console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@setup_required
diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py
index f6319573e0..e32ba5f66c 100644
--- a/api/controllers/console/app/workflow_draft_variable.py
+++ b/api/controllers/console/app/workflow_draft_variable.py
@@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
- return value_type.exposed_type().value
+ return str(value_type.exposed_type())
class FullContentDict(TypedDict):
@@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
result: FullContentDict = {
"size_bytes": variable_file.size,
- "value_type": variable_file.value_type.exposed_type().value,
+ "value_type": str(variable_file.value_type.exposed_type()),
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
@@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
- "value_type": v.value_type.exposed_type().value,
+ "value_type": str(v.value_type.exposed_type()),
"value": v.value,
# Do not track edited for env vars.
"edited": False,
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index ea0fdef0a7..d001dfba64 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -50,6 +50,7 @@ from fields.dataset_fields import (
from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.login import current_account_with_tenant, login_required
+from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
@@ -889,7 +890,8 @@ class DatasetApiBaseUrlApi(Resource):
@login_required
@account_initialization_required
def get(self):
- return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
+ base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/")
+ return {"api_base_url": normalize_api_base_url(base)}
@console_ns.route("/datasets/retrieval-setting")
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index b61e39716f..3372a967d9 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -3,18 +3,19 @@ import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
from contextlib import ExitStack
+from datetime import datetime
from typing import Any, Literal, cast
import sqlalchemy as sa
from flask import request, send_file
-from flask_restx import Resource, fields, marshal, marshal_with
-from pydantic import BaseModel, Field
+from flask_restx import Resource, marshal
+from pydantic import BaseModel, Field, field_validator
from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
-from controllers.common.schema import get_or_create_model, register_schema_models
+from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from core.errors.error import (
LLMBadRequestError,
@@ -29,11 +30,9 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from extensions.ext_database import db
-from fields.dataset_fields import dataset_fields
+from fields.base import ResponseModel
from fields.document_fields import (
- dataset_and_document_fields,
document_fields,
- document_metadata_fields,
document_status_fields,
document_with_segments_fields,
)
@@ -72,27 +71,100 @@ from ..wraps import (
logger = logging.getLogger(__name__)
-# Register models for flask_restx to avoid dict type issues in Swagger
-dataset_model = get_or_create_model("Dataset", dataset_fields)
+def _to_timestamp(value: datetime | int | None) -> int | None:
+ if isinstance(value, datetime):
+ return int(value.timestamp())
+ return value
-document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
-document_fields_copy = document_fields.copy()
-document_fields_copy["doc_metadata"] = fields.List(
- fields.Nested(document_metadata_model), attribute="doc_metadata_details"
-)
-document_model = get_or_create_model("Document", document_fields_copy)
+def _normalize_enum(value: Any) -> Any:
+ if isinstance(value, str) or value is None:
+ return value
+ return getattr(value, "value", value)
-document_with_segments_fields_copy = document_with_segments_fields.copy()
-document_with_segments_fields_copy["doc_metadata"] = fields.List(
- fields.Nested(document_metadata_model), attribute="doc_metadata_details"
-)
-document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
-dataset_and_document_fields_copy = dataset_and_document_fields.copy()
-dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
-dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
-dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
+class DatasetResponse(ResponseModel):
+ id: str
+ name: str
+ description: str | None = None
+ permission: str | None = None
+ data_source_type: str | None = None
+ indexing_technique: str | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+
+ @field_validator("data_source_type", "indexing_technique", mode="before")
+ @classmethod
+ def _normalize_enum_fields(cls, value: Any) -> Any:
+ return _normalize_enum(value)
+
+ @field_validator("created_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class DocumentMetadataResponse(ResponseModel):
+ id: str
+ name: str
+ type: str
+ value: str | None = None
+
+
+class DocumentResponse(ResponseModel):
+ id: str
+ position: int | None = None
+ data_source_type: str | None = None
+ data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
+ data_source_detail_dict: Any = None
+ dataset_process_rule_id: str | None = None
+ name: str
+ created_from: str | None = None
+ created_by: str | None = None
+ created_at: int | None = None
+ tokens: int | None = None
+ indexing_status: str | None = None
+ error: str | None = None
+ enabled: bool | None = None
+ disabled_at: int | None = None
+ disabled_by: str | None = None
+ archived: bool | None = None
+ display_status: str | None = None
+ word_count: int | None = None
+ hit_count: int | None = None
+ doc_form: str | None = None
+ doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
+ summary_index_status: str | None = None
+ need_summary: bool | None = None
+
+ @field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
+ @classmethod
+ def _normalize_enum_fields(cls, value: Any) -> Any:
+ return _normalize_enum(value)
+
+ @field_validator("doc_metadata", mode="before")
+ @classmethod
+ def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
+ if value is None:
+ return []
+ return value
+
+ @field_validator("created_at", "disabled_at", mode="before")
+ @classmethod
+ def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
+ return _to_timestamp(value)
+
+
+class DocumentWithSegmentsResponse(DocumentResponse):
+ process_rule_dict: Any = None
+ completed_segments: int | None = None
+ total_segments: int | None = None
+
+
+class DatasetAndDocumentResponse(ResponseModel):
+ dataset: DatasetResponse
+ documents: list[DocumentResponse]
+ batch: str
class DocumentRetryPayload(BaseModel):
@@ -107,6 +179,11 @@ class GenerateSummaryPayload(BaseModel):
document_list: list[str]
+class DocumentMetadataUpdatePayload(BaseModel):
+ doc_type: str | None = None
+ doc_metadata: Any = None
+
+
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
@@ -124,7 +201,13 @@ register_schema_models(
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
+ DocumentMetadataUpdatePayload,
DocumentBatchDownloadZipPayload,
+ DatasetResponse,
+ DocumentMetadataResponse,
+ DocumentResponse,
+ DocumentWithSegmentsResponse,
+ DatasetAndDocumentResponse,
)
@@ -357,10 +440,10 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
+ @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
@@ -398,7 +481,9 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
- return {"dataset": dataset, "documents": documents, "batch": batch}
+ return DatasetAndDocumentResponse.model_validate(
+ {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
+ ).model_dump(mode="json")
@setup_required
@login_required
@@ -426,12 +511,13 @@ class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
- @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
+ @console_ns.response(
+ 201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__]
+ )
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
- @marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@@ -479,9 +565,9 @@ class DatasetInitApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
- response = {"dataset": dataset, "documents": documents, "batch": batch}
-
- return response
+ return DatasetAndDocumentResponse.model_validate(
+ {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
+ ).model_dump(mode="json")
@console_ns.route("/datasets//documents//indexing-estimate")
@@ -988,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource):
@console_ns.doc("update_document_metadata")
@console_ns.doc(description="Update document metadata")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
- @console_ns.expect(
- console_ns.model(
- "UpdateDocumentMetadataRequest",
- {
- "doc_type": fields.String(description="Document type"),
- "doc_metadata": fields.Raw(description="Document metadata"),
- },
- )
- )
+ @console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__])
@console_ns.response(200, "Document metadata updated successfully")
@console_ns.response(404, "Document not found")
@console_ns.response(403, "Permission denied")
@@ -1009,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
- req_data = request.get_json()
+ req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
- doc_type = req_data.get("doc_type")
- doc_metadata = req_data.get("doc_metadata")
+ doc_type = req_data.doc_type
+ doc_metadata = req_data.doc_metadata
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
@@ -1194,7 +1272,7 @@ class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
- @marshal_with(document_model)
+ @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
@@ -1212,7 +1290,7 @@ class DocumentRenameApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
- return document
+ return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
@console_ns.route("/datasets//documents//website-sync")
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index 44404005b2..c01286cc59 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -595,13 +595,25 @@ class ChangeEmailSendEmailApi(Resource):
account = None
user_email = None
email_for_sending = args.email.lower()
- if args.phase is not None and args.phase == "new_email":
+ # Default to the initial phase; any legacy/unexpected client input is
+ # coerced back to `old_email` so we never trust the caller to declare
+ # later phases without a verified predecessor token.
+ send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
+ if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW:
+ send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
if args.token is None:
raise InvalidTokenError()
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
+
+ # The token used to request a new-email code must come from the
+ # old-email verification step. This prevents the bypass described
+ # in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
+ token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
+ if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
+ raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
@@ -620,7 +632,7 @@ class ChangeEmailSendEmailApi(Resource):
email=email_for_sending,
old_email=user_email,
language=language,
- phase=args.phase,
+ phase=send_phase,
)
return {"result": "success", "data": token}
@@ -655,12 +667,31 @@ class ChangeEmailCheckApi(Resource):
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
+ # Only advance tokens that were minted by the matching send-code step;
+ # refuse tokens that have already progressed or lack a phase marker so
+ # the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
+ # is strictly enforced.
+ phase_transitions = {
+ AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
+ AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
+ }
+ token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
+ if not isinstance(token_phase, str):
+ raise InvalidTokenError()
+ refreshed_phase = phase_transitions.get(token_phase)
+ if refreshed_phase is None:
+ raise InvalidTokenError()
+
# Verified, revoke the first token
AccountService.revoke_change_email_token(args.token)
- # Refresh token data by generating a new token
+ # Refresh token data by generating a new token that carries the
+ # upgraded phase so later steps can check it.
_, new_token = AccountService.generate_change_email_token(
- user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
+ user_email,
+ code=args.code,
+ old_email=token_data.get("old_email"),
+ additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
)
AccountService.reset_change_email_error_rate_limit(user_email)
@@ -690,13 +721,29 @@ class ChangeEmailResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
- AccountService.revoke_change_email_token(args.token)
+ # Only tokens that completed both verification phases may be used to
+ # change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
+ # the initial send-code step could be replayed directly here.
+ token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
+ if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
+ raise InvalidTokenError()
+
+ # Bind the new email to the token that was mailed and verified, so a
+ # verified token cannot be reused with a different `new_email` value.
+ token_email = reset_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if normalized_token_email != normalized_new_email:
+ raise InvalidTokenError()
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
+ # Revoke only after all checks pass so failed attempts don't burn a
+ # legitimately verified token.
+ AccountService.revoke_change_email_token(args.token)
+
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index 471594f349..34c9534de8 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -1131,6 +1131,14 @@ class ToolMCPAuthApi(Resource):
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
+ parsed = urlparse(server_url)
+ sanitized_url = f"{parsed.scheme}://{parsed.hostname}{parsed.path}"
+ logger.warning(
+ "MCP authorization failed for provider %s (url=%s)",
+ provider_id,
+ sanitized_url,
+ exc_info=True,
+ )
raise ValueError(f"Failed to connect to MCP server: {e}") from e
diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py
index a5846e2815..2f309262cb 100644
--- a/api/controllers/inner_api/plugin/wraps.py
+++ b/api/controllers/inner_api/plugin/wraps.py
@@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel):
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
- Get current user
+ Get current user.
NOTE: user_id is not trusted, it could be maliciously set to any value.
- As a result, it could only be considered as an end user id.
+ As a result, it could only be considered as an end user id. Even when a
+ concrete end-user ID is supplied, lookups must stay tenant-scoped so one
+ tenant cannot bind another tenant's user record into the plugin request
+ context.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
@@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
.limit(1)
)
else:
- user_model = session.get(EndUser, user_id)
+ user_model = session.scalar(
+ select(EndUser)
+ .where(
+ EndUser.id == user_id,
+ EndUser.tenant_id == tenant_id,
+ )
+ .limit(1)
+ )
if not user_model:
user_model = EndUser(
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index c4353ca7b8..ca4b18cb5e 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
try:
- return str(SegmentType(value).exposed_type().value)
+ return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index 790602ef5d..c22102c2ba 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index d38d24d1e7..29de0b8b1c 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -299,7 +299,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# update prompt tool
for prompt_tool in prompt_messages_tools:
- self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
+ tool_instance = tool_instances.get(prompt_tool.name)
+ if tool_instance:
+ self.update_prompt_message_tool(tool_instance, prompt_tool)
iteration_step += 1
diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
index dbd7527fc6..5df3df2b3e 100644
--- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
+++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
@@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class ModelConfigConverter:
diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py
index 09ddce327e..cae0eee0df 100644
--- a/api/core/app/apps/agent_chat/app_runner.py
+++ b/api/core/app/apps/agent_chat/app_runner.py
@@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)
diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py
index 80d504ef1c..a583301f9b 100644
--- a/api/core/app/file_access/scope.py
+++ b/api/core/app/file_access/scope.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from collections.abc import Iterator
+from collections.abc import Generator # Changed from Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
@@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None:
@contextmanager
-def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
+def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
token = _current_file_access_scope.set(scope)
try:
yield
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index dfe6133cb6..e2e07ebaff 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
TextPromptMessageContent,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py
index 68e5e5f0c8..3a6f9d575a 100644
--- a/api/core/app/workflow/file_runtime.py
+++ b/api/core/app/workflow/file_runtime.py
@@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
-from core.helper.ssrf_proxy import ssrf_proxy
+from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
from graphon.file import FileTransferMethod
-from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
+from graphon.file.protocols import WorkflowFileRuntimeProtocol
from graphon.file.runtime import set_workflow_file_runtime
+from graphon.http.protocols import HttpResponseProtocol
if TYPE_CHECKING:
from graphon.file import File
@@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
return dify_config.MULTIMODAL_SEND_FORMAT
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
- return ssrf_proxy.get(url, follow_redirects=follow_redirects)
+ return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)
diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py
index 87f005a250..d521304615 100644
--- a/api/core/app/workflow/layers/persistence.py
+++ b/api/core/app/workflow/layers/persistence.py
@@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
- execution.exceptions_count = runtime_state.exceptions_count
+ execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
def _update_node_execution(
self,
diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py
index dc831e5cac..f0dcb13b62 100644
--- a/api/core/datasource/datasource_manager.py
+++ b/api/core/datasource/datasource_manager.py
@@ -352,11 +352,11 @@ class DatasourceManager:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.CUSTOM,
+ file_type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(record_id=str(upload_file.id)),
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index 1ab66cceee..38b87e2cd1 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
FormType,
ProviderEntity,
)
-from graphon.model_runtime.model_providers.__base.ai_model import AIModel
+from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
@@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel):
else [],
)
- def validate_provider_credentials(
- self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
- ):
+ def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
"""
Validate custom credentials.
:param credentials: provider credentials
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
- :param session: optional database session
:return:
"""
+ provider_credential_secret_variables = self.extract_secret_variables(
+ self.provider.provider_credential_schema.credential_form_schemas
+ if self.provider.provider_credential_schema
+ else []
+ )
- def _validate(s: Session):
- # Get provider credential secret variables
- provider_credential_secret_variables = self.extract_secret_variables(
- self.provider.provider_credential_schema.credential_form_schemas
- if self.provider.provider_credential_schema
- else []
- )
-
- if credential_id:
+ if credential_id:
+ with Session(db.engine) as session:
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
- credential_record = s.execute(stmt).scalar_one_or_none()
- # fix origin data
+ credential_record = session.execute(stmt).scalar_one_or_none()
if credential_record and credential_record.encrypted_config:
if not credential_record.encrypted_config.startswith("{"):
original_credentials = {"openai_api_key": credential_record.encrypted_config}
@@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
- # encrypt credentials
- for key, value in credentials.items():
- if key in provider_credential_secret_variables:
- # if send [__HIDDEN__] in secret input, it will be same as original value
- if value == HIDDEN_VALUE and key in original_credentials:
- credentials[key] = encrypter.decrypt_token(
- tenant_id=self.tenant_id, token=original_credentials[key]
- )
-
- model_provider_factory = self.get_model_provider_factory()
- validated_credentials = model_provider_factory.provider_credentials_validate(
- provider=self.provider.provider, credentials=credentials
- )
-
- for key, value in validated_credentials.items():
+ for key, value in credentials.items():
if key in provider_credential_secret_variables:
- validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+ if value == HIDDEN_VALUE and key in original_credentials:
+ credentials[key] = encrypter.decrypt_token(
+ tenant_id=self.tenant_id, token=original_credentials[key]
+ )
- return validated_credentials
+ model_provider_factory = self.get_model_provider_factory()
+ validated_credentials = model_provider_factory.provider_credentials_validate(
+ provider=self.provider.provider, credentials=credentials
+ )
- if session:
- return _validate(session)
- else:
- with Session(db.engine) as new_session:
- return _validate(new_session)
+ for key, value in validated_credentials.items():
+ if key in provider_credential_secret_variables and isinstance(value, str):
+ validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+
+ return validated_credentials
def _generate_provider_credential_name(self, session) -> str:
"""
@@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
- with Session(db.engine) as session:
+ with Session(db.engine) as pre_session:
if credential_name:
- if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
+ if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
- credential_name = self._generate_provider_credential_name(session)
+ credential_name = self._generate_provider_credential_name(pre_session)
- credentials = self.validate_provider_credentials(credentials=credentials, session=session)
+ credentials = self.validate_provider_credentials(credentials=credentials)
+
+ with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
try:
new_record = ProviderCredential(
@@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel):
session.flush()
if not provider_record:
- # If provider record does not exist, create it
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
@@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
- with Session(db.engine) as session:
+ with Session(db.engine) as pre_session:
if credential_name and self._check_provider_credential_name_exists(
- credential_name=credential_name, session=session, exclude_id=credential_id
+ credential_name=credential_name, session=pre_session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
- credentials = self.validate_provider_credentials(
- credentials=credentials, credential_id=credential_id, session=session
- )
+ credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
+
+ with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
@@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel):
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
- # Get the credential record to update
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
try:
- # Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
@@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel):
model: str,
credentials: dict[str, Any],
credential_id: str = "",
- session: Session | None = None,
):
"""
Validate custom model credentials.
@@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel):
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:return:
"""
+ provider_credential_secret_variables = self.extract_secret_variables(
+ self.provider.model_credential_schema.credential_form_schemas
+ if self.provider.model_credential_schema
+ else []
+ )
- def _validate(s: Session):
- # Get provider credential secret variables
- provider_credential_secret_variables = self.extract_secret_variables(
- self.provider.model_credential_schema.credential_form_schemas
- if self.provider.model_credential_schema
- else []
- )
-
- if credential_id:
+ if credential_id:
+ with Session(db.engine) as session:
try:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
@@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
)
- credential_record = s.execute(stmt).scalar_one_or_none()
+ credential_record = session.execute(stmt).scalar_one_or_none()
original_credentials = (
json.loads(credential_record.encrypted_config)
if credential_record and credential_record.encrypted_config
@@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
- # decrypt credentials
- for key, value in credentials.items():
- if key in provider_credential_secret_variables:
- # if send [__HIDDEN__] in secret input, it will be same as original value
- if value == HIDDEN_VALUE and key in original_credentials:
- credentials[key] = encrypter.decrypt_token(
- tenant_id=self.tenant_id, token=original_credentials[key]
- )
-
- model_provider_factory = self.get_model_provider_factory()
- validated_credentials = model_provider_factory.model_credentials_validate(
- provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
- )
-
- for key, value in validated_credentials.items():
+ for key, value in credentials.items():
if key in provider_credential_secret_variables:
- validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+ if value == HIDDEN_VALUE and key in original_credentials:
+ credentials[key] = encrypter.decrypt_token(
+ tenant_id=self.tenant_id, token=original_credentials[key]
+ )
- return validated_credentials
+ model_provider_factory = self.get_model_provider_factory()
+ validated_credentials = model_provider_factory.model_credentials_validate(
+ provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
+ )
- if session:
- return _validate(session)
- else:
- with Session(db.engine) as new_session:
- return _validate(new_session)
+ for key, value in validated_credentials.items():
+ if key in provider_credential_secret_variables and isinstance(value, str):
+ validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+
+ return validated_credentials
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
@@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel):
:param credentials: model credentials dict
:return:
"""
- with Session(db.engine) as session:
+ with Session(db.engine) as pre_session:
if credential_name:
if self._check_custom_model_credential_name_exists(
- model=model, model_type=model_type, credential_name=credential_name, session=session
+ model=model, model_type=model_type, credential_name=credential_name, session=pre_session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
- model=model, model_type=model_type, session=session
+ model=model, model_type=model_type, session=pre_session
)
- # validate custom model config
- credentials = self.validate_custom_model_credentials(
- model_type=model_type, model=model, credentials=credentials, session=session
- )
+
+ credentials = self.validate_custom_model_credentials(
+ model_type=model_type, model=model, credentials=credentials
+ )
+
+ with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
try:
@@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel):
session.add(credential)
session.flush()
- # save provider model
if not provider_model_record:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
@@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel):
:param credential_id: credential id
:return:
"""
- with Session(db.engine) as session:
+ with Session(db.engine) as pre_session:
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
- session=session,
+ session=pre_session,
exclude_id=credential_id,
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
- # validate custom model config
- credentials = self.validate_custom_model_credentials(
- model_type=model_type,
- model=model,
- credentials=credentials,
- credential_id=credential_id,
- session=session,
- )
+
+ credentials = self.validate_custom_model_credentials(
+ model_type=model_type,
+ model=model,
+ credentials=credentials,
+ credential_id=credential_id,
+ )
+
+ with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
stmt = select(ProviderModelCredential).where(
@@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Credential record not found.")
try:
- # Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index b96a9ce380..38864a1830 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
- inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
+ inputs_json_str = dumps_with_segments(inputs).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded
diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py
index dc37a36943..f169f247cf 100644
--- a/api/core/helper/moderation.py
+++ b/api/core/helper/moderation.py
@@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py
index e38592bb7b..91e92712b7 100644
--- a/api/core/helper/ssrf_proxy.py
+++ b/api/core/helper/ssrf_proxy.py
@@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
+from graphon.http.response import HttpResponse
logger = logging.getLogger(__name__)
@@ -267,4 +268,47 @@ class SSRFProxy:
return patch(url=url, max_retries=max_retries, **kwargs)
+def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
+ """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
+ return HttpResponse(
+ status_code=response.status_code,
+ headers=dict(response.headers),
+ content=response.content,
+ url=str(response.url) if response.url else None,
+ reason_phrase=response.reason_phrase,
+ fallback_text=response.text,
+ )
+
+
+class GraphonSSRFProxy:
+ """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
+
+ @property
+ def max_retries_exceeded_error(self) -> type[Exception]:
+ return max_retries_exceeded_error
+
+ @property
+ def request_error(self) -> type[Exception]:
+ return request_error
+
+ def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
+
+ def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
+
+ def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
+
+ def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
+
+ def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
+
+ def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
+
+
ssrf_proxy = SSRFProxy()
+graphon_ssrf_proxy = GraphonSSRFProxy()
diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py
index 5c3cd0d8f8..acba3e666b 100644
--- a/api/core/mcp/client/streamable_client.py
+++ b/api/core/mcp/client/streamable_client.py
@@ -303,9 +303,16 @@ class StreamableHTTPTransport:
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
+ error_msg = (
+ f"MCP server URL returned 404 Not Found: {self.url} "
+ "— verify the server URL is correct and the server is running"
+ if is_initialization
+ else "Session terminated by server"
+ )
self._send_session_terminated_error(
ctx.server_to_client_queue,
message.root.id,
+ message=error_msg,
)
return
@@ -381,12 +388,13 @@ class StreamableHTTPTransport:
self,
server_to_client_queue: ServerToClientQueue,
request_id: RequestId,
+ message: str = "Session terminated by server",
):
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
- error=ErrorData(code=32600, message="Session terminated by server"),
+ error=ErrorData(code=32600, message=message),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
server_to_client_queue.put(session_message)
diff --git a/api/core/model_manager.py b/api/core/model_manager.py
index d8d8dfedd8..86d0e3baaa 100644
--- a/api/core/model_manager.py
+++ b/api/core/model_manager.py
@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
-from typing import IO, Any, Literal, Optional, Union, cast, overload
+from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
from core.entities import PluginCredentialType
@@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
-from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
-from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
+from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
+P = ParamSpec("P")
+R = TypeVar("R")
class ModelInstance:
@@ -168,7 +170,7 @@ class ModelInstance:
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -193,7 +195,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
+ self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -213,7 +215,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -235,7 +237,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
@@ -252,7 +254,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
+ self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -277,7 +279,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -305,7 +307,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke_multimodal_rerank,
+ self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -324,7 +326,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
@@ -340,7 +342,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
@@ -357,14 +359,14 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
- def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
+ def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke
diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py
index fda00ac3b9..d78ce90aa1 100644
--- a/api/core/ops/entities/config_entity.py
+++ b/api/core/ops/entities/config_entity.py
@@ -1,8 +1,8 @@
from enum import StrEnum
-from pydantic import BaseModel, ValidationInfo, field_validator
+from pydantic import BaseModel
-from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
+from core.ops.utils import validate_project_name, validate_url
class TracingProviderEnum(StrEnum):
@@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel):
return validate_project_name(v, default_name)
-class ArizeConfig(BaseTracingConfig):
- """
- Model class for Arize tracing config.
- """
-
- api_key: str | None = None
- space_id: str | None = None
- project: str | None = None
- endpoint: str = "https://otlp.arize.com"
-
- @field_validator("project")
- @classmethod
- def project_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "default")
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- return cls.validate_endpoint_url(v, "https://otlp.arize.com")
-
-
-class PhoenixConfig(BaseTracingConfig):
- """
- Model class for Phoenix tracing config.
- """
-
- api_key: str | None = None
- project: str | None = None
- endpoint: str = "https://app.phoenix.arize.com"
-
- @field_validator("project")
- @classmethod
- def project_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "default")
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- return validate_url_with_path(v, "https://app.phoenix.arize.com")
-
-
-class LangfuseConfig(BaseTracingConfig):
- """
- Model class for Langfuse tracing config.
- """
-
- public_key: str
- secret_key: str
- host: str = "https://api.langfuse.com"
-
- @field_validator("host")
- @classmethod
- def host_validator(cls, v, info: ValidationInfo):
- return validate_url_with_path(v, "https://api.langfuse.com")
-
-
-class LangSmithConfig(BaseTracingConfig):
- """
- Model class for Langsmith tracing config.
- """
-
- api_key: str
- project: str
- endpoint: str = "https://api.smith.langchain.com"
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- # LangSmith only allows HTTPS
- return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
-
-
-class OpikConfig(BaseTracingConfig):
- """
- Model class for Opik tracing config.
- """
-
- api_key: str | None = None
- project: str | None = None
- workspace: str | None = None
- url: str = "https://www.comet.com/opik/api/"
-
- @field_validator("project")
- @classmethod
- def project_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "Default Project")
-
- @field_validator("url")
- @classmethod
- def url_validator(cls, v, info: ValidationInfo):
- return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
-
-
-class WeaveConfig(BaseTracingConfig):
- """
- Model class for Weave tracing config.
- """
-
- api_key: str
- entity: str | None = None
- project: str
- endpoint: str = "https://trace.wandb.ai"
- host: str | None = None
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- # Weave only allows HTTPS for endpoint
- return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
-
- @field_validator("host")
- @classmethod
- def host_validator(cls, v, info: ValidationInfo):
- if v is not None and v.strip() != "":
- return validate_url(v, v, allowed_schemes=("https", "http"))
- return v
-
-
-class AliyunConfig(BaseTracingConfig):
- """
- Model class for Aliyun tracing config.
- """
-
- app_name: str = "dify_app"
- license_key: str
- endpoint: str
-
- @field_validator("app_name")
- @classmethod
- def app_name_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "dify_app")
-
- @field_validator("license_key")
- @classmethod
- def license_key_validator(cls, v, info: ValidationInfo):
- if not v or v.strip() == "":
- raise ValueError("License key cannot be empty")
- return v
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- # aliyun uses two URL formats, which may include a URL path
- return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
-
-
-class TencentConfig(BaseTracingConfig):
- """
- Tencent APM tracing config
- """
-
- token: str
- endpoint: str
- service_name: str
-
- @field_validator("token")
- @classmethod
- def token_validator(cls, v, info: ValidationInfo):
- if not v or v.strip() == "":
- raise ValueError("Token cannot be empty")
- return v
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
-
- @field_validator("service_name")
- @classmethod
- def service_name_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "dify_app")
-
-
-class MLflowConfig(BaseTracingConfig):
- """
- Model class for MLflow tracing config.
- """
-
- tracking_uri: str = "http://localhost:5000"
- experiment_id: str = "0" # Default experiment id in MLflow is 0
- username: str | None = None
- password: str | None = None
-
- @field_validator("tracking_uri")
- @classmethod
- def tracking_uri_validator(cls, v, info: ValidationInfo):
- if isinstance(v, str) and v.startswith("databricks"):
- raise ValueError(
- "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
- )
- return validate_url_with_path(v, "http://localhost:5000")
-
- @field_validator("experiment_id")
- @classmethod
- def experiment_id_validator(cls, v, info: ValidationInfo):
- return validate_integer_id(v)
-
-
-class DatabricksConfig(BaseTracingConfig):
- """
- Model class for Databricks (Databricks-managed MLflow) tracing config.
- """
-
- experiment_id: str
- host: str
- client_id: str | None = None
- client_secret: str | None = None
- personal_access_token: str | None = None
-
- @field_validator("experiment_id")
- @classmethod
- def experiment_id_validator(cls, v, info: ValidationInfo):
- return validate_integer_id(v)
-
-
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index cd63951537..e7ba6e502b 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict):
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
- match provider:
- case TracingProviderEnum.LANGFUSE:
- from core.ops.entities.config_entity import LangfuseConfig
- from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
+ try:
+ match provider:
+ case TracingProviderEnum.LANGFUSE:
+ from dify_trace_langfuse.config import LangfuseConfig
+ from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
- return {
- "config_class": LangfuseConfig,
- "secret_keys": ["public_key", "secret_key"],
- "other_keys": ["host", "project_key"],
- "trace_instance": LangFuseDataTrace,
- }
+ return {
+ "config_class": LangfuseConfig,
+ "secret_keys": ["public_key", "secret_key"],
+ "other_keys": ["host", "project_key"],
+ "trace_instance": LangFuseDataTrace,
+ }
- case TracingProviderEnum.LANGSMITH:
- from core.ops.entities.config_entity import LangSmithConfig
- from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
+ case TracingProviderEnum.LANGSMITH:
+ from dify_trace_langsmith.config import LangSmithConfig
+ from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
- return {
- "config_class": LangSmithConfig,
- "secret_keys": ["api_key"],
- "other_keys": ["project", "endpoint"],
- "trace_instance": LangSmithDataTrace,
- }
+ return {
+ "config_class": LangSmithConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "endpoint"],
+ "trace_instance": LangSmithDataTrace,
+ }
- case TracingProviderEnum.OPIK:
- from core.ops.entities.config_entity import OpikConfig
- from core.ops.opik_trace.opik_trace import OpikDataTrace
+ case TracingProviderEnum.OPIK:
+ from dify_trace_opik.config import OpikConfig
+ from dify_trace_opik.opik_trace import OpikDataTrace
- return {
- "config_class": OpikConfig,
- "secret_keys": ["api_key"],
- "other_keys": ["project", "url", "workspace"],
- "trace_instance": OpikDataTrace,
- }
+ return {
+ "config_class": OpikConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "url", "workspace"],
+ "trace_instance": OpikDataTrace,
+ }
- case TracingProviderEnum.WEAVE:
- from core.ops.entities.config_entity import WeaveConfig
- from core.ops.weave_trace.weave_trace import WeaveDataTrace
+ case TracingProviderEnum.WEAVE:
+ from dify_trace_weave.config import WeaveConfig
+ from dify_trace_weave.weave_trace import WeaveDataTrace
- return {
- "config_class": WeaveConfig,
- "secret_keys": ["api_key"],
- "other_keys": ["project", "entity", "endpoint", "host"],
- "trace_instance": WeaveDataTrace,
- }
- case TracingProviderEnum.ARIZE:
- from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
- from core.ops.entities.config_entity import ArizeConfig
+ return {
+ "config_class": WeaveConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "entity", "endpoint", "host"],
+ "trace_instance": WeaveDataTrace,
+ }
+ case TracingProviderEnum.ARIZE:
+ from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
+ from dify_trace_arize_phoenix.config import ArizeConfig
- return {
- "config_class": ArizeConfig,
- "secret_keys": ["api_key", "space_id"],
- "other_keys": ["project", "endpoint"],
- "trace_instance": ArizePhoenixDataTrace,
- }
- case TracingProviderEnum.PHOENIX:
- from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
- from core.ops.entities.config_entity import PhoenixConfig
+ return {
+ "config_class": ArizeConfig,
+ "secret_keys": ["api_key", "space_id"],
+ "other_keys": ["project", "endpoint"],
+ "trace_instance": ArizePhoenixDataTrace,
+ }
+ case TracingProviderEnum.PHOENIX:
+ from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
+ from dify_trace_arize_phoenix.config import PhoenixConfig
- return {
- "config_class": PhoenixConfig,
- "secret_keys": ["api_key"],
- "other_keys": ["project", "endpoint"],
- "trace_instance": ArizePhoenixDataTrace,
- }
- case TracingProviderEnum.ALIYUN:
- from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
- from core.ops.entities.config_entity import AliyunConfig
+ return {
+ "config_class": PhoenixConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "endpoint"],
+ "trace_instance": ArizePhoenixDataTrace,
+ }
+ case TracingProviderEnum.ALIYUN:
+ from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
+ from dify_trace_aliyun.config import AliyunConfig
- return {
- "config_class": AliyunConfig,
- "secret_keys": ["license_key"],
- "other_keys": ["endpoint", "app_name"],
- "trace_instance": AliyunDataTrace,
- }
- case TracingProviderEnum.MLFLOW:
- from core.ops.entities.config_entity import MLflowConfig
- from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
+ return {
+ "config_class": AliyunConfig,
+ "secret_keys": ["license_key"],
+ "other_keys": ["endpoint", "app_name"],
+ "trace_instance": AliyunDataTrace,
+ }
+ case TracingProviderEnum.MLFLOW:
+ from dify_trace_mlflow.config import MLflowConfig
+ from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
- return {
- "config_class": MLflowConfig,
- "secret_keys": ["password"],
- "other_keys": ["tracking_uri", "experiment_id", "username"],
- "trace_instance": MLflowDataTrace,
- }
- case TracingProviderEnum.DATABRICKS:
- from core.ops.entities.config_entity import DatabricksConfig
- from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
+ return {
+ "config_class": MLflowConfig,
+ "secret_keys": ["password"],
+ "other_keys": ["tracking_uri", "experiment_id", "username"],
+ "trace_instance": MLflowDataTrace,
+ }
+ case TracingProviderEnum.DATABRICKS:
+ from dify_trace_mlflow.config import DatabricksConfig
+ from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
- return {
- "config_class": DatabricksConfig,
- "secret_keys": ["personal_access_token", "client_secret"],
- "other_keys": ["host", "client_id", "experiment_id"],
- "trace_instance": MLflowDataTrace,
- }
+ return {
+ "config_class": DatabricksConfig,
+ "secret_keys": ["personal_access_token", "client_secret"],
+ "other_keys": ["host", "client_id", "experiment_id"],
+ "trace_instance": MLflowDataTrace,
+ }
- case TracingProviderEnum.TENCENT:
- from core.ops.entities.config_entity import TencentConfig
- from core.ops.tencent_trace.tencent_trace import TencentDataTrace
+ case TracingProviderEnum.TENCENT:
+ from dify_trace_tencent.config import TencentConfig
+ from dify_trace_tencent.tencent_trace import TencentDataTrace
- return {
- "config_class": TencentConfig,
- "secret_keys": ["token"],
- "other_keys": ["endpoint", "service_name"],
- "trace_instance": TencentDataTrace,
- }
+ return {
+ "config_class": TencentConfig,
+ "secret_keys": ["token"],
+ "other_keys": ["endpoint", "service_name"],
+ "trace_instance": TencentDataTrace,
+ }
- case _:
- raise KeyError(f"Unsupported tracing provider: {provider}")
+ case _:
+ raise KeyError(f"Unsupported tracing provider: {provider}")
+ except ImportError:
+ raise ImportError(f"Provider {provider} is not installed.")
provider_config_map = OpsTraceProviderConfigMap()
diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py
index e3fba4ef3a..4e66d58b5e 100644
--- a/api/core/plugin/impl/model_runtime.py
+++ b/api/core/plugin/impl/model_runtime.py
@@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
- provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
+ provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
- provider_schema.icon_small_dark.zh_Hans
+ provider_schema.icon_small_dark.zh_hans
if lang.lower() == "zh_hans"
- else provider_schema.icon_small_dark.en_US
+ else provider_schema.icon_small_dark.en_us
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py
index 8f1d51f08a..7c6280fe93 100644
--- a/api/core/prompt/agent_history_prompt_transform.py
+++ b/api/core/prompt/agent_history_prompt_transform.py
@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class AgentHistoryPromptTransform(PromptTransform):
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index c3bbe8fc09..8969825be4 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -70,12 +70,32 @@ class ProviderManager:
Request-bound managers may carry caller identity in that runtime, and the
resulting ``ProviderConfiguration`` objects must reuse it for downstream
model-type and schema lookups.
+
+ Configuration assembly is cached per manager instance so call chains that
+ share one request-scoped manager can reuse the same provider graph instead
+ of rebuilding it for every lookup. Call ``clear_configurations_cache()``
+ when a long-lived manager needs to observe writes performed within the same
+ instance scope.
"""
+ decoding_rsa_key: Any | None
+ decoding_cipher_rsa: Any | None
+ _model_runtime: ModelRuntime
+ _configurations_cache: dict[str, ProviderConfigurations]
+
def __init__(self, model_runtime: ModelRuntime):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._model_runtime = model_runtime
+ self._configurations_cache = {}
+
+ def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
+ """Drop assembled provider configurations cached on this manager instance."""
+ if tenant_id is None:
+ self._configurations_cache.clear()
+ return
+
+ self._configurations_cache.pop(tenant_id, None)
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@@ -114,6 +134,10 @@ class ProviderManager:
:param tenant_id:
:return:
"""
+ cached_configurations = self._configurations_cache.get(tenant_id)
+ if cached_configurations is not None:
+ return cached_configurations
+
# Get all provider records of the workspace
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
@@ -273,6 +297,8 @@ class ProviderManager:
provider_configurations[str(provider_id_entity)] = provider_configuration
+ self._configurations_cache[tenant_id] = provider_configurations
+
# Return the encapsulated object
return provider_configurations
diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py
index ed264878d3..242da520c1 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba.py
@@ -139,8 +139,10 @@ class Jieba(BaseKeyword):
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
}
dataset_keyword_table = self.dataset.dataset_keyword_table
- keyword_data_source_type = dataset_keyword_table.data_source_type
+ keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
if keyword_data_source_type == "database":
+ if dataset_keyword_table is None:
+ return
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit()
else:
diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
index 84f35c25f8..1ca6303af6 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
@@ -1,4 +1,5 @@
import re
+from collections.abc import Callable
from operator import itemgetter
from typing import cast
@@ -80,12 +81,14 @@ class JiebaKeywordTableHandler:
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
- top_k = kwargs.pop("topK", top_k)
+ top_k = cast(int | None, kwargs.pop("topK", top_k))
+ if top_k is None:
+ top_k = 20
cut = getattr(jieba, "cut", None)
if self._lcut:
tokens = self._lcut(sentence)
elif callable(cut):
- tokens = list(cut(sentence))
+ tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
else:
tokens = re.findall(r"\w+", sentence)
@@ -108,7 +111,7 @@ class JiebaKeywordTableHandler:
sentence=text,
topK=max_keywords_per_chunk,
)
- # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
+ # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords)
return set(self._expand_tokens_with_subtokens(set(keywords)))
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index 7e71d67ec0..2997710daf 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -158,7 +158,7 @@ class RetrievalService:
)
if futures:
- for future in concurrent.futures.as_completed(futures, timeout=3600):
+ for _ in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()
diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py
index 4926f44f16..a9995778f7 100644
--- a/api/core/rag/embedding/cached_embedding.py
+++ b/api/core/rag/embedding/cached_embedding.py
@@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from models.dataset import Embedding
diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py
index fbd2a6db93..b679edab36 100644
--- a/api/core/rag/extractor/extract_processor.py
+++ b/api/core/rag/extractor/extract_processor.py
@@ -94,6 +94,7 @@ class ExtractProcessor:
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
+ upload_file = extract_setting.upload_file
with tempfile.TemporaryDirectory() as temp_dir:
upload_file = extract_setting.upload_file
if not file_path:
@@ -104,6 +105,7 @@ class ExtractProcessor:
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
+ assert upload_file is not None, "upload_file is required"
etl_type = dify_config.ETL_TYPE
extractor: BaseExtractor | None = None
if etl_type == "Unstructured":
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index 052fca930d..0330a43b28 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -3,6 +3,7 @@
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
+import inspect
import logging
import mimetypes
import os
@@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
+ _closed: bool
+
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
+ self._closed = False
self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
@@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
+ def close(self) -> None:
+ """Best-effort cleanup for downloaded temporary files."""
+ if getattr(self, "_closed", False):
+ return
+
+ self._closed = True
+ temp_file = getattr(self, "temp_file", None)
+ if temp_file is None:
+ return
+
+ try:
+ close_result = temp_file.close()
+ if inspect.isawaitable(close_result):
+ close_awaitable = getattr(close_result, "close", None)
+ if callable(close_awaitable):
+ close_awaitable()
+ except Exception:
+ logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
+
def __del__(self):
- if hasattr(self, "temp_file"):
- self.temp_file.close()
+ self.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""
diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py
index f8242efe31..7ffa9afafd 100644
--- a/api/core/rag/index_processor/processor/paragraph_index_processor.py
+++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py
@@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 1453fe020b..5631b3a921 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.helper import parse_uuid_str_or_none
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
@@ -517,11 +517,11 @@ class DatasetRetrieval:
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(
diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py
index e617a9660e..426d1b67dc 100644
--- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py
+++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py
@@ -28,7 +28,7 @@ class FunctionCallMultiDatasetRouter:
SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query),
]
- result: LLMResult = model_instance.invoke_llm(
+ result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py
index 2581c354dd..52c9a02f97 100644
--- a/api/core/rag/splitter/fixed_text_splitter.py
+++ b/api/core/rag/splitter/fixed_text_splitter.py
@@ -4,12 +4,12 @@ from __future__ import annotations
import codecs
import re
-from collections.abc import Collection
+from collections.abc import Set as AbstractSet
from typing import Any, Literal
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
-from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
+from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
@@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
cls: type[T],
embedding_model_instance: ModelInstance | None,
- allowed_special: Literal["all"] | set[str] = set(),
- disallowed_special: Literal["all"] | Collection[str] = "all",
+ allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
+ disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any,
) -> T:
def _token_encoder(texts: list[str]) -> list[int]:
@@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return [len(text) for text in texts]
+ _ = _token_encoder # kept for future token-length wiring
return cls(length_function=_character_encoder, **kwargs)
diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py
index 7f2117e2dd..a8d9013fbc 100644
--- a/api/core/rag/splitter/text_splitter.py
+++ b/api/core/rag/splitter/text_splitter.py
@@ -4,7 +4,8 @@ import copy
import logging
import re
from abc import ABC, abstractmethod
-from collections.abc import Callable, Collection, Iterable, Sequence, Set
+from collections.abc import Callable, Iterable, Sequence
+from collections.abc import Set as AbstractSet
from dataclasses import dataclass
from typing import Any, Literal
@@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter):
self,
encoding_name: str = "gpt2",
model_name: str | None = None,
- allowed_special: Literal["all"] | Set[str] = set(),
- disallowed_special: Literal["all"] | Collection[str] = "all",
+ allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
+ disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
@@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter):
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
- self._allowed_special = allowed_special
- self._disallowed_special = disallowed_special
+ self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
+ self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
def split_text(self, text: str) -> list[str]:
def _encode(_text: str) -> list[int]:
diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py
index 02625e242f..740d727e26 100644
--- a/api/core/repositories/human_input_repository.py
+++ b/api/core/repositories/human_input_repository.py
@@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
EmailDeliveryMethod,
diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py
index 4c3efd6ff9..2b26832b44 100644
--- a/api/core/tools/errors.py
+++ b/api/core/tools/errors.py
@@ -38,6 +38,17 @@ class ToolCredentialPolicyViolationError(ValueError):
pass
+class ApiToolProviderNotFoundError(ValueError):
+ error_code = "api_tool_provider_not_found"
+ provider_name: str
+ tenant_id: str
+
+ def __init__(self, provider_name: str, tenant_id: str):
+ self.provider_name = provider_name
+ self.tenant_id = tenant_id
+ super().__init__(f"api provider {provider_name} does not exist")
+
+
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
error_code = "workflow_tool_human_input_not_supported"
description = "Workflow with Human Input nodes cannot be published as a workflow tool."
diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py
index b3424cd9a5..c87e8a3ae0 100644
--- a/api/core/tools/tool_file_manager.py
+++ b/api/core/tools/tool_file_manager.py
@@ -28,7 +28,7 @@ class ToolFileManager:
def _build_graph_file_reference(tool_file: ToolFile) -> File:
extension = guess_extension(tool_file.mimetype) or ".bin"
return File(
- type=get_file_type_by_mime_type(tool_file.mimetype),
+ file_type=get_file_type_by_mime_type(tool_file.mimetype),
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index f4588904d3..87cf6d7085 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -1082,7 +1082,12 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
- variable = variable_pool.get(tool_input.value)
+ variable_selector = tool_input.value
+ if not isinstance(variable_selector, list) or not all(
+ isinstance(selector_part, str) for selector_part in variable_selector
+ ):
+ raise ToolParameterError("Variable tool input must be a variable selector")
+ variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py
index 79d0c114d4..5679466cbc 100644
--- a/api/core/tools/utils/message_transformer.py
+++ b/api/core/tools/utils/message_transformer.py
@@ -41,6 +41,10 @@ def safe_json_value(v):
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
+ elif isinstance(v, np.integer):
+ return int(v)
+ elif isinstance(v, np.floating):
+ return float(v)
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py
index 9e1d41cb39..a3623d4ecd 100644
--- a/api/core/tools/utils/model_invocation_utils.py
+++ b/api/core/tools/utils/model_invocation_utils.py
@@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models.tools import ToolModelInvoke
diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py
index ed3ed3e0de..94a2c0427b 100644
--- a/api/core/tools/utils/web_reader_tool.py
+++ b/api/core/tools/utils/web_reader_tool.py
@@ -105,7 +105,7 @@ class Article:
def extract_using_readabilipy(html: str):
- json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
+ json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False)
article = Article(
title=json_article.get("title") or "",
author=json_article.get("byline") or "",
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 52ab605963..cd8c6352b5 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -357,7 +357,10 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
- transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
+ transfer_method_value = file_dict.get("transfer_method")
+ if not isinstance(transfer_method_value, str):
+ raise ValueError("Workflow file mapping is missing a valid transfer_method")
+ transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id
diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py
similarity index 74%
rename from api/core/workflow/human_input_compat.py
rename to api/core/workflow/human_input_adapter.py
index 75a0a0c202..4b765e6aea 100644
--- a/api/core/workflow/human_input_compat.py
+++ b/api/core/workflow/human_input_adapter.py
@@ -1,8 +1,8 @@
-"""Workflow-layer adapters for legacy human-input payload keys.
+"""Workflow-to-Graphon adapters for persisted node payloads.
-Stored workflow graphs and editor payloads may still use Dify-specific human
-input recipient keys. Normalize them here before handing configs to
-`graphon` so graph-owned models only see graph-neutral field names.
+Stored workflow graphs and editor payloads still contain a small set of
+Dify-owned field spellings and value shapes. Adapt them here before handing the
+payload to Graphon so Graphon-owned models only see current contracts.
"""
from __future__ import annotations
@@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
return None
-def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
@@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
- normalized = normalize_human_input_node_data_for_graph(node_data)
+ normalized = adapt_human_input_node_data_for_graph(node_data)
raw_delivery_methods = normalized.get("delivery_methods")
if not isinstance(raw_delivery_methods, list):
return []
@@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
return False
-def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
- if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
- return normalized
- return normalize_human_input_node_data_for_graph(normalized)
+ node_type = normalized.get("type")
+ if node_type == BuiltinNodeTypes.HUMAN_INPUT:
+ return adapt_human_input_node_data_for_graph(normalized)
+ if node_type == BuiltinNodeTypes.TOOL:
+ return _adapt_tool_node_data_for_graph(normalized)
+ return normalized
-def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_config)
if normalized is None:
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
@@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
if data_mapping is None:
return normalized
- normalized["data"] = normalize_node_data_for_graph(data_mapping)
+ normalized["data"] = adapt_node_data_for_graph(data_mapping)
return normalized
+def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
+ normalized = dict(node_data)
+
+ raw_tool_configurations = normalized.get("tool_configurations")
+ if not isinstance(raw_tool_configurations, Mapping):
+ return normalized
+
+ existing_tool_parameters = normalized.get("tool_parameters")
+ normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
+ normalized_tool_configurations: dict[str, Any] = {}
+ found_legacy_tool_inputs = False
+
+ for name, value in raw_tool_configurations.items():
+ if not isinstance(value, Mapping):
+ normalized_tool_configurations[name] = value
+ continue
+
+ input_type = value.get("type")
+ input_value = value.get("value")
+ if input_type not in {"mixed", "variable", "constant"}:
+ normalized_tool_configurations[name] = value
+ continue
+
+ found_legacy_tool_inputs = True
+ normalized_tool_parameters.setdefault(name, dict(value))
+
+ flattened_value = _flatten_legacy_tool_configuration_value(
+ input_type=input_type,
+ input_value=input_value,
+ )
+ if flattened_value is not None:
+ normalized_tool_configurations[name] = flattened_value
+
+ if not found_legacy_tool_inputs:
+ return normalized
+
+ normalized["tool_parameters"] = normalized_tool_parameters
+ normalized["tool_configurations"] = normalized_tool_configurations
+ return normalized
+
+
+def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
+ if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
+ return input_value
+
+ if (
+ input_type == "variable"
+ and isinstance(input_value, list)
+ and all(isinstance(item, str) for item in input_value)
+ ):
+ return "{{#" + ".".join(input_value) + "#}}"
+
+ return None
+
+
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)
@@ -291,9 +349,9 @@ __all__ = [
"MemberRecipient",
"WebAppDeliveryMethod",
"_WebAppDeliveryConfig",
+ "adapt_human_input_node_data_for_graph",
+ "adapt_node_config_for_graph",
+ "adapt_node_data_for_graph",
"is_human_input_webapp_enabled",
- "normalize_human_input_node_data_for_graph",
- "normalize_node_config_for_graph",
- "normalize_node_data_for_graph",
"parse_human_input_delivery_methods",
]
diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py
index 351da3444f..de4eae1b22 100644
--- a/api/core/workflow/node_factory.py
+++ b/api/core/workflow/node_factory.py
@@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
-from core.helper.ssrf_proxy import ssrf_proxy
+from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.trigger.constants import TRIGGER_NODE_TYPES
-from core.workflow.human_input_compat import normalize_node_config_for_graph
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
@@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.file.file_manager import file_manager
from graphon.graph.graph import NodeFactory
from graphon.model_runtime.memory import PromptMessageMemory
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.base.node import Node
from graphon.nodes.code.code_node import WorkflowCodeExecutor
from graphon.nodes.code.entities import CodeLanguage
@@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
+ """Resolve the production node class for the requested type/version."""
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
@@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
- self._http_request_http_client = ssrf_proxy
+ self._http_request_http_client = graphon_ssrf_proxy
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
self._dify_context,
conversation_id_getter=self._conversation_id,
@@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
- typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
+ typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
+ # Graph configs are initially validated against permissive shared node data.
+ # Re-validate using the resolved node class so workflow-local node schemas
+ # stay explicit and constructors receive the concrete typed payload.
+ resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
@@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
- return node_class.validate_node_data(node_data)
+ validate_node_data = getattr(node_class, "validate_node_data", None)
+ if callable(validate_node_data):
+ return cast("BaseNodeData", validate_node_data(node_data))
+ return node_data
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py
index 2e632e56f0..b8725853c4 100644
--- a/api/core/workflow/node_runtime.py
+++ b/api/core/workflow/node_runtime.py
@@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any, Literal, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
@@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
-from .human_input_compat import (
+from .human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
DeliveryMethodType,
@@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
+ @overload
+ def invoke_llm(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ model_parameters: Mapping[str, Any],
+ tools: Sequence[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: Literal[False],
+ ) -> LLMResult: ...
+
+ @overload
+ def invoke_llm(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ model_parameters: Mapping[str, Any],
+ tools: Sequence[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: Literal[True],
+ ) -> Generator[LLMResultChunk, None, None]: ...
+
def invoke_llm(
self,
*,
@@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
stream=stream,
)
+ @overload
+ def invoke_llm_with_structured_output(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Mapping[str, Any],
+ stop: Sequence[str] | None,
+ stream: Literal[False],
+ ) -> LLMResultWithStructuredOutput: ...
+
+ @overload
+ def invoke_llm_with_structured_output(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Mapping[str, Any],
+ stop: Sequence[str] | None,
+ stream: Literal[True],
+ ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
+
def invoke_llm_with_structured_output(
self,
*,
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 7b000101b0..68a24e86b1 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
@@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: AgentNodeData,
+ *,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
- *,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py
index e4f6b3b470..f3006c4242 100644
--- a/api/core/workflow/nodes/datasource/datasource_node.py
+++ b/api/core/workflow/nodes/datasource/datasource_node.py
@@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.workflow.file_reference import resolve_file_record_id
from core.workflow.system_variables import SystemVariableKey, get_system_segment
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
NodeExecutionType,
@@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: DatasourceNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- ):
+ ) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
index d5cab05dbe..9c1b7ab2c4 100644
--- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
+++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
@@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
@@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: KnowledgeIndexNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
- super().__init__(id, config, graph_init_params, graph_runtime_state)
+ super().__init__(
+ node_id=node_id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index 47ad14b499..25f73e446d 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file_reference import parse_file_reference
from graphon.entities import GraphInitParams
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
@@ -50,6 +49,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
+ if value is None or isinstance(value, (str, float)):
+ return value
+ if isinstance(value, int) and not isinstance(value, bool):
+ return value
+ return str(value)
+
+
+def _normalize_metadata_filter_sequence_item(value: object) -> str:
+ return value if isinstance(value, str) else str(value)
+
+
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
@@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: KnowledgeRetrievalNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- ):
+ ) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
+ resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
- resolved_value = segment_group.value[0].to_object()
+ resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
- resolved_values = []
- for v in value: # type: ignore
+ resolved_values: list[str] = []
+ for v in value:
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
- resolved_values.append(segment_group.value[0].to_object())
+ resolved_values.append(
+ _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
+ )
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py
index b7e7a6e60f..0c535a1c5b 100644
--- a/api/events/event_handlers/create_document_index.py
+++ b/api/events/event_handlers/create_document_index.py
@@ -6,9 +6,9 @@ import click
from sqlalchemy import select
from werkzeug.exceptions import NotFound
+from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.document_index_event import document_index_created
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Document
from models.enums import IndexingStatus
@@ -22,24 +22,25 @@ def handle(sender, **kwargs):
document_ids = kwargs.get("document_ids", [])
documents = []
start_at = time.perf_counter()
- for document_id in document_ids:
- logger.info(click.style(f"Start process document: {document_id}", fg="green"))
+ with session_factory.create_session() as session:
+ for document_id in document_ids:
+ logger.info(click.style(f"Start process document: {document_id}", fg="green"))
- document = db.session.scalar(
- select(Document).where(
- Document.id == document_id,
- Document.dataset_id == dataset_id,
+ document = session.scalar(
+ select(Document).where(
+ Document.id == document_id,
+ Document.dataset_id == dataset_id,
+ )
)
- )
- if not document:
- raise NotFound("Document not found")
+ if not document:
+ raise NotFound("Document not found")
- document.indexing_status = IndexingStatus.PARSING
- document.processing_started_at = naive_utc_now()
- documents.append(document)
- db.session.add(document)
- db.session.commit()
+ document.indexing_status = IndexingStatus.PARSING
+ document.processing_started_at = naive_utc_now()
+ documents.append(document)
+ session.add(document)
+ session.commit()
with contextlib.suppress(Exception):
try:
diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py
index 57412cc4ad..38e102d5fd 100644
--- a/api/events/event_handlers/create_installed_app_when_app_created.py
+++ b/api/events/event_handlers/create_installed_app_when_app_created.py
@@ -1,5 +1,5 @@
+from core.db.session_factory import session_factory
from events.app_event import app_was_created
-from extensions.ext_database import db
from models.model import InstalledApp
@@ -12,5 +12,6 @@ def handle(sender, **kwargs):
app_id=app.id,
app_owner_tenant_id=app.tenant_id,
)
- db.session.add(installed_app)
- db.session.commit()
+ with session_factory.create_session() as session:
+ session.add(installed_app)
+ session.commit()
diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py
index 84be592b1a..5e2a456dce 100644
--- a/api/events/event_handlers/create_site_record_when_app_created.py
+++ b/api/events/event_handlers/create_site_record_when_app_created.py
@@ -1,5 +1,5 @@
+from core.db.session_factory import session_factory
from events.app_event import app_was_created
-from extensions.ext_database import db
from models.enums import CustomizeTokenStrategy
from models.model import Site
@@ -22,6 +22,6 @@ def handle(sender, **kwargs):
created_by=app.created_by,
updated_by=app.updated_by,
)
-
- db.session.add(site)
- db.session.commit()
+ with session_factory.create_session() as session:
+ session.add(site)
+ session.commit()
diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py
index 288d37d265..1d2ad4d445 100644
--- a/api/factories/file_factory/builders.py
+++ b/api/factories/file_factory/builders.py
@@ -10,8 +10,8 @@ from typing import Any
from sqlalchemy import select
from core.app.file_access import FileAccessControllerProtocol
+from core.db.session_factory import session_factory
from core.workflow.file_reference import build_file_reference
-from extensions.ext_database import db
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
from models import ToolFile, UploadFile
@@ -135,29 +135,30 @@ def _build_from_local_file(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
- row = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
- if row is None:
- raise ValueError("Invalid upload file")
+ with session_factory.create_session() as session:
+ row = session.scalar(access_controller.apply_upload_file_filters(stmt))
+ if row is None:
+ raise ValueError("Invalid upload file")
- detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
- file_type = _resolve_file_type(
- detected_file_type=detected_file_type,
- specified_type=mapping.get("type", "custom"),
- strict_type_validation=strict_type_validation,
- )
+ detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
+ file_type = _resolve_file_type(
+ detected_file_type=detected_file_type,
+ specified_type=mapping.get("type", "custom"),
+ strict_type_validation=strict_type_validation,
+ )
- return File(
- id=mapping.get("id"),
- filename=row.name,
- extension="." + row.extension,
- mime_type=row.mime_type,
- type=file_type,
- transfer_method=transfer_method,
- remote_url=row.source_url,
- reference=build_file_reference(record_id=str(row.id)),
- size=row.size,
- storage_key=row.key,
- )
+ return File(
+ file_id=mapping.get("id"),
+ filename=row.name,
+ extension="." + row.extension,
+ mime_type=row.mime_type,
+ file_type=file_type,
+ transfer_method=transfer_method,
+ remote_url=row.source_url,
+ reference=build_file_reference(record_id=str(row.id)),
+ size=row.size,
+ storage_key=row.key,
+ )
def _build_from_remote_url(
@@ -179,32 +180,33 @@ def _build_from_remote_url(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
- upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
- if upload_file is None:
- raise ValueError("Invalid upload file")
+ with session_factory.create_session() as session:
+ upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
+ if upload_file is None:
+ raise ValueError("Invalid upload file")
- detected_file_type = standardize_file_type(
- extension="." + upload_file.extension,
- mime_type=upload_file.mime_type,
- )
- file_type = _resolve_file_type(
- detected_file_type=detected_file_type,
- specified_type=mapping.get("type"),
- strict_type_validation=strict_type_validation,
- )
+ detected_file_type = standardize_file_type(
+ extension="." + upload_file.extension,
+ mime_type=upload_file.mime_type,
+ )
+ file_type = _resolve_file_type(
+ detected_file_type=detected_file_type,
+ specified_type=mapping.get("type"),
+ strict_type_validation=strict_type_validation,
+ )
- return File(
- id=mapping.get("id"),
- filename=upload_file.name,
- extension="." + upload_file.extension,
- mime_type=upload_file.mime_type,
- type=file_type,
- transfer_method=transfer_method,
- remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
- reference=build_file_reference(record_id=str(upload_file.id)),
- size=upload_file.size,
- storage_key=upload_file.key,
- )
+ return File(
+ file_id=mapping.get("id"),
+ filename=upload_file.name,
+ extension="." + upload_file.extension,
+ mime_type=upload_file.mime_type,
+ file_type=file_type,
+ transfer_method=transfer_method,
+ remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
+ reference=build_file_reference(record_id=str(upload_file.id)),
+ size=upload_file.size,
+ storage_key=upload_file.key,
+ )
url = mapping.get("url") or mapping.get("remote_url")
if not url:
@@ -220,9 +222,9 @@ def _build_from_remote_url(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=filename,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
@@ -247,30 +249,31 @@ def _build_from_tool_file(
ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
- tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt))
- if tool_file is None:
- raise ValueError(f"ToolFile {tool_file_id} not found")
+ with session_factory.create_session() as session:
+ tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt))
+ if tool_file is None:
+ raise ValueError(f"ToolFile {tool_file_id} not found")
- extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
- detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
- file_type = _resolve_file_type(
- detected_file_type=detected_file_type,
- specified_type=mapping.get("type"),
- strict_type_validation=strict_type_validation,
- )
+ extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
+ detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
+ file_type = _resolve_file_type(
+ detected_file_type=detected_file_type,
+ specified_type=mapping.get("type"),
+ strict_type_validation=strict_type_validation,
+ )
- return File(
- id=mapping.get("id"),
- filename=tool_file.name,
- type=file_type,
- transfer_method=transfer_method,
- remote_url=tool_file.original_url,
- reference=build_file_reference(record_id=str(tool_file.id)),
- extension=extension,
- mime_type=tool_file.mimetype,
- size=tool_file.size,
- storage_key=tool_file.file_key,
- )
+ return File(
+ file_id=mapping.get("id"),
+ filename=tool_file.name,
+ file_type=file_type,
+ transfer_method=transfer_method,
+ remote_url=tool_file.original_url,
+ reference=build_file_reference(record_id=str(tool_file.id)),
+ extension=extension,
+ mime_type=tool_file.mimetype,
+ size=tool_file.size,
+ storage_key=tool_file.file_key,
+ )
def _build_from_datasource_file(
@@ -289,31 +292,32 @@ def _build_from_datasource_file(
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
- datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
- if datasource_file is None:
- raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
+ with session_factory.create_session() as session:
+ datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
+ if datasource_file is None:
+ raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
- extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
- detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
- file_type = _resolve_file_type(
- detected_file_type=detected_file_type,
- specified_type=mapping.get("type"),
- strict_type_validation=strict_type_validation,
- )
+ extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
+ detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
+ file_type = _resolve_file_type(
+ detected_file_type=detected_file_type,
+ specified_type=mapping.get("type"),
+ strict_type_validation=strict_type_validation,
+ )
- return File(
- id=mapping.get("datasource_file_id"),
- filename=datasource_file.name,
- type=file_type,
- transfer_method=FileTransferMethod.TOOL_FILE,
- remote_url=datasource_file.source_url,
- reference=build_file_reference(record_id=str(datasource_file.id)),
- extension=extension,
- mime_type=datasource_file.mime_type,
- size=datasource_file.size,
- storage_key=datasource_file.key,
- url=datasource_file.source_url,
- )
+ return File(
+ file_id=mapping.get("datasource_file_id"),
+ filename=datasource_file.name,
+ file_type=file_type,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ remote_url=datasource_file.source_url,
+ reference=build_file_reference(record_id=str(datasource_file.id)),
+ extension=extension,
+ mime_type=datasource_file.mime_type,
+ size=datasource_file.size,
+ storage_key=datasource_file.key,
+ url=datasource_file.source_url,
+ )
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:
diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py
index b5acbbbcb4..d518114777 100644
--- a/api/fields/_value_type_serializer.py
+++ b/api/fields/_value_type_serializer.py
@@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
- return v.value_type.exposed_type().value
+ return str(v.value_type.exposed_type())
else:
value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
- return value_type.exposed_type().value
+ return str(value_type.exposed_type())
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index cf4a71d545..e4219ba1ee 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
try:
- return str(SegmentType(value).exposed_type().value)
+ return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index f9b5e98936..6e947858ba 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
- "value_type": value.value_type.exposed_type().value,
+ "value_type": str(value.value_type.exposed_type()),
"description": value.description,
}
if isinstance(value, dict):
diff --git a/api/libs/flask_utils.py b/api/libs/flask_utils.py
index 52fc787c79..838af2bf32 100644
--- a/api/libs/flask_utils.py
+++ b/api/libs/flask_utils.py
@@ -1,5 +1,5 @@
import contextvars
-from collections.abc import Iterator
+from collections.abc import Generator # Changed from Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
def preserve_flask_contexts(
flask_app: Flask,
context_vars: contextvars.Context,
-) -> Iterator[None]:
+) -> Generator[None, None, None]: # Changed from Iterator[None]
"""
A context manager that handles:
1. flask-login's UserProxy copy
diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py
index 9b53918f24..934aacb45b 100644
--- a/api/libs/oauth_data_source.py
+++ b/api/libs/oauth_data_source.py
@@ -6,8 +6,8 @@ from flask_login import current_user
from pydantic import TypeAdapter
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.helper.http_client_pooling import get_pooled_http_client
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
- data_source_binding = db.session.scalar(
- select(DataSourceOauthBinding).where(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.access_token == access_token,
+ with session_factory.create_session() as session:
+ data_source_binding = session.scalar(
+ select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.access_token == access_token,
+ )
)
- )
- if data_source_binding:
- data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
- data_source_binding.disabled = False
- data_source_binding.updated_at = naive_utc_now()
- db.session.commit()
- else:
- new_data_source_binding = DataSourceOauthBinding(
- tenant_id=current_user.current_tenant_id,
- access_token=access_token,
- source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
- provider="notion",
- )
- db.session.add(new_data_source_binding)
- db.session.commit()
+ if data_source_binding:
+ data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
+ data_source_binding.disabled = False
+ data_source_binding.updated_at = naive_utc_now()
+ session.commit()
+ else:
+ new_data_source_binding = DataSourceOauthBinding(
+ tenant_id=current_user.current_tenant_id,
+ access_token=access_token,
+ source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
+ provider="notion",
+ )
+ session.add(new_data_source_binding)
+ session.commit()
def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token)
@@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
- data_source_binding = db.session.scalar(
- select(DataSourceOauthBinding).where(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.access_token == access_token,
+ with session_factory.create_session() as session:
+ data_source_binding = session.scalar(
+ select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.access_token == access_token,
+ )
)
- )
- if data_source_binding:
- data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
- data_source_binding.disabled = False
- data_source_binding.updated_at = naive_utc_now()
- db.session.commit()
- else:
- new_data_source_binding = DataSourceOauthBinding(
- tenant_id=current_user.current_tenant_id,
- access_token=access_token,
- source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
- provider="notion",
- )
- db.session.add(new_data_source_binding)
- db.session.commit()
+ if data_source_binding:
+ data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
+ data_source_binding.disabled = False
+ data_source_binding.updated_at = naive_utc_now()
+ session.commit()
+ else:
+ new_data_source_binding = DataSourceOauthBinding(
+ tenant_id=current_user.current_tenant_id,
+ access_token=access_token,
+ source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
+ provider="notion",
+ )
+ session.add(new_data_source_binding)
+ session.commit()
def sync_data_source(self, binding_id: str) -> None:
# save data source binding
- data_source_binding = db.session.scalar(
- select(DataSourceOauthBinding).where(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.id == binding_id,
- DataSourceOauthBinding.disabled == False,
+ with session_factory.create_session() as session:
+ data_source_binding = session.scalar(
+ select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.id == binding_id,
+ DataSourceOauthBinding.disabled == False,
+ )
)
- )
- if data_source_binding:
- # get all authorized pages
- pages = self.get_authorized_pages(data_source_binding.access_token)
- source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
- new_source_info = self._build_source_info(
- workspace_name=source_info["workspace_name"],
- workspace_icon=source_info["workspace_icon"],
- workspace_id=source_info["workspace_id"],
- pages=pages,
- )
- data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
- data_source_binding.disabled = False
- data_source_binding.updated_at = naive_utc_now()
- db.session.commit()
- else:
- raise ValueError("Data source binding not found")
+ if data_source_binding:
+ # get all authorized pages
+ pages = self.get_authorized_pages(data_source_binding.access_token)
+ source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
+ new_source_info = self._build_source_info(
+ workspace_name=source_info["workspace_name"],
+ workspace_icon=source_info["workspace_icon"],
+ workspace_id=source_info["workspace_id"],
+ pages=pages,
+ )
+ data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
+ data_source_binding.disabled = False
+ data_source_binding.updated_at = naive_utc_now()
+ session.commit()
+ else:
+ raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = []
diff --git a/api/libs/token.py b/api/libs/token.py
index a34db70764..5b043465ac 100644
--- a/api/libs/token.py
+++ b/api/libs/token.py
@@ -47,23 +47,17 @@ def _cookie_domain() -> str | None:
def _real_cookie_name(cookie_name: str) -> str:
if is_secure() and _cookie_domain() is None:
return "__Host-" + cookie_name
- else:
- return cookie_name
+ return cookie_name
def _try_extract_from_header(request: Request) -> str | None:
auth_header = request.headers.get("Authorization")
- if auth_header:
- if " " not in auth_header:
- return None
- else:
- auth_scheme, auth_token = auth_header.split(None, 1)
- auth_scheme = auth_scheme.lower()
- if auth_scheme != "bearer":
- return None
- else:
- return auth_token
- return None
+ if not auth_header or " " not in auth_header:
+ return None
+ auth_scheme, auth_token = auth_header.split(None, 1)
+ if auth_scheme.lower() != "bearer":
+ return None
+ return auth_token
def extract_refresh_token(request: Request) -> str | None:
@@ -90,14 +84,9 @@ def extract_webapp_access_token(request: Request) -> str | None:
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
- def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
- return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
-
- def _try_extract_passport_token_from_header(request: Request) -> str | None:
- return request.headers.get(HEADER_NAME_PASSPORT)
-
- ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request)
- return ret
+ return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) or request.headers.get(
+ HEADER_NAME_PASSPORT
+ )
def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"):
@@ -209,22 +198,18 @@ def check_csrf_token(request: Request, user_id: str):
if not csrf_token:
_unauthorized()
- verified = {}
try:
verified = PassportService().verify(csrf_token)
- except:
+ except Exception:
_unauthorized()
+ raise # unreachable, but helps the type checker see verified is always bound
if verified.get("sub") != user_id:
_unauthorized()
exp: int | None = verified.get("exp")
- if not exp:
+ if not exp or exp < int(datetime.now(UTC).timestamp()):
_unauthorized()
- else:
- time_now = int(datetime.now().timestamp())
- if exp < time_now:
- _unauthorized()
def generate_csrf_token(user_id: str) -> str:
diff --git a/api/libs/url_utils.py b/api/libs/url_utils.py
new file mode 100644
index 0000000000..adcac3add0
--- /dev/null
+++ b/api/libs/url_utils.py
@@ -0,0 +1,3 @@
+def normalize_api_base_url(base_url: str) -> str:
+ """Normalize a base URL to always end with /v1, avoiding double /v1 suffixes."""
+ return base_url.rstrip("/").removesuffix("/v1").rstrip("/") + "/v1"
diff --git a/api/models/comment.py b/api/models/comment.py
index 308339e6f6..1154e16788 100644
--- a/api/models/comment.py
+++ b/api/models/comment.py
@@ -3,6 +3,7 @@
from datetime import datetime
from typing import Optional
+import sqlalchemy as sa
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -36,24 +37,24 @@ class WorkflowComment(Base):
__tablename__ = "workflow_comments"
__table_args__ = (
- db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
+ sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
+ id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
- position_x: Mapped[float] = mapped_column(db.Float)
- position_y: Mapped[float] = mapped_column(db.Float)
- content: Mapped[str] = mapped_column(db.Text, nullable=False)
+ position_x: Mapped[float] = mapped_column(sa.Float)
+ position_y: Mapped[float] = mapped_column(sa.Float)
+ content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
- created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
- db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
- resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
- resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
+ resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
+ resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
# Relationships
@@ -143,20 +144,20 @@ class WorkflowCommentReply(Base):
__tablename__ = "workflow_comment_replies"
__table_args__ = (
- db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
+ sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
+ id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
- StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
+ StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
- content: Mapped[str] = mapped_column(db.Text, nullable=False)
+ content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
- created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
- db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@@ -187,18 +188,18 @@ class WorkflowCommentMention(Base):
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
- db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
+ sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
+ id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
- StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
+ StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[str | None] = mapped_column(
- StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
+ StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 50301dd2d7..eee5c39a0e 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -1715,7 +1715,7 @@ class SegmentAttachmentBinding(TypeBase):
)
-class DocumentSegmentSummary(Base):
+class DocumentSegmentSummary(TypeBase):
__tablename__ = "document_segment_summaries"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
@@ -1725,25 +1725,40 @@ class DocumentSegmentSummary(Base):
sa.Index("document_segment_summaries_status_idx", "status"),
)
- id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# corresponds to DocumentSegment.id or parent chunk id
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
- summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
- summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
- summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
- tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- status: Mapped[str] = mapped_column(
- EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
+ summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ status: Mapped[SummaryStatus] = mapped_column(
+ EnumText(SummaryStatus, length=32),
+ nullable=False,
+ server_default=sa.text("'generating'"),
+ default=SummaryStatus.GENERATING,
+ )
+ error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
+ disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
+ disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ created_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
- error: Mapped[str] = mapped_column(LongText, nullable=True)
- enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
- disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
- disabled_by = mapped_column(StringUUID, nullable=True)
- created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
- DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ DateTime,
+ nullable=False,
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp(),
+ init=False,
)
def __repr__(self):
diff --git a/api/models/human_input.py b/api/models/human_input.py
index b4c7a634b6..7447d3efcb 100644
--- a/api/models/human_input.py
+++ b/api/models/human_input.py
@@ -6,7 +6,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.orm import Mapped, mapped_column, relationship
-from core.workflow.human_input_compat import DeliveryMethodType
+from core.workflow.human_input_adapter import DeliveryMethodType
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.helper import generate_string
diff --git a/api/models/model.py b/api/models/model.py
index 7fe0731098..a1117fc43a 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -25,6 +25,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from graphon.file import helpers as file_helpers
from libs.helper import generate_string # type: ignore[import-not-found]
+from libs.url_utils import normalize_api_base_url
from libs.uuid_utils import uuidv7
from models.utils.file_input_compat import build_file_from_input_mapping
@@ -446,7 +447,8 @@ class App(Base):
@property
def api_base_url(self) -> str:
- return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
+ base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/")
+ return normalize_api_base_url(base)
@property
def tenant(self) -> Tenant | None:
diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py
index a2dc8f6157..77dcbd13d4 100644
--- a/api/models/utils/file_input_compat.py
+++ b/api/models/utils/file_input_compat.py
@@ -5,7 +5,8 @@ from functools import lru_cache
from typing import Any
from core.workflow.file_reference import parse_file_reference
-from graphon.file import File, FileTransferMethod
+from graphon.file import File, FileTransferMethod, FileType
+from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
@lru_cache(maxsize=1)
@@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
return tenant_resolver()
+def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
+ """Build a graph `File` directly from serialized metadata."""
+
+ def _coerce_file_type(value: Any) -> FileType:
+ if isinstance(value, FileType):
+ return value
+ if isinstance(value, str):
+ return FileType.value_of(value)
+ raise ValueError("file type is required in file mapping")
+
+ mapping = dict(file_mapping)
+ transfer_method_value = mapping.get("transfer_method")
+ if isinstance(transfer_method_value, FileTransferMethod):
+ transfer_method = transfer_method_value
+ elif isinstance(transfer_method_value, str):
+ transfer_method = FileTransferMethod.value_of(transfer_method_value)
+ else:
+ raise ValueError("transfer_method is required in file mapping")
+
+ file_id = mapping.get("file_id")
+ if not isinstance(file_id, str) or not file_id:
+ legacy_id = mapping.get("id")
+ file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
+
+ related_id = resolve_file_record_id(mapping)
+ if related_id is None:
+ raw_related_id = mapping.get("related_id")
+ related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
+
+ remote_url = mapping.get("remote_url")
+ if not isinstance(remote_url, str) or not remote_url:
+ url = mapping.get("url")
+ remote_url = url if isinstance(url, str) and url else None
+
+ reference = mapping.get("reference")
+ if not isinstance(reference, str) or not reference:
+ reference = None
+
+ filename = mapping.get("filename")
+ if not isinstance(filename, str):
+ filename = None
+
+ extension = mapping.get("extension")
+ if not isinstance(extension, str):
+ extension = None
+
+ mime_type = mapping.get("mime_type")
+ if not isinstance(mime_type, str):
+ mime_type = None
+
+ size = mapping.get("size", -1)
+ if not isinstance(size, int):
+ size = -1
+
+ storage_key = mapping.get("storage_key")
+ if not isinstance(storage_key, str):
+ storage_key = None
+
+ tenant_id = mapping.get("tenant_id")
+ if not isinstance(tenant_id, str):
+ tenant_id = None
+
+ dify_model_identity = mapping.get("dify_model_identity")
+ if not isinstance(dify_model_identity, str):
+ dify_model_identity = FILE_MODEL_IDENTITY
+
+ tool_file_id = mapping.get("tool_file_id")
+ if not isinstance(tool_file_id, str):
+ tool_file_id = None
+
+ upload_file_id = mapping.get("upload_file_id")
+ if not isinstance(upload_file_id, str):
+ upload_file_id = None
+
+ datasource_file_id = mapping.get("datasource_file_id")
+ if not isinstance(datasource_file_id, str):
+ datasource_file_id = None
+
+ return File(
+ file_id=file_id,
+ tenant_id=tenant_id,
+ file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
+ transfer_method=transfer_method,
+ remote_url=remote_url,
+ reference=reference,
+ related_id=related_id,
+ filename=filename,
+ extension=extension,
+ mime_type=mime_type,
+ size=size,
+ storage_key=storage_key,
+ dify_model_identity=dify_model_identity,
+ url=remote_url,
+ tool_file_id=tool_file_id,
+ upload_file_id=upload_file_id,
+ datasource_file_id=datasource_file_id,
+ )
+
+
+def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
+ """Recursively rebuild serialized graph file payloads into `File` objects.
+
+ `graphon` 0.2.2 no longer accepts legacy serialized file mappings via
+ `model_validate_json()`. Dify keeps this recovery path at the model boundary
+ so historical JSON blobs remain readable without reintroducing global graph
+ patches or test-local coercion.
+ """
+ if isinstance(value, list):
+ return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
+
+ if isinstance(value, dict):
+ if maybe_file_object(value):
+ return build_file_from_mapping_without_lookup(file_mapping=value)
+ return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
+
+ return value
+
+
def build_file_from_stored_mapping(
*,
file_mapping: Mapping[str, Any],
@@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
- remote_url = mapping.get("remote_url")
- if not isinstance(remote_url, str) or not remote_url:
- url = mapping.get("url")
- if isinstance(url, str) and url:
- mapping["remote_url"] = url
- return File.model_validate(mapping)
+ return build_file_from_mapping_without_lookup(file_mapping=mapping)
return file_factory.build_from_mapping(
mapping=mapping,
diff --git a/api/models/workflow.py b/api/models/workflow.py
index dfda03c2ee..d127244b0f 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
-from core.workflow.human_input_compat import normalize_node_config_for_graph
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.variable_prefixes import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
-from .utils.file_input_compat import build_file_from_stored_mapping
+from .utils.file_input_compat import (
+ build_file_from_mapping_without_lookup,
+ build_file_from_stored_mapping,
+)
logger = logging.getLogger(__name__)
@@ -290,7 +293,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
- return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
+ return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
@staticmethod
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
@@ -1688,7 +1691,7 @@ class WorkflowDraftVariable(Base):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
- return File.model_validate(normalized_file)
+ return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
@@ -1698,7 +1701,7 @@ class WorkflowDraftVariable(Base):
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
normalized_file.pop("tenant_id", None)
- file_list.append(File.model_validate(normalized_file))
+ file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
return cast(Any, file_list)
else:
return cast(Any, value)
diff --git a/api/providers/README.md b/api/providers/README.md
index a00ec8bc52..5d5e6db9af 100644
--- a/api/providers/README.md
+++ b/api/providers/README.md
@@ -10,3 +10,6 @@ This directory holds **optional workspace packages** that plug into Dify’s API
Provider tests often live next to the package, e.g. `providers///tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).
+## Excluding Providers
+
+In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-` and `--group trace-` to enable specific providers.
diff --git a/api/providers/trace/README.md b/api/providers/trace/README.md
new file mode 100644
index 0000000000..a7ffa5ed26
--- /dev/null
+++ b/api/providers/trace/README.md
@@ -0,0 +1,78 @@
+# Trace providers
+
+This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others).
+
+Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping.
+
+## Architecture
+
+| Layer | Location | Role |
+|--------|----------|------|
+| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads |
+| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class |
+| Your package | `api/providers/trace/trace-/` | Pydantic config + subclass of `BaseTraceInstance` |
+
+At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype.
+
+## What you implement
+
+### 1. Config model (`BaseTracingConfig`)
+
+Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate.
+
+Fields fall into two groups used by the manager:
+
+- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords).
+- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints).
+
+List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct.
+
+### 2. Trace instance (`BaseTraceInstance`)
+
+Subclass `BaseTraceInstance` and implement:
+
+```python
+def trace(self, trace_info: BaseTraceInfo) -> None:
+ ...
+```
+
+Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including:
+
+- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace`
+- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo`
+- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo`
+
+You may ignore categories your backend does not support; existing providers often no-op unhandled types.
+
+Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context.
+
+### 3. Register in the API core
+
+Upstream changes are required so Dify knows your provider exists:
+
+1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`).
+2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning:
+ - `config_class`: your Pydantic config type
+ - `secret_keys` / `other_keys`: lists of field names as above
+ - `trace_instance`: your `BaseTraceInstance` subclass
+ Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`.
+
+If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app.
+
+## Package layout
+
+Each provider is a normal uv workspace member, for example:
+
+- `api/providers/trace/trace-/pyproject.toml` — project name `dify-trace-`, dependencies on vendor SDKs
+- `api/providers/trace/trace-/src/dify_trace_/` — `config.py`, `_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`.
+
+Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`.
+
+## Wiring into the `api` workspace
+
+In `api/pyproject.toml`:
+
+1. **`[tool.uv.sources]`** — `dify-trace- = { workspace = true }`
+2. **`[dependency-groups]`** — add `trace- = ["dify-trace-"]` and include `dify-trace-` in `trace-all` if it should ship with the default bundle
+
+After changing metadata, run **`uv sync`** from `api/`.
diff --git a/api/providers/trace/trace-aliyun/pyproject.toml b/api/providers/trace/trace-aliyun/pyproject.toml
new file mode 100644
index 0000000000..bcef7e9fb1
--- /dev/null
+++ b/api/providers/trace/trace-aliyun/pyproject.toml
@@ -0,0 +1,14 @@
+[project]
+name = "dify-trace-aliyun"
+version = "0.0.1"
+dependencies = [
+ # versions inherited from parent
+ "opentelemetry-api",
+ "opentelemetry-exporter-otlp-proto-grpc",
+ "opentelemetry-sdk",
+ "opentelemetry-semantic-conventions",
+]
+description = "Dify ops tracing provider (Aliyun)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py
similarity index 100%
rename from api/core/ops/aliyun_trace/__init__.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py
diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py
similarity index 98%
rename from api/core/ops/aliyun_trace/aliyun_trace.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py
index 76e81242f4..54d2f8167f 100644
--- a/api/core/ops/aliyun_trace/aliyun_trace.py
+++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py
@@ -4,7 +4,20 @@ from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
-from core.ops.aliyun_trace.data_exporter.traceclient import (
+from core.ops.base_trace_instance import BaseTraceInstance
+from core.ops.entities.trace_entity import (
+ BaseTraceInfo,
+ DatasetRetrievalTraceInfo,
+ GenerateNameTraceInfo,
+ MessageTraceInfo,
+ ModerationTraceInfo,
+ SuggestedQuestionTraceInfo,
+ ToolTraceInfo,
+ WorkflowTraceInfo,
+)
+from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_aliyun.config import AliyunConfig
+from dify_trace_aliyun.data_exporter.traceclient import (
TraceClient,
build_endpoint,
convert_datetime_to_nanoseconds,
@@ -12,8 +25,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
convert_to_trace_id,
generate_span_id,
)
-from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
-from core.ops.aliyun_trace.entities.semconv import (
+from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
+from dify_trace_aliyun.entities.semconv import (
DIFY_APP_ID,
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@@ -32,7 +45,7 @@ from core.ops.aliyun_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
-from core.ops.aliyun_trace.utils import (
+from dify_trace_aliyun.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
@@ -44,19 +57,6 @@ from core.ops.aliyun_trace.utils import (
get_workflow_node_status,
serialize_json_data,
)
-from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import AliyunConfig
-from core.ops.entities.trace_entity import (
- BaseTraceInfo,
- DatasetRetrievalTraceInfo,
- GenerateNameTraceInfo,
- MessageTraceInfo,
- ModerationTraceInfo,
- SuggestedQuestionTraceInfo,
- ToolTraceInfo,
- WorkflowTraceInfo,
-)
-from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.entities import WorkflowNodeExecution
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
diff --git a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py
new file mode 100644
index 0000000000..e0133e6cc9
--- /dev/null
+++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py
@@ -0,0 +1,32 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url_with_path
+
+
+class AliyunConfig(BaseTracingConfig):
+ """
+ Model class for Aliyun tracing config.
+ """
+
+ app_name: str = "dify_app"
+ license_key: str
+ endpoint: str
+
+ @field_validator("app_name")
+ @classmethod
+ def app_name_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "dify_app")
+
+ @field_validator("license_key")
+ @classmethod
+ def license_key_validator(cls, v, info: ValidationInfo):
+ if not v or v.strip() == "":
+ raise ValueError("License key cannot be empty")
+ return v
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ # aliyun uses two URL formats, which may include a URL path
+ return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py
similarity index 100%
rename from api/core/ops/aliyun_trace/data_exporter/__init__.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py
diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
similarity index 98%
rename from api/core/ops/aliyun_trace/data_exporter/traceclient.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
index 67d5163b0f..00aab6bf89 100644
--- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py
+++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
@@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
-from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
-from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
+from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
+from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py
similarity index 100%
rename from api/core/ops/aliyun_trace/entities/__init__.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py
diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py
similarity index 100%
rename from api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py
diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py
similarity index 100%
rename from api/core/ops/aliyun_trace/entities/semconv.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py
diff --git a/api/core/ops/arize_phoenix_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed
similarity index 100%
rename from api/core/ops/arize_phoenix_trace/__init__.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed
diff --git a/api/core/ops/aliyun_trace/utils.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py
similarity index 97%
rename from api/core/ops/aliyun_trace/utils.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py
index 2e02a186cc..5678c66adb 100644
--- a/api/core/ops/aliyun_trace/utils.py
+++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py
@@ -4,7 +4,8 @@ from typing import Any, TypedDict
from opentelemetry.trace import Link, Status, StatusCode
-from core.ops.aliyun_trace.entities.semconv import (
+from core.rag.models.document import Document
+from dify_trace_aliyun.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
@@ -13,7 +14,6 @@ from core.ops.aliyun_trace.entities.semconv import (
OUTPUT_VALUE,
GenAISpanKind,
)
-from core.rag.models.document import Document
from extensions.ext_database import db
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
- from core.ops.aliyun_trace.data_exporter.traceclient import create_link
+ from dify_trace_aliyun.data_exporter.traceclient import create_link
links = []
if trace_id:
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py
similarity index 86%
rename from api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py
index acb43d4036..286dda419c 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py
@@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch
import httpx
import pytest
-from opentelemetry.sdk.trace import ReadableSpan
-from opentelemetry.trace import SpanKind, Status, StatusCode
-
-from core.ops.aliyun_trace.data_exporter.traceclient import (
+from dify_trace_aliyun.data_exporter.traceclient import (
INVALID_SPAN_ID,
SpanBuilder,
TraceClient,
@@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
create_link,
generate_span_id,
)
-from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
+from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
+from opentelemetry.sdk.trace import ReadableSpan
+from opentelemetry.trace import SpanKind, Status, StatusCode
@pytest.fixture
@@ -41,8 +40,8 @@ def trace_client_factory():
class TestTraceClient:
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname")
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
mock_gethostname.return_value = "test-host"
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@@ -56,7 +55,7 @@ class TestTraceClient:
client.shutdown()
assert client.done is True
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_export(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@@ -64,8 +63,8 @@ class TestTraceClient:
client.export(spans)
mock_exporter.export.assert_called_once_with(spans)
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
mock_response = MagicMock()
mock_response.status_code = 405
@@ -74,8 +73,8 @@ class TestTraceClient:
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.api_check() is True
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
mock_response = MagicMock()
mock_response.status_code = 500
@@ -84,8 +83,8 @@ class TestTraceClient:
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.api_check() is False
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
mock_head.side_effect = httpx.RequestError("Connection error")
@@ -93,12 +92,12 @@ class TestTraceClient:
with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"):
client.api_check()
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_add_span(self, mock_exporter_class, trace_client_factory):
client = trace_client_factory(
service_name="test-service",
@@ -134,8 +133,8 @@ class TestTraceClient:
assert len(client.queue) == 2
mock_notify.assert_called_once()
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.logger")
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
@@ -159,7 +158,7 @@ class TestTraceClient:
assert len(client.queue) == 1
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
mock_exporter.export.side_effect = Exception("Export failed")
@@ -168,11 +167,11 @@ class TestTraceClient:
mock_span = MagicMock(spec=ReadableSpan)
client.queue.append(mock_span)
- with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger:
+ with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger:
client._export_batch()
mock_logger.warning.assert_called()
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
# We need to test the wait timeout in _worker
# But _worker runs in a thread. Let's mock condition.wait.
@@ -189,7 +188,7 @@ class TestTraceClient:
# mock_wait might have been called
assert mock_wait.called or client.done
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@@ -268,7 +267,7 @@ def test_generate_span_id():
assert span_id != INVALID_SPAN_ID
# Test retry loop
- with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand:
+ with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand:
mock_rand.side_effect = [INVALID_SPAN_ID, 999]
span_id = generate_span_id()
assert span_id == 999
@@ -290,7 +289,7 @@ def test_convert_to_trace_id():
def test_convert_string_to_id():
assert convert_string_to_id("test") > 0
# Test with None string
- with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen:
+ with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen:
mock_gen.return_value = 12345
assert convert_string_to_id(None) == 12345
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py
similarity index 97%
rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py
index 2fcb927e0c..38d33dd21b 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py
@@ -1,11 +1,10 @@
import pytest
+from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import ValidationError
-from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
-
class TestTraceMetadata:
def test_trace_metadata_init(self):
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py
similarity index 97%
rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py
index 3961555b9a..9cab40748f 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py
@@ -1,4 +1,4 @@
-from core.ops.aliyun_trace.entities.semconv import (
+from dify_trace_aliyun.entities.semconv import (
ACS_ARMS_SERVICE_FEATURE,
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py
similarity index 99%
rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py
index c2324fdec4..c1b11c9186 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py
@@ -4,12 +4,11 @@ from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
+import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest
-from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
-
-import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module
-from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
-from core.ops.aliyun_trace.entities.semconv import (
+from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
+from dify_trace_aliyun.config import AliyunConfig
+from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
GEN_AI_OUTPUT_MESSAGE,
@@ -24,7 +23,8 @@ from core.ops.aliyun_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
-from core.ops.entities.config_entity import AliyunConfig
+from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
+
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py
similarity index 95%
rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py
index e4d8f2d5ea..a9e7b80c2a 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py
@@ -1,9 +1,7 @@
import json
from unittest.mock import MagicMock
-from opentelemetry.trace import Link, StatusCode
-
-from core.ops.aliyun_trace.entities.semconv import (
+from dify_trace_aliyun.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
@@ -11,7 +9,7 @@ from core.ops.aliyun_trace.entities.semconv import (
INPUT_VALUE,
OUTPUT_VALUE,
)
-from core.ops.aliyun_trace.utils import (
+from dify_trace_aliyun.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
@@ -23,6 +21,8 @@ from core.ops.aliyun_trace.utils import (
get_workflow_node_status,
serialize_json_data,
)
+from opentelemetry.trace import Link, StatusCode
+
from core.rag.models.document import Document
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@@ -48,7 +48,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
mock_session = MagicMock()
mock_session.get.return_value = end_user_data
- from core.ops.aliyun_trace.utils import db
+ from dify_trace_aliyun.utils import db
monkeypatch.setattr(db, "session", mock_session)
@@ -63,7 +63,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
mock_session = MagicMock()
mock_session.get.return_value = None
- from core.ops.aliyun_trace.utils import db
+ from dify_trace_aliyun.utils import db
monkeypatch.setattr(db, "session", mock_session)
@@ -112,9 +112,9 @@ def test_get_workflow_node_status():
def test_create_links_from_trace_id(monkeypatch):
# Mock create_link
mock_link = MagicMock(spec=Link)
- import core.ops.aliyun_trace.data_exporter.traceclient
+ import dify_trace_aliyun.data_exporter.traceclient
- monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
+ monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
# Trace ID None
assert create_links_from_trace_id(None) == []
diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..1b24ee7421
--- /dev/null
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,85 @@
+import pytest
+from dify_trace_aliyun.config import AliyunConfig
+from pydantic import ValidationError
+
+
+class TestAliyunConfig:
+ """Test cases for AliyunConfig"""
+
+ def test_valid_config(self):
+ """Test valid Aliyun configuration"""
+ config = AliyunConfig(
+ app_name="test_app",
+ license_key="test_license_key",
+ endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
+ )
+ assert config.app_name == "test_app"
+ assert config.license_key == "test_license_key"
+ assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
+ assert config.app_name == "dify_app"
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ AliyunConfig()
+
+ with pytest.raises(ValidationError):
+ AliyunConfig(license_key="test_license")
+
+ with pytest.raises(ValidationError):
+ AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_app_name_validation_empty(self):
+ """Test app_name validation with empty value"""
+ config = AliyunConfig(
+ license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
+ )
+ assert config.app_name == "dify_app"
+
+ def test_endpoint_validation_empty(self):
+ """Test endpoint validation with empty value"""
+ config = AliyunConfig(license_key="test_license", endpoint="")
+ assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
+
+ def test_endpoint_validation_with_path(self):
+ """Test endpoint validation preserves path for Aliyun endpoints"""
+ config = AliyunConfig(
+ license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
+ )
+ assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
+
+ def test_endpoint_validation_invalid_scheme(self):
+ """Test endpoint validation rejects invalid schemes"""
+ with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
+ AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_endpoint_validation_no_scheme(self):
+ """Test endpoint validation rejects URLs without scheme"""
+ with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
+ AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_license_key_required(self):
+ """Test that license_key is required and cannot be empty"""
+ with pytest.raises(ValidationError):
+ AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_valid_endpoint_format_examples(self):
+ """Test valid endpoint format examples from comments"""
+ valid_endpoints = [
+ # cms2.0 public endpoint
+ "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
+ # cms2.0 intranet endpoint
+ "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
+ # xtrace public endpoint
+ "http://tracing-cn-heyuan.arms.aliyuncs.com",
+ # xtrace intranet endpoint
+ "http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
+ ]
+
+ for endpoint in valid_endpoints:
+ config = AliyunConfig(license_key="test_license", endpoint=endpoint)
+ assert config.endpoint == endpoint
diff --git a/api/providers/trace/trace-arize-phoenix/pyproject.toml b/api/providers/trace/trace-arize-phoenix/pyproject.toml
new file mode 100644
index 0000000000..9e756944c9
--- /dev/null
+++ b/api/providers/trace/trace-arize-phoenix/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-arize-phoenix"
+version = "0.0.1"
+dependencies = [
+ "arize-phoenix-otel~=0.15.0",
+]
+description = "Dify ops tracing provider (Arize / Phoenix)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/langfuse_trace/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py
similarity index 100%
rename from api/core/ops/langfuse_trace/__init__.py
rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py
diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
similarity index 99%
rename from api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
index 78516e1a22..96df49ed0e 100644
--- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
+++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
@@ -25,7 +25,6 @@ from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -39,6 +38,7 @@ from core.ops.entities.trace_entity import (
)
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from models.model import EndUser, MessageFile
diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py
new file mode 100644
index 0000000000..6eac5b30d2
--- /dev/null
+++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py
@@ -0,0 +1,45 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url_with_path
+
+
+class ArizeConfig(BaseTracingConfig):
+ """
+ Model class for Arize tracing config.
+ """
+
+ api_key: str | None = None
+ space_id: str | None = None
+ project: str | None = None
+ endpoint: str = "https://otlp.arize.com"
+
+ @field_validator("project")
+ @classmethod
+ def project_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "default")
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ return cls.validate_endpoint_url(v, "https://otlp.arize.com")
+
+
+class PhoenixConfig(BaseTracingConfig):
+ """
+ Model class for Phoenix tracing config.
+ """
+
+ api_key: str | None = None
+ project: str | None = None
+ endpoint: str = "https://app.phoenix.arize.com"
+
+ @field_validator("project")
+ @classmethod
+ def project_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "default")
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ return validate_url_with_path(v, "https://app.phoenix.arize.com")
diff --git a/api/core/ops/langfuse_trace/entities/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed
similarity index 100%
rename from api/core/ops/langfuse_trace/entities/__init__.py
rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed
diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py
similarity index 91%
rename from api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py
rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py
index 4ce9e22fd7..b0691a87ea 100644
--- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py
+++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py
@@ -2,11 +2,7 @@ from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
-from opentelemetry.sdk.trace import Tracer
-from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
-from opentelemetry.trace import StatusCode
-
-from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
+from dify_trace_arize_phoenix.arize_phoenix_trace import (
ArizePhoenixDataTrace,
datetime_to_nanos,
error_to_string,
@@ -15,7 +11,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
setup_tracer,
wrap_span_metadata,
)
-from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
+from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
+from opentelemetry.sdk.trace import Tracer
+from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
+from opentelemetry.trace import StatusCode
+
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -80,7 +80,7 @@ def test_datetime_to_nanos():
expected = int(dt.timestamp() * 1_000_000_000)
assert datetime_to_nanos(dt) == expected
- with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt:
+ with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt:
mock_now = MagicMock()
mock_now.timestamp.return_value = 1704110400.0
mock_dt.now.return_value = mock_now
@@ -142,8 +142,8 @@ def test_wrap_span_metadata():
assert res == {"a": 1, "b": 2, "created_from": "Dify"}
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter")
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
def test_setup_tracer_arize(mock_provider, mock_exporter):
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
setup_tracer(config)
@@ -151,8 +151,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter):
assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1"
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter")
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
def test_setup_tracer_phoenix(mock_provider, mock_exporter):
config = PhoenixConfig(endpoint="http://p.com", project="p")
setup_tracer(config)
@@ -162,7 +162,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter):
def test_setup_tracer_exception():
config = ArizeConfig(endpoint="http://a.com", project="p")
- with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
+ with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
with pytest.raises(Exception, match="boom"):
setup_tracer(config)
@@ -172,7 +172,7 @@ def test_setup_tracer_exception():
@pytest.fixture
def trace_instance():
- with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup:
+ with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup:
mock_tracer = MagicMock(spec=Tracer)
mock_processor = MagicMock()
mock_setup.return_value = (mock_tracer, mock_processor)
@@ -228,9 +228,9 @@ def test_trace_exception(trace_instance):
trace_instance.trace(_make_workflow_info())
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker")
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory")
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance):
mock_db.engine = MagicMock()
info = _make_workflow_info()
@@ -262,7 +262,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac
assert trace_instance.tracer.start_span.call_count >= 2
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_workflow_trace_no_app_id(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_workflow_info()
@@ -271,7 +271,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance):
trace_instance.workflow_trace(info)
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_message_trace_success(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_message_info()
@@ -291,7 +291,7 @@ def test_message_trace_success(mock_db, trace_instance):
assert trace_instance.tracer.start_span.call_count >= 1
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_message_trace_with_error(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_message_info()
diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py
similarity index 94%
rename from api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py
rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py
index 4b925390d9..a01c63ae61 100644
--- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py
+++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py
@@ -1,6 +1,6 @@
+from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
from openinference.semconv.trace import OpenInferenceSpanKindValues
-from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes
diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..11e951c3b1
--- /dev/null
+++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,88 @@
+import pytest
+from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
+from pydantic import ValidationError
+
+
+class TestArizeConfig:
+ """Test cases for ArizeConfig"""
+
+ def test_valid_config(self):
+ """Test valid Arize configuration"""
+ config = ArizeConfig(
+ api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
+ )
+ assert config.api_key == "test_key"
+ assert config.space_id == "test_space"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.arize.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = ArizeConfig()
+ assert config.api_key is None
+ assert config.space_id is None
+ assert config.project is None
+ assert config.endpoint == "https://otlp.arize.com"
+
+ def test_project_validation_empty(self):
+ """Test project validation with empty value"""
+ config = ArizeConfig(project="")
+ assert config.project == "default"
+
+ def test_project_validation_none(self):
+ """Test project validation with None value"""
+ config = ArizeConfig(project=None)
+ assert config.project == "default"
+
+ def test_endpoint_validation_empty(self):
+ """Test endpoint validation with empty value"""
+ config = ArizeConfig(endpoint="")
+ assert config.endpoint == "https://otlp.arize.com"
+
+ def test_endpoint_validation_with_path(self):
+ """Test endpoint validation normalizes URL by removing path"""
+ config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
+ assert config.endpoint == "https://custom.arize.com"
+
+ def test_endpoint_validation_invalid_scheme(self):
+ """Test endpoint validation rejects invalid schemes"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ ArizeConfig(endpoint="ftp://invalid.com")
+
+ def test_endpoint_validation_no_scheme(self):
+ """Test endpoint validation rejects URLs without scheme"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ ArizeConfig(endpoint="invalid.com")
+
+
+class TestPhoenixConfig:
+ """Test cases for PhoenixConfig"""
+
+ def test_valid_config(self):
+ """Test valid Phoenix configuration"""
+ config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
+ assert config.api_key == "test_key"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.phoenix.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = PhoenixConfig()
+ assert config.api_key is None
+ assert config.project is None
+ assert config.endpoint == "https://app.phoenix.arize.com"
+
+ def test_project_validation_empty(self):
+ """Test project validation with empty value"""
+ config = PhoenixConfig(project="")
+ assert config.project == "default"
+
+ def test_endpoint_validation_with_path(self):
+ """Test endpoint validation with path"""
+ config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
+ assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
+
+ def test_endpoint_validation_without_path(self):
+ """Test endpoint validation without path"""
+ config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
+ assert config.endpoint == "https://app.phoenix.arize.com"
diff --git a/api/providers/trace/trace-langfuse/pyproject.toml b/api/providers/trace/trace-langfuse/pyproject.toml
new file mode 100644
index 0000000000..27d2273a69
--- /dev/null
+++ b/api/providers/trace/trace-langfuse/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-langfuse"
+version = "0.0.1"
+dependencies = [
+ "langfuse>=4.2.0,<5.0.0",
+]
+description = "Dify ops tracing provider (Langfuse)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/langsmith_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py
similarity index 100%
rename from api/core/ops/langsmith_trace/__init__.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py
diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py
new file mode 100644
index 0000000000..90d1a2846b
--- /dev/null
+++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py
@@ -0,0 +1,19 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url_with_path
+
+
+class LangfuseConfig(BaseTracingConfig):
+ """
+ Model class for Langfuse tracing config.
+ """
+
+ public_key: str
+ secret_key: str
+ host: str = "https://api.langfuse.com"
+
+ @field_validator("host")
+ @classmethod
+ def host_validator(cls, v, info: ValidationInfo):
+ return validate_url_with_path(v, "https://api.langfuse.com")
diff --git a/api/core/ops/langsmith_trace/entities/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py
similarity index 100%
rename from api/core/ops/langsmith_trace/entities/__init__.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py
diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py
similarity index 100%
rename from api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py
diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py
similarity index 99%
rename from api/core/ops/langfuse_trace/langfuse_trace.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py
index 7eacc2be46..68881378a7 100644
--- a/api/core/ops/langfuse_trace/langfuse_trace.py
+++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py
@@ -16,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import LangfuseConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -28,7 +27,10 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
+from core.ops.utils import filter_none_values
+from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_langfuse.config import LangfuseConfig
+from dify_trace_langfuse.entities.langfuse_trace_entity import (
GenerationUsage,
LangfuseGeneration,
LangfuseSpan,
@@ -36,8 +38,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
LevelEnum,
UnitEnum,
)
-from core.ops.utils import filter_none_values
-from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
diff --git a/api/core/ops/mlflow_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed
similarity index 100%
rename from api/core/ops/mlflow_trace/__init__.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed
diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py
similarity index 93%
rename from api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py
rename to api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py
index a0bcc92795..952f10c34f 100644
--- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py
+++ b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py
@@ -5,8 +5,16 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
+from dify_trace_langfuse.config import LangfuseConfig
+from dify_trace_langfuse.entities.langfuse_trace_entity import (
+ LangfuseGeneration,
+ LangfuseSpan,
+ LangfuseTrace,
+ LevelEnum,
+ UnitEnum,
+)
+from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
-from core.ops.entities.config_entity import LangfuseConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -17,14 +25,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
- LangfuseGeneration,
- LangfuseSpan,
- LangfuseTrace,
- LevelEnum,
- UnitEnum,
-)
-from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from graphon.enums import BuiltinNodeTypes
from models import EndUser
from models.enums import MessageStatus
@@ -43,7 +43,7 @@ def langfuse_config():
def trace_instance(langfuse_config, monkeypatch):
# Mock Langfuse client to avoid network calls
mock_client = MagicMock()
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
instance = LangFuseDataTrace(langfuse_config)
return instance
@@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch):
def test_init(langfuse_config, monkeypatch):
mock_langfuse = MagicMock()
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = LangFuseDataTrace(langfuse_config)
@@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
# Mock DB and Repositories
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
# Mock node executions
node_llm = MagicMock()
@@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
error="",
)
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
@@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
workflow_app_log_id="log-1",
error="",
)
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
@@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
trace_instance.add_trace = MagicMock()
trace_instance.add_generation = MagicMock()
@@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat
repo.get_by_workflow_execution.return_value = [node]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..103d888eef
--- /dev/null
+++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,42 @@
+import pytest
+from dify_trace_langfuse.config import LangfuseConfig
+from pydantic import ValidationError
+
+
+class TestLangfuseConfig:
+ """Test cases for LangfuseConfig"""
+
+ def test_valid_config(self):
+ """Test valid Langfuse configuration"""
+ config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
+ assert config.public_key == "public_key"
+ assert config.secret_key == "secret_key"
+ assert config.host == "https://custom.langfuse.com"
+
+ def test_valid_config_with_path(self):
+ host = "https://custom.langfuse.com/api/v1"
+ config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
+ assert config.public_key == "public_key"
+ assert config.secret_key == "secret_key"
+ assert config.host == host
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = LangfuseConfig(public_key="public", secret_key="secret")
+ assert config.host == "https://api.langfuse.com"
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ LangfuseConfig()
+
+ with pytest.raises(ValidationError):
+ LangfuseConfig(public_key="public")
+
+ with pytest.raises(ValidationError):
+ LangfuseConfig(secret_key="secret")
+
+ def test_host_validation_empty(self):
+ """Test host validation with empty value"""
+ config = LangfuseConfig(public_key="public", secret_key="secret", host="")
+ assert config.host == "https://api.langfuse.com"
diff --git a/api/tests/unit_tests/core/ops/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py
similarity index 92%
rename from api/tests/unit_tests/core/ops/test_langfuse_trace.py
rename to api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py
index 017ac8c891..0340ffb669 100644
--- a/api/tests/unit_tests/core/ops/test_langfuse_trace.py
+++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py
@@ -4,14 +4,15 @@ from datetime import datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
-from core.ops.entities.config_entity import LangfuseConfig
+from dify_trace_langfuse.config import LangfuseConfig
+from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
+
from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo
-from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from graphon.enums import BuiltinNodeTypes
def _create_trace_instance() -> LangFuseDataTrace:
- with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True):
+ with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True):
return LangFuseDataTrace(
LangfuseConfig(
public_key="public-key",
@@ -116,9 +117,9 @@ class TestLangFuseDataTraceCompletionStartTime:
patch.object(trace, "add_span"),
patch.object(trace, "add_generation") as add_generation,
patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()),
- patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()),
+ patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()),
patch(
- "core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
+ "dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=repository,
),
):
diff --git a/api/providers/trace/trace-langsmith/pyproject.toml b/api/providers/trace/trace-langsmith/pyproject.toml
new file mode 100644
index 0000000000..8131952b28
--- /dev/null
+++ b/api/providers/trace/trace-langsmith/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-langsmith"
+version = "0.0.1"
+dependencies = [
+ "langsmith~=0.7.30",
+]
+description = "Dify ops tracing provider (LangSmith)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/opik_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py
similarity index 100%
rename from api/core/ops/opik_trace/__init__.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py
diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py
new file mode 100644
index 0000000000..498b8c5e7e
--- /dev/null
+++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py
@@ -0,0 +1,20 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url
+
+
+class LangSmithConfig(BaseTracingConfig):
+ """
+ Model class for Langsmith tracing config.
+ """
+
+ api_key: str
+ project: str
+ endpoint: str = "https://api.smith.langchain.com"
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ # LangSmith only allows HTTPS
+ return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
diff --git a/api/core/ops/tencent_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py
similarity index 100%
rename from api/core/ops/tencent_trace/__init__.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py
diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py
similarity index 100%
rename from api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py
diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py
similarity index 99%
rename from api/core/ops/langsmith_trace/langsmith_trace.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py
index d960038f15..145bd70dbc 100644
--- a/api/core/ops/langsmith_trace/langsmith_trace.py
+++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py
@@ -9,7 +9,6 @@ from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import LangSmithConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -21,13 +20,14 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
+from core.ops.utils import filter_none_values, generate_dotted_order
+from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_langsmith.config import LangSmithConfig
+from dify_trace_langsmith.entities.langsmith_trace_entity import (
LangSmithRunModel,
LangSmithRunType,
LangSmithRunUpdateModel,
)
-from core.ops.utils import filter_none_values, generate_dotted_order
-from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
diff --git a/api/core/ops/weave_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed
similarity index 100%
rename from api/core/ops/weave_trace/__init__.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed
diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py
similarity index 91%
rename from api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py
rename to api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py
index 34c64c54a1..45e5894e4a 100644
--- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py
+++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py
@@ -3,8 +3,14 @@ from datetime import datetime, timedelta
from unittest.mock import MagicMock
import pytest
+from dify_trace_langsmith.config import LangSmithConfig
+from dify_trace_langsmith.entities.langsmith_trace_entity import (
+ LangSmithRunModel,
+ LangSmithRunType,
+ LangSmithRunUpdateModel,
+)
+from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
-from core.ops.entities.config_entity import LangSmithConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -15,12 +21,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
- LangSmithRunModel,
- LangSmithRunType,
- LangSmithRunUpdateModel,
-)
-from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser
@@ -38,7 +38,7 @@ def langsmith_config():
def trace_instance(langsmith_config, monkeypatch):
# Mock LangSmith client
mock_client = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client)
instance = LangSmithDataTrace(langsmith_config)
return instance
@@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch):
def test_init(langsmith_config, monkeypatch):
mock_client_class = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = LangSmithDataTrace(langsmith_config)
@@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch):
# Mock dependencies
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
# Mock node executions
node_llm = MagicMock()
@@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch):
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
)
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()
@@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_info.error = ""
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
@@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch):
# Mock EndUser lookup
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
trace_instance.add_run = MagicMock()
@@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()
diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..37efaf69cf
--- /dev/null
+++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,35 @@
+import pytest
+from dify_trace_langsmith.config import LangSmithConfig
+from pydantic import ValidationError
+
+
+class TestLangSmithConfig:
+ """Test cases for LangSmithConfig"""
+
+ def test_valid_config(self):
+ """Test valid LangSmith configuration"""
+ config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
+ assert config.api_key == "test_key"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.smith.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = LangSmithConfig(api_key="key", project="project")
+ assert config.endpoint == "https://api.smith.langchain.com"
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ LangSmithConfig()
+
+ with pytest.raises(ValidationError):
+ LangSmithConfig(api_key="key")
+
+ with pytest.raises(ValidationError):
+ LangSmithConfig(project="project")
+
+ def test_endpoint_validation_https_only(self):
+ """Test endpoint validation only allows HTTPS"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
diff --git a/api/providers/trace/trace-mlflow/pyproject.toml b/api/providers/trace/trace-mlflow/pyproject.toml
new file mode 100644
index 0000000000..fad6002944
--- /dev/null
+++ b/api/providers/trace/trace-mlflow/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-mlflow"
+version = "0.0.1"
+dependencies = [
+ "mlflow-skinny>=3.11.1",
+]
+description = "Dify ops tracing provider (MLflow / Databricks)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py
similarity index 100%
rename from api/core/ops/weave_trace/entities/__init__.py
rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py
diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py
new file mode 100644
index 0000000000..84914165e3
--- /dev/null
+++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py
@@ -0,0 +1,46 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_integer_id, validate_url_with_path
+
+
+class MLflowConfig(BaseTracingConfig):
+ """
+ Model class for MLflow tracing config.
+ """
+
+ tracking_uri: str = "http://localhost:5000"
+ experiment_id: str = "0" # Default experiment id in MLflow is 0
+ username: str | None = None
+ password: str | None = None
+
+ @field_validator("tracking_uri")
+ @classmethod
+ def tracking_uri_validator(cls, v, info: ValidationInfo):
+ if isinstance(v, str) and v.startswith("databricks"):
+ raise ValueError(
+ "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
+ )
+ return validate_url_with_path(v, "http://localhost:5000")
+
+ @field_validator("experiment_id")
+ @classmethod
+ def experiment_id_validator(cls, v, info: ValidationInfo):
+ return validate_integer_id(v)
+
+
+class DatabricksConfig(BaseTracingConfig):
+ """
+ Model class for Databricks (Databricks-managed MLflow) tracing config.
+ """
+
+ experiment_id: str
+ host: str
+ client_id: str | None = None
+ client_secret: str | None = None
+ personal_access_token: str | None = None
+
+ @field_validator("experiment_id")
+ @classmethod
+ def experiment_id_validator(cls, v, info: ValidationInfo):
+ return validate_integer_id(v)
diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
similarity index 99%
rename from api/core/ops/mlflow_trace/mlflow_trace.py
rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
index 87fcaeabcc..4e4c45a532 100644
--- a/api/core/ops/mlflow_trace/mlflow_trace.py
+++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
@@ -11,7 +11,6 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex
from sqlalchemy import select
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -24,6 +23,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
+from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes
from models import EndUser
diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py
similarity index 98%
rename from api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py
rename to api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py
index afc5726ede..20211456e3 100644
--- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py
+++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py
@@ -1,4 +1,4 @@
-"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module."""
+"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module."""
from __future__ import annotations
@@ -9,8 +9,9 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
+from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
+from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
-from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -20,7 +21,6 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
from graphon.enums import BuiltinNodeTypes
# ── Helpers ──────────────────────────────────────────────────────────────────
@@ -179,7 +179,7 @@ def _make_node(**overrides):
@pytest.fixture
def mock_mlflow():
- with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock:
+ with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock:
yield mock
@@ -187,10 +187,10 @@ def mock_mlflow():
def mock_tracing():
"""Patch all MLflow tracing functions used by the module."""
with (
- patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start,
- patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update,
- patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set,
- patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach,
+ patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start,
+ patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update,
+ patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set,
+ patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach,
):
yield {
"start": mock_start,
@@ -202,7 +202,7 @@ def mock_tracing():
@pytest.fixture
def mock_db():
- with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock:
+ with patch("dify_trace_mlflow.mlflow_trace.db") as mock:
yield mock
diff --git a/api/providers/trace/trace-opik/pyproject.toml b/api/providers/trace/trace-opik/pyproject.toml
new file mode 100644
index 0000000000..874997168e
--- /dev/null
+++ b/api/providers/trace/trace-opik/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-opik"
+version = "0.0.1"
+dependencies = [
+ "opik~=1.11.2",
+]
+description = "Dify ops tracing provider (Opik)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/config.py b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py
new file mode 100644
index 0000000000..c16ff1d903
--- /dev/null
+++ b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py
@@ -0,0 +1,25 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url_with_path
+
+
+class OpikConfig(BaseTracingConfig):
+ """
+ Model class for Opik tracing config.
+ """
+
+ api_key: str | None = None
+ project: str | None = None
+ workspace: str | None = None
+ url: str = "https://www.comet.com/opik/api/"
+
+ @field_validator("project")
+ @classmethod
+ def project_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "Default Project")
+
+ @field_validator("url")
+ @classmethod
+ def url_validator(cls, v, info: ValidationInfo):
+ return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py
similarity index 99%
rename from api/core/ops/opik_trace/opik_trace.py
rename to api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py
index 672efe45bd..2d124ac989 100644
--- a/api/core/ops/opik_trace/opik_trace.py
+++ b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py
@@ -10,7 +10,6 @@ from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -23,6 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_opik.config import OpikConfig
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed b/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py
similarity index 93%
rename from api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py
rename to api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py
index c02ac413f2..eefed3c78c 100644
--- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py
+++ b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py
@@ -5,8 +5,9 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
+from dify_trace_opik.config import OpikConfig
+from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
-from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -17,7 +18,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser
from models.enums import MessageStatus
@@ -37,7 +37,7 @@ def opik_config():
@pytest.fixture
def trace_instance(opik_config, monkeypatch):
mock_client = MagicMock()
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client)
instance = OpikDataTrace(opik_config)
return instance
@@ -67,7 +67,7 @@ def test_prepare_opik_uuid():
def test_init(opik_config, monkeypatch):
mock_opik = MagicMock()
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = OpikDataTrace(opik_config)
@@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
)
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
node_llm = MagicMock()
node_llm.id = LLM_NODE_ID
@@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
error="",
)
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
@@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e",
error="",
)
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
@@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user)
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2"))
trace_instance.add_span = MagicMock()
@@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch
repo.get_by_workflow_execution.return_value = [node]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..5a54b70bba
--- /dev/null
+++ b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,48 @@
+import pytest
+from dify_trace_opik.config import OpikConfig
+from pydantic import ValidationError
+
+
+class TestOpikConfig:
+ """Test cases for OpikConfig"""
+
+ def test_valid_config(self):
+ """Test valid Opik configuration"""
+ config = OpikConfig(
+ api_key="test_key",
+ project="test_project",
+ workspace="test_workspace",
+ url="https://custom.comet.com/opik/api/",
+ )
+ assert config.api_key == "test_key"
+ assert config.project == "test_project"
+ assert config.workspace == "test_workspace"
+ assert config.url == "https://custom.comet.com/opik/api/"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = OpikConfig()
+ assert config.api_key is None
+ assert config.project is None
+ assert config.workspace is None
+ assert config.url == "https://www.comet.com/opik/api/"
+
+ def test_project_validation_empty(self):
+ """Test project validation with empty value"""
+ config = OpikConfig(project="")
+ assert config.project == "Default Project"
+
+ def test_url_validation_empty(self):
+ """Test URL validation with empty value"""
+ config = OpikConfig(url="")
+ assert config.url == "https://www.comet.com/opik/api/"
+
+ def test_url_validation_missing_suffix(self):
+ """Test URL validation requires /api/ suffix"""
+ with pytest.raises(ValidationError, match="URL should end with /api/"):
+ OpikConfig(url="https://custom.comet.com/opik/")
+
+ def test_url_validation_invalid_scheme(self):
+ """Test URL validation rejects invalid schemes"""
+ with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
+ OpikConfig(url="ftp://custom.comet.com/opik/api/")
diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py
similarity index 94%
rename from api/tests/unit_tests/core/ops/test_opik_trace.py
rename to api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py
index ad9d0846be..fba290f5b8 100644
--- a/api/tests/unit_tests/core/ops/test_opik_trace.py
+++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py
@@ -14,8 +14,9 @@ import uuid
from datetime import datetime
from unittest.mock import MagicMock, patch
+from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
+
from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo
-from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
# A stable UUID4 used as the workflow_run_id throughout all tests.
_WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6"
@@ -56,8 +57,8 @@ def _make_workflow_trace_info(
def _make_opik_trace_instance() -> OpikDataTrace:
"""Construct an OpikDataTrace with the Opik SDK client mocked out."""
- with patch("core.ops.opik_trace.opik_trace.Opik"):
- from core.ops.entities.config_entity import OpikConfig
+ with patch("dify_trace_opik.opik_trace.Opik"):
+ from dify_trace_opik.config import OpikConfig
config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/")
instance = OpikDataTrace(config)
@@ -133,10 +134,10 @@ class TestWorkflowTraceWithoutMessageId:
fake_repo.get_by_workflow_execution.return_value = node_executions or []
with (
- patch("core.ops.opik_trace.opik_trace.db") as mock_db,
- patch("core.ops.opik_trace.opik_trace.sessionmaker"),
+ patch("dify_trace_opik.opik_trace.db") as mock_db,
+ patch("dify_trace_opik.opik_trace.sessionmaker"),
patch(
- "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
+ "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=fake_repo,
),
):
@@ -265,10 +266,10 @@ class TestWorkflowTraceWithMessageId:
fake_repo.get_by_workflow_execution.return_value = node_executions or []
with (
- patch("core.ops.opik_trace.opik_trace.db") as mock_db,
- patch("core.ops.opik_trace.opik_trace.sessionmaker"),
+ patch("dify_trace_opik.opik_trace.db") as mock_db,
+ patch("dify_trace_opik.opik_trace.sessionmaker"),
patch(
- "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
+ "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=fake_repo,
),
):
diff --git a/api/providers/trace/trace-tencent/pyproject.toml b/api/providers/trace/trace-tencent/pyproject.toml
new file mode 100644
index 0000000000..eab06fc708
--- /dev/null
+++ b/api/providers/trace/trace-tencent/pyproject.toml
@@ -0,0 +1,14 @@
+[project]
+name = "dify-trace-tencent"
+version = "0.0.1"
+dependencies = [
+ # versions inherited from parent
+ "opentelemetry-api",
+ "opentelemetry-exporter-otlp-proto-grpc",
+ "opentelemetry-sdk",
+ "opentelemetry-semantic-conventions",
+]
+description = "Dify ops tracing provider (Tencent APM)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/tencent_trace/client.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py
similarity index 100%
rename from api/core/ops/tencent_trace/client.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py
diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py
new file mode 100644
index 0000000000..398e6c55a8
--- /dev/null
+++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py
@@ -0,0 +1,30 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+
+
+class TencentConfig(BaseTracingConfig):
+ """
+ Tencent APM tracing config
+ """
+
+ token: str
+ endpoint: str
+ service_name: str
+
+ @field_validator("token")
+ @classmethod
+ def token_validator(cls, v, info: ValidationInfo):
+ if not v or v.strip() == "":
+ raise ValueError("Token cannot be empty")
+ return v
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
+
+ @field_validator("service_name")
+ @classmethod
+ def service_name_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "dify_app")
diff --git a/api/core/ops/tencent_trace/entities/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py
similarity index 100%
rename from api/core/ops/tencent_trace/entities/__init__.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py
diff --git a/api/core/ops/tencent_trace/entities/semconv.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py
similarity index 100%
rename from api/core/ops/tencent_trace/entities/semconv.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py
diff --git a/api/core/ops/tencent_trace/entities/tencent_trace_entity.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py
similarity index 100%
rename from api/core/ops/tencent_trace/entities/tencent_trace_entity.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py
diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed b/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py
similarity index 98%
rename from api/core/ops/tencent_trace/span_builder.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py
index 36878dc58f..763a85ffd7 100644
--- a/api/core/ops/tencent_trace/span_builder.py
+++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py
@@ -14,7 +14,8 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.ops.tencent_trace.entities.semconv import (
+from core.rag.models.document import Document
+from dify_trace_tencent.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_IS_ENTRY,
@@ -38,9 +39,8 @@ from core.ops.tencent_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
-from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
-from core.ops.tencent_trace.utils import TencentTraceUtils
-from core.rag.models.document import Document
+from dify_trace_tencent.entities.tencent_trace_entity import SpanData
+from dify_trace_tencent.utils import TencentTraceUtils
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
similarity index 94%
rename from api/core/ops/tencent_trace/tencent_trace.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
index d681b9da80..a8c480e4a5 100644
--- a/api/core/ops/tencent_trace/tencent_trace.py
+++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
@@ -1,14 +1,12 @@
-"""
-Tencent APM tracing implementation with separated concerns
-"""
+"""Tencent APM tracing with idempotent client cleanup."""
+import inspect
import logging
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -19,11 +17,12 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.ops.tencent_trace.client import TencentTraceClient
-from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
-from core.ops.tencent_trace.span_builder import TencentSpanBuilder
-from core.ops.tencent_trace.utils import TencentTraceUtils
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from dify_trace_tencent.client import TencentTraceClient
+from dify_trace_tencent.config import TencentConfig
+from dify_trace_tencent.entities.tencent_trace_entity import SpanData
+from dify_trace_tencent.span_builder import TencentSpanBuilder
+from dify_trace_tencent.utils import TencentTraceUtils
from extensions.ext_database import db
from graphon.entities.workflow_node_execution import (
WorkflowNodeExecution,
@@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
+
+ The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen
+ explicitly in tests or implicitly during garbage collection, so shutdown
+ must be safe to call multiple times.
"""
+ trace_client: TencentTraceClient
+ _closed: bool
+
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
+ self._closed = False
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
@@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance):
except Exception:
logger.debug("[Tencent APM] Failed to record message trace duration")
- def __del__(self):
- """Ensure proper cleanup on garbage collection."""
+ def close(self) -> None:
+ """Synchronously and idempotently shutdown the underlying trace client."""
+ if getattr(self, "_closed", False):
+ return
+
+ self._closed = True
+ trace_client = getattr(self, "trace_client", None)
+ if trace_client is None:
+ return
+
try:
- if hasattr(self, "trace_client"):
- self.trace_client.shutdown()
+ shutdown_result = trace_client.shutdown()
+ if inspect.isawaitable(shutdown_result):
+ close_awaitable = getattr(shutdown_result, "close", None)
+ if callable(close_awaitable):
+ close_awaitable()
except Exception:
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
+
+ def __del__(self):
+ """Ensure best-effort cleanup on garbage collection without retrying shutdown."""
+ self.close()
diff --git a/api/core/ops/tencent_trace/utils.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py
similarity index 100%
rename from api/core/ops/tencent_trace/utils.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py
diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py
similarity index 98%
rename from api/tests/unit_tests/core/ops/tencent_trace/test_client.py
rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py
index 870c18e53e..1e656e2462 100644
--- a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py
@@ -8,13 +8,12 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
+from dify_trace_tencent import client as client_module
+from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
+from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
-from core.ops.tencent_trace import client as client_module
-from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version
-from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
-
metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = []
diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py
similarity index 89%
rename from api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py
rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py
index 6113e5c6c8..e850a801f3 100644
--- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py
@@ -1,15 +1,7 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
-from opentelemetry.trace import StatusCode
-
-from core.ops.entities.trace_entity import (
- DatasetRetrievalTraceInfo,
- MessageTraceInfo,
- ToolTraceInfo,
- WorkflowTraceInfo,
-)
-from core.ops.tencent_trace.entities.semconv import (
+from dify_trace_tencent.entities.semconv import (
GEN_AI_IS_ENTRY,
GEN_AI_IS_STREAMING_REQUEST,
GEN_AI_MODEL_NAME,
@@ -23,7 +15,15 @@ from core.ops.tencent_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
-from core.ops.tencent_trace.span_builder import TencentSpanBuilder
+from dify_trace_tencent.span_builder import TencentSpanBuilder
+from opentelemetry.trace import StatusCode
+
+from core.ops.entities.trace_entity import (
+ DatasetRetrievalTraceInfo,
+ MessageTraceInfo,
+ ToolTraceInfo,
+ WorkflowTraceInfo,
+)
from core.rag.models.document import Document
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -31,7 +31,7 @@ from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutio
class TestTencentSpanBuilder:
def test_get_time_nanoseconds(self):
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert:
mock_convert.return_value = 123456789
dt = datetime.now()
result = TencentSpanBuilder._get_time_nanoseconds(dt)
@@ -48,7 +48,7 @@ class TestTencentSpanBuilder:
trace_info.workflow_run_outputs = {"answer": "world"}
trace_info.metadata = {"conversation_id": "conv_id"}
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
@@ -70,7 +70,7 @@ class TestTencentSpanBuilder:
trace_info.workflow_run_outputs = {}
trace_info.metadata = {} # No conversation_id
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 1
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
@@ -98,7 +98,7 @@ class TestTencentSpanBuilder:
}
node_execution.outputs = {"text": "world"}
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 456
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
@@ -123,7 +123,7 @@ class TestTencentSpanBuilder:
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
}
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 456
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
@@ -142,7 +142,7 @@ class TestTencentSpanBuilder:
trace_info.metadata = {"conversation_id": "conv_id"}
trace_info.is_streaming_request = True
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 789
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
@@ -162,7 +162,7 @@ class TestTencentSpanBuilder:
trace_info.metadata = {}
trace_info.is_streaming_request = False
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 789
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
@@ -182,7 +182,7 @@ class TestTencentSpanBuilder:
trace_info.tool_inputs = {"i": 2}
trace_info.tool_outputs = "result"
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 101
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1)
@@ -204,7 +204,7 @@ class TestTencentSpanBuilder:
)
trace_info.documents = [doc]
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 202
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
@@ -222,7 +222,7 @@ class TestTencentSpanBuilder:
trace_info.end_time = datetime.now()
trace_info.documents = []
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 202
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
@@ -264,7 +264,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 303
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
@@ -286,7 +286,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 303
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
@@ -307,7 +307,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 404
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
@@ -329,7 +329,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 404
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
@@ -350,7 +350,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 505
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution)
diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
similarity index 86%
rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py
rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
index 7afd0b824a..54524b09ca 100644
--- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
@@ -1,9 +1,12 @@
+import gc
import logging
-from unittest.mock import MagicMock, patch
+import warnings
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
+from dify_trace_tencent.config import TencentConfig
+from dify_trace_tencent.tencent_trace import TencentDataTrace
-from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -13,7 +16,6 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.ops.tencent_trace.tencent_trace import TencentDataTrace
from graphon.entities import WorkflowNodeExecution
from graphon.enums import BuiltinNodeTypes
from models import Account, App, TenantAccountJoin
@@ -28,19 +30,19 @@ def tencent_config():
@pytest.fixture
def mock_trace_client():
- with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock:
+ with patch("dify_trace_tencent.tencent_trace.TencentTraceClient") as mock:
yield mock
@pytest.fixture
def mock_span_builder():
- with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock:
+ with patch("dify_trace_tencent.tencent_trace.TencentSpanBuilder") as mock:
yield mock
@pytest.fixture
def mock_trace_utils():
- with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock:
+ with patch("dify_trace_tencent.tencent_trace.TencentTraceUtils") as mock:
yield mock
@@ -198,9 +200,9 @@ class TestTencentDataTrace:
trace_info.workflow_run_id = "run-id"
with patch(
- "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
+ "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.workflow_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace")
@@ -230,9 +232,9 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=MessageTraceInfo)
with patch(
- "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
+ "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.message_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace")
@@ -262,9 +264,9 @@ class TestTencentDataTrace:
trace_info.message_id = "msg-id"
with patch(
- "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
+ "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.tool_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace")
@@ -294,22 +296,22 @@ class TestTencentDataTrace:
trace_info.message_id = "msg-id"
with patch(
- "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
+ "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.dataset_retrieval_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace")
def test_suggested_question_trace(self, tencent_data_trace):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
- with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log:
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace")
def test_suggested_question_trace_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
- with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")):
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace")
@@ -342,7 +344,7 @@ class TestTencentDataTrace:
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
# The exception should be caught by the outer handler since convert_to_span_id is called first
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
@@ -351,7 +353,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
@@ -381,7 +383,7 @@ class TestTencentDataTrace:
node.id = "n1"
mock_span_builder.build_workflow_llm_span.side_effect = Exception("error")
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456)
assert result is None
mock_log.assert_called_once()
@@ -403,15 +405,13 @@ class TestTencentDataTrace:
mock_executions = [MagicMock()]
- with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
+ with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.engine = "engine"
- with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
+ with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.side_effect = [app, account, tenant_join]
- with patch(
- "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"
- ) as mock_repo:
+ with patch("dify_trace_tencent.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository") as mock_repo:
mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions
results = tencent_data_trace._get_workflow_node_executions(trace_info)
@@ -423,7 +423,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {}
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
@@ -432,14 +432,14 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"app_id": "app-1"}
- with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
+ with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.init_app = MagicMock() # Ensure init_app is mocked
mock_db.engine = "engine"
- with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
+ with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.return_value = None
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
@@ -449,8 +449,8 @@ class TestTencentDataTrace:
trace_info.tenant_id = "tenant-1"
trace_info.metadata = {"user_id": "user-1"}
- with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
- with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
+ with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
+ with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.init_app = MagicMock()
mock_db.engine = MagicMock()
@@ -476,8 +476,8 @@ class TestTencentDataTrace:
trace_info.tenant_id = "t"
trace_info.metadata = {"user_id": "u"}
- with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")):
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "unknown"
mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID")
@@ -519,7 +519,7 @@ class TestTencentDataTrace:
node.process_data = None
node.outputs = None
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_llm_metrics(node)
# Should not crash
@@ -557,7 +557,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = None
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_llm_metrics(trace_info)
# Should not crash
@@ -609,7 +609,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_workflow_trace_duration(trace_info)
def test_record_message_trace_duration(self, tencent_data_trace):
@@ -631,16 +631,41 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.start_time = None
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_trace_duration(trace_info)
- def test_del(self, tencent_data_trace):
+ def test_close(self, tencent_data_trace):
client = tencent_data_trace.trace_client
- tencent_data_trace.__del__()
+ tencent_data_trace.close()
client.shutdown.assert_called_once()
- def test_del_exception(self, tencent_data_trace):
+ def test_close_is_idempotent(self, tencent_data_trace):
+ client = tencent_data_trace.trace_client
+
+ tencent_data_trace.close()
+ tencent_data_trace.close()
+
+ client.shutdown.assert_called_once()
+
+ def test_close_exception(self, tencent_data_trace):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
- tencent_data_trace.__del__()
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
+ tencent_data_trace.close()
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
+
+ def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
+ shutdown = AsyncMock()
+ tencent_data_trace.trace_client.shutdown = shutdown
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ tencent_data_trace.close()
+ gc.collect()
+
+ shutdown.assert_called_once()
+ assert not [
+ warning
+ for warning in caught
+ if issubclass(warning.category, RuntimeWarning)
+ and "AsyncMockMixin._execute_mock_call" in str(warning.message)
+ ]
diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py
similarity index 88%
rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py
rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py
index ef28d18e20..63c6d680d7 100644
--- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py
@@ -8,10 +8,9 @@ from datetime import UTC, datetime
from unittest.mock import patch
import pytest
+from dify_trace_tencent.utils import TencentTraceUtils
from opentelemetry.trace import Link, TraceFlags
-from core.ops.tencent_trace.utils import TencentTraceUtils
-
def test_convert_to_trace_id_with_valid_uuid() -> None:
uuid_str = "12345678-1234-5678-1234-567812345678"
@@ -20,7 +19,7 @@ def test_convert_to_trace_id_with_valid_uuid() -> None:
def test_convert_to_trace_id_uses_uuid4_when_none() -> None:
expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
- with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
+ with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int
uuid4_mock.assert_called_once()
@@ -45,7 +44,7 @@ def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None:
def test_convert_to_span_id_uses_uuid4_when_none() -> None:
expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
- with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
+ with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
span_id = TencentTraceUtils.convert_to_span_id(None, "workflow")
assert isinstance(span_id, int)
uuid4_mock.assert_called_once()
@@ -58,7 +57,7 @@ def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None:
def test_generate_span_id_skips_invalid_span_id() -> None:
with patch(
- "core.ops.tencent_trace.utils.random.getrandbits",
+ "dify_trace_tencent.utils.random.getrandbits",
side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42],
) as bits_mock:
assert TencentTraceUtils.generate_span_id() == 42
@@ -75,7 +74,7 @@ def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None:
fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC)
expected = int(fixed.timestamp() * 1e9)
- with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock:
+ with patch("dify_trace_tencent.utils.datetime") as datetime_mock:
datetime_mock.now.return_value = fixed
assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected
datetime_mock.now.assert_called_once()
@@ -100,7 +99,7 @@ def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: i
@pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None])
def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None:
fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd")
- with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock:
+ with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock:
link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type]
assert link.context.trace_id == fallback_uuid.int
uuid4_mock.assert_called_once()
diff --git a/api/providers/trace/trace-weave/pyproject.toml b/api/providers/trace/trace-weave/pyproject.toml
new file mode 100644
index 0000000000..ba449f2a93
--- /dev/null
+++ b/api/providers/trace/trace-weave/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-weave"
+version = "0.0.1"
+dependencies = [
+ "weave>=0.52.36",
+]
+description = "Dify ops tracing provider (Weave)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/config.py b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py
new file mode 100644
index 0000000000..5942bd57fe
--- /dev/null
+++ b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py
@@ -0,0 +1,29 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url
+
+
+class WeaveConfig(BaseTracingConfig):
+ """
+ Model class for Weave tracing config.
+ """
+
+ api_key: str
+ entity: str | None = None
+ project: str
+ endpoint: str = "https://trace.wandb.ai"
+ host: str | None = None
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ # Weave only allows HTTPS for endpoint
+ return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
+
+ @field_validator("host")
+ @classmethod
+ def host_validator(cls, v, info: ValidationInfo):
+ if v is not None and v.strip() != "":
+ return validate_url(v, v, allowed_schemes=("https", "http"))
+ return v
diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py
similarity index 100%
rename from api/core/ops/weave_trace/entities/weave_trace_entity.py
rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py
diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed b/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py
similarity index 99%
rename from api/core/ops/weave_trace/weave_trace.py
rename to api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py
index f79544f1c7..4292cbf0f1 100644
--- a/api/core/ops/weave_trace/weave_trace.py
+++ b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py
@@ -17,7 +17,6 @@ from weave.trace_server.trace_server_interface import (
)
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -29,8 +28,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_weave.config import WeaveConfig
+from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..eeb1fe1d87
--- /dev/null
+++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,61 @@
+import pytest
+from dify_trace_weave.config import WeaveConfig
+from pydantic import ValidationError
+
+
+class TestWeaveConfig:
+ """Test cases for WeaveConfig"""
+
+ def test_valid_config(self):
+ """Test valid Weave configuration"""
+ config = WeaveConfig(
+ api_key="test_key",
+ entity="test_entity",
+ project="test_project",
+ endpoint="https://custom.wandb.ai",
+ host="https://custom.host.com",
+ )
+ assert config.api_key == "test_key"
+ assert config.entity == "test_entity"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.wandb.ai"
+ assert config.host == "https://custom.host.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = WeaveConfig(api_key="key", project="project")
+ assert config.entity is None
+ assert config.endpoint == "https://trace.wandb.ai"
+ assert config.host is None
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ WeaveConfig()
+
+ with pytest.raises(ValidationError):
+ WeaveConfig(api_key="key")
+
+ with pytest.raises(ValidationError):
+ WeaveConfig(project="project")
+
+ def test_endpoint_validation_https_only(self):
+ """Test endpoint validation only allows HTTPS"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
+
+ def test_host_validation_optional(self):
+ """Test host validation is optional but validates when provided"""
+ config = WeaveConfig(api_key="key", project="project", host=None)
+ assert config.host is None
+
+ config = WeaveConfig(api_key="key", project="project", host="")
+ assert config.host == ""
+
+ config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
+ assert config.host == "https://valid.host.com"
+
+ def test_host_validation_invalid_scheme(self):
+ """Test host validation rejects invalid schemes when provided"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py
similarity index 97%
rename from api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py
rename to api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py
index 531c7de05f..6028d0c550 100644
--- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py
+++ b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py
@@ -1,4 +1,4 @@
-"""Comprehensive tests for core.ops.weave_trace.weave_trace module."""
+"""Comprehensive tests for dify_trace_weave.weave_trace module."""
from __future__ import annotations
@@ -7,9 +7,11 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
+from dify_trace_weave.config import WeaveConfig
+from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel
+from dify_trace_weave.weave_trace import WeaveDataTrace
from weave.trace_server.trace_server_interface import TraceStatus
-from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -20,8 +22,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
-from core.ops.weave_trace.weave_trace import WeaveDataTrace
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
# ── Helpers ──────────────────────────────────────────────────────────────────
@@ -191,14 +191,14 @@ def _make_node(**overrides):
@pytest.fixture
def mock_wandb():
- with patch("core.ops.weave_trace.weave_trace.wandb") as mock:
+ with patch("dify_trace_weave.weave_trace.wandb") as mock:
mock.login.return_value = True
yield mock
@pytest.fixture
def mock_weave():
- with patch("core.ops.weave_trace.weave_trace.weave") as mock:
+ with patch("dify_trace_weave.weave_trace.weave") as mock:
client = MagicMock()
client.entity = "my-entity"
client.project = "my-project"
@@ -307,7 +307,7 @@ class TestGetProjectUrl:
monkeypatch.setattr(trace_instance, "entity", None)
monkeypatch.setattr(trace_instance, "project_name", None)
# Force an error by making string formatting fail
- with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger:
+ with patch("dify_trace_weave.weave_trace.logger") as mock_logger:
# Simulate exception via property
original_entity = trace_instance.entity
trace_instance.entity = None
@@ -594,9 +594,9 @@ class TestWorkflowTrace:
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory)
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock())
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_weave.weave_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
return repo
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch):
@@ -703,8 +703,8 @@ class TestWorkflowTrace:
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch):
"""Raises ValueError when app_id is missing from metadata."""
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock())
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
trace_info = _make_workflow_trace_info(
message_id=None,
@@ -802,7 +802,7 @@ class TestMessageTrace:
def test_basic_message_trace(self, trace_instance, monkeypatch):
"""message_trace creates message run and llm child run."""
monkeypatch.setattr(
- "core.ops.weave_trace.weave_trace.db.session.get",
+ "dify_trace_weave.weave_trace.db.session.get",
lambda model, pk: None,
)
@@ -824,7 +824,7 @@ class TestMessageTrace:
mock_db = MagicMock()
mock_db.session.get.return_value = None
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -846,7 +846,7 @@ class TestMessageTrace:
mock_db = MagicMock()
mock_db.session.get.return_value = end_user
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -866,7 +866,7 @@ class TestMessageTrace:
"""message_trace handles when from_end_user_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -884,7 +884,7 @@ class TestMessageTrace:
"""trace_id falls back to message_id when trace_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -899,7 +899,7 @@ class TestMessageTrace:
"""message_trace handles file_list=None gracefully."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
diff --git a/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py
index b2908ebdae..11398efb58 100644
--- a/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py
+++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py
@@ -1,6 +1,6 @@
import json
import uuid
-from collections.abc import Iterator
+from collections.abc import Generator # Added Generator
from contextlib import contextmanager
from typing import Any
@@ -75,7 +75,7 @@ class AnalyticdbVectorBySql:
)
@contextmanager
- def _get_cursor(self) -> Iterator[Any]:
+ def _get_cursor(self) -> Generator[Any, None, None]: # Changed from Iterator[Any]
assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()
diff --git a/api/pyproject.toml b/api/pyproject.toml
index a1ceea181e..b13c744f0a 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -6,7 +6,7 @@ requires-python = "~=3.12.0"
dependencies = [
# Legacy: mature and widely deployed
"bleach>=6.3.0",
- "boto3>=1.42.88",
+ "boto3>=1.42.91",
"celery>=5.6.3",
"croniter>=6.2.2",
"flask-cors>=6.0.2",
@@ -30,11 +30,8 @@ dependencies = [
"flask-migrate>=4.1.0,<5.0.0",
"flask-orjson>=2.0.0,<3.0.0",
"flask-restx>=1.3.2,<2.0.0",
- "google-cloud-aiplatform>=1.147.0,<2.0.0",
+ "google-cloud-aiplatform>=1.148.1,<2.0.0",
"httpx[socks]>=0.28.1,<1.0.0",
- "langfuse>=4.2.0,<5.0.0",
- "langsmith>=0.7.31,<1.0.0",
- "mlflow-skinny>=3.11.1,<4.0.0",
"opentelemetry-distro>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-flask>=0.62b0,<1.0.0",
@@ -44,15 +41,12 @@ dependencies = [
"opentelemetry-propagator-b3>=1.41.0,<2.0.0",
"readabilipy>=0.3.0,<1.0.0",
"resend>=2.27.0,<3.0.0",
- "weave>=0.52.36,<1.0.0",
# Emerging: newer and fast-moving, use compatible pins
- "arize-phoenix-otel~=0.15.0",
"fastopenapi[flask]~=0.7.0",
- "graphon~=0.1.2",
+ "graphon~=0.2.2",
"httpx-sse~=0.4.0",
- "json-repair~=0.59.2",
- "opik~=1.11.2",
+ "json-repair~=0.59.4",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@@ -61,8 +55,8 @@ dependencies = [
packages = []
[tool.uv.workspace]
-members = ["providers/vdb/*"]
-exclude = ["providers/vdb/__pycache__"]
+members = ["providers/vdb/*", "providers/trace/*"]
+exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"]
[tool.uv.sources]
dify-vdb-alibabacloud-mysql = { workspace = true }
@@ -95,9 +89,17 @@ dify-vdb-upstash = { workspace = true }
dify-vdb-vastbase = { workspace = true }
dify-vdb-vikingdb = { workspace = true }
dify-vdb-weaviate = { workspace = true }
+dify-trace-aliyun = { workspace = true }
+dify-trace-arize-phoenix = { workspace = true }
+dify-trace-langfuse = { workspace = true }
+dify-trace-langsmith = { workspace = true }
+dify-trace-mlflow = { workspace = true }
+dify-trace-opik = { workspace = true }
+dify-trace-tencent = { workspace = true }
+dify-trace-weave = { workspace = true }
[tool.uv]
-default-groups = ["storage", "tools", "vdb-all"]
+default-groups = ["storage", "tools", "vdb-all", "trace-all"]
package = false
override-dependencies = [
"pyarrow>=18.0.0",
@@ -171,8 +173,8 @@ dev = [
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
- "pyrefly>=0.60.0",
- "xinference-client>=2.4.0",
+ "pyrefly>=0.61.1",
+ "xinference-client>=2.5.0",
]
############################################################
@@ -181,13 +183,13 @@ dev = [
############################################################
storage = [
"azure-storage-blob>=12.28.0",
- "bce-python-sdk>=0.9.69",
+ "bce-python-sdk>=0.9.70",
"cos-python-sdk-v5>=1.9.41",
"esdk-obs-python>=3.22.2",
"google-cloud-storage>=3.10.1",
"opendal>=0.46.0",
"oss2>=2.19.1",
- "supabase>=2.18.1",
+ "supabase>=2.28.3",
"tos>=2.9.0",
]
@@ -264,7 +266,26 @@ vdb-vastbase = ["dify-vdb-vastbase"]
vdb-vikingdb = ["dify-vdb-vikingdb"]
vdb-weaviate = ["dify-vdb-weaviate"]
# Optional client used by some tests / integrations (not a vector backend plugin)
-vdb-xinference = ["xinference-client>=2.4.0"]
+vdb-xinference = ["xinference-client>=2.5.0"]
+
+trace-all = [
+ "dify-trace-aliyun",
+ "dify-trace-arize-phoenix",
+ "dify-trace-langfuse",
+ "dify-trace-langsmith",
+ "dify-trace-mlflow",
+ "dify-trace-opik",
+ "dify-trace-tencent",
+ "dify-trace-weave",
+]
+trace-aliyun = ["dify-trace-aliyun"]
+trace-arize-phoenix = ["dify-trace-arize-phoenix"]
+trace-langfuse = ["dify-trace-langfuse"]
+trace-langsmith = ["dify-trace-langsmith"]
+trace-mlflow = ["dify-trace-mlflow"]
+trace-opik = ["dify-trace-opik"]
+trace-tencent = ["dify-trace-tencent"]
+trace-weave = ["dify-trace-weave"]
[tool.pyrefly]
project-includes = ["."]
diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt
index 3e5ece1fcf..fbbca24558 100644
--- a/api/pyrefly-local-excludes.txt
+++ b/api/pyrefly-local-excludes.txt
@@ -34,12 +34,12 @@ core/external_data_tool/api/api.py
core/llm_generator/llm_generator.py
core/llm_generator/output_parser/structured_output.py
core/mcp/mcp_client.py
-core/ops/aliyun_trace/data_exporter/traceclient.py
-core/ops/arize_phoenix_trace/arize_phoenix_trace.py
-core/ops/mlflow_trace/mlflow_trace.py
+providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
+providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
+providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
core/ops/ops_trace_manager.py
-core/ops/tencent_trace/client.py
-core/ops/tencent_trace/utils.py
+providers/trace/trace-tencent/src/dify_trace_tencent/client.py
+providers/trace/trace-tencent/src/dify_trace_tencent/utils.py
core/plugin/backwards_invocation/base.py
core/plugin/backwards_invocation/model.py
core/prompt/utils/extract_thread_messages.py
diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json
index c4582e891d..ac0e2a3a53 100644
--- a/api/pyrightconfig.json
+++ b/api/pyrightconfig.json
@@ -5,7 +5,8 @@
".venv",
"migrations/",
"core/rag",
- "providers/",
+ "providers/vdb/",
+ "providers/trace/*/tests",
],
"typeCheckingMode": "strict",
"allowedUntypedLibraries": [
diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py
index 8479cdfb0c..2cc0192a4a 100644
--- a/api/schedule/mail_clean_document_notify_task.py
+++ b/api/schedule/mail_clean_document_notify_task.py
@@ -7,8 +7,8 @@ from sqlalchemy import select
import app
from configs import dify_config
+from core.db.session_factory import session_factory
from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
from models import Account, Tenant, TenantAccountJoin
@@ -33,67 +33,68 @@ def mail_clean_document_notify_task():
# send document clean notify mail
try:
- dataset_auto_disable_logs = db.session.scalars(
- select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
- ).all()
- # group by tenant_id
- dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
- for dataset_auto_disable_log in dataset_auto_disable_logs:
- if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
- dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
- dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
- url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
- for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
- features = FeatureService.get_features(tenant_id)
- plan = features.billing.subscription.plan
- if plan != CloudPlan.SANDBOX:
- knowledge_details = []
- # check tenant
- tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
- if not tenant:
- continue
- # check current owner
- current_owner_join = db.session.scalar(
- select(TenantAccountJoin)
- .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
- .limit(1)
- )
- if not current_owner_join:
- continue
- account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
- if not account:
- continue
+ with session_factory.create_session() as session:
+ dataset_auto_disable_logs = session.scalars(
+ select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False))
+ ).all()
+ # group by tenant_id
+ dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
+ for dataset_auto_disable_log in dataset_auto_disable_logs:
+ if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
+ dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
+ dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
+ url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
+ for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
+ features = FeatureService.get_features(tenant_id)
+ plan = features.billing.subscription.plan
+ if plan != CloudPlan.SANDBOX:
+ knowledge_details = []
+ # check tenant
+ tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id))
+ if not tenant:
+ continue
+ # check current owner
+ current_owner_join = session.scalar(
+ select(TenantAccountJoin)
+ .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
+ .limit(1)
+ )
+ if not current_owner_join:
+ continue
+ account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
+ if not account:
+ continue
- dataset_auto_dataset_map = {} # type: ignore
+ dataset_auto_dataset_map = {} # type: ignore
+ for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
+ if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
+ dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
+ dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
+ dataset_auto_disable_log.document_id
+ )
+
+ for dataset_id, document_ids in dataset_auto_dataset_map.items():
+ dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id))
+ if dataset:
+ document_count = len(document_ids)
+ knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
+ if knowledge_details:
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
+ language_code="en-US",
+ to=account.email,
+ template_context={
+ "userName": account.email,
+ "knowledge_details": knowledge_details,
+ "url": url,
+ },
+ )
+
+ # update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
- if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
- dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
- dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
- dataset_auto_disable_log.document_id
- )
-
- for dataset_id, document_ids in dataset_auto_dataset_map.items():
- dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
- if dataset:
- document_count = len(document_ids)
- knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
- if knowledge_details:
- email_service = get_email_i18n_service()
- email_service.send_email(
- email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
- language_code="en-US",
- to=account.email,
- template_context={
- "userName": account.email,
- "knowledge_details": knowledge_details,
- "url": url,
- },
- )
-
- # update notified to True
- for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
- dataset_auto_disable_log.notified = True
- db.session.commit()
+ dataset_auto_disable_log.notified = True
+ session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
diff --git a/api/services/account_service.py b/api/services/account_service.py
index ccc4a7c1fa..b6554a3de7 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -112,6 +112,14 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
+ # Phase-bound token metadata for the change-email flow. Tokens carry the
+ # current phase so that downstream endpoints can enforce proper progression
+ CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
+ CHANGE_EMAIL_PHASE_OLD = "old_email"
+ CHANGE_EMAIL_PHASE_OLD_VERIFIED = "old_email_verified"
+ CHANGE_EMAIL_PHASE_NEW = "new_email"
+ CHANGE_EMAIL_PHASE_NEW_VERIFIED = "new_email_verified"
+
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
@@ -576,13 +584,20 @@ class AccountService:
raise ValueError("Email must be provided.")
if not phase:
raise ValueError("phase must be provided.")
+ if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
+ raise ValueError("phase must be one of old_email or new_email.")
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
- code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
+ code, token = cls.generate_change_email_token(
+ account_email,
+ account,
+ old_email=old_email,
+ additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
+ )
send_change_mail_task.delay(
language=language,
diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py
index 78806927bc..97aaea3395 100644
--- a/api/services/app_dsl_service.py
+++ b/api/services/app_dsl_service.py
@@ -17,6 +17,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
+from constants.dsl_version import CURRENT_APP_DSL_VERSION
from core.helper import ssrf_proxy
from core.plugin.entities.plugin import PluginDependency
from core.trigger.constants import (
@@ -50,7 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
-CURRENT_DSL_VERSION = "0.6.0"
+CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION
class Import(BaseModel):
diff --git a/api/services/app_service.py b/api/services/app_service.py
index afd98e2975..a046b909b3 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
@@ -303,17 +303,22 @@ class AppService:
return app
- def update_app_icon(self, app: App, icon: str, icon_background: str) -> App:
+ def update_app_icon(
+ self, app: App, icon: str, icon_background: str, icon_type: IconType | str | None = None
+ ) -> App:
"""
Update app icon
:param app: App instance
:param icon: new icon
:param icon_background: new icon_background
+ :param icon_type: new icon type
:return: App instance
"""
assert current_user is not None
app.icon = icon
app.icon_background = icon_background
+ if icon_type is not None:
+ app.icon_type = icon_type if isinstance(icon_type, IconType) else IconType(icon_type)
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
db.session.commit()
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index e6f5f80a6d..894cb05687 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -30,7 +30,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
diff --git a/api/services/feature_service.py b/api/services/feature_service.py
index d0d3fbd66b..e18eb096c9 100644
--- a/api/services/feature_service.py
+++ b/api/services/feature_service.py
@@ -3,6 +3,7 @@ from enum import StrEnum
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
+from constants.dsl_version import CURRENT_APP_DSL_VERSION
from enums.cloud_plan import CloudPlan
from enums.hosted_provider import HostedTrialProvider
from services.billing_service import BillingService
@@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel):
class SystemFeatureModel(BaseModel):
+ app_dsl_version: str = ""
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
enable_marketplace: bool = False
@@ -225,6 +227,7 @@ class FeatureService:
@classmethod
def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
system_features = SystemFeatureModel()
+ system_features.app_dsl_version = CURRENT_APP_DSL_VERSION
cls._fulfill_system_params_from_env(system_features)
diff --git a/api/services/file_service.py b/api/services/file_service.py
index 52da2a7951..f60afe2f19 100644
--- a/api/services/file_service.py
+++ b/api/services/file_service.py
@@ -2,7 +2,7 @@ import base64
import hashlib
import os
import uuid
-from collections.abc import Iterator, Sequence
+from collections.abc import Generator, Sequence # Changed Iterator to Generator
from contextlib import contextmanager, suppress
from tempfile import NamedTemporaryFile
from typing import Literal
@@ -324,7 +324,7 @@ class FileService:
def build_upload_files_zip_tempfile(
*,
upload_files: Sequence[UploadFile],
- ) -> Iterator[str]:
+ ) -> Generator[str, None, None]: # Changed from Iterator[str]
"""
Build a ZIP from `UploadFile`s and yield a tempfile path.
diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py
index ca84b2a3d8..2e5987dd28 100644
--- a/api/services/hit_testing_service.py
+++ b/api/services/hit_testing_service.py
@@ -1,10 +1,10 @@
import json
import logging
import time
-from typing import Any, TypedDict
+from typing import Any, TypedDict, cast
from core.app.app_config.entities import ModelConfig
-from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@@ -36,6 +36,10 @@ default_retrieval_model = {
}
+class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
+ metadata_filtering_conditions: dict[str, Any]
+
+
class HitTestingService:
@classmethod
def retrieve(
@@ -51,17 +55,18 @@ class HitTestingService:
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
- if not retrieval_model:
- retrieval_model = dataset.retrieval_model or default_retrieval_model
- assert isinstance(retrieval_model, dict)
+ resolved_retrieval_model = cast(
+ HitTestingRetrievalModelDict,
+ retrieval_model or dataset.retrieval_model or default_retrieval_model,
+ )
document_ids_filter = None
- metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
- if metadata_filtering_conditions and query:
+ metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {})
+ if metadata_filtering_conditions_raw and query:
dataset_retrieval = DatasetRetrieval()
from core.rag.entities import MetadataFilteringCondition
- metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
+ metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw)
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id],
@@ -78,19 +83,21 @@ class HitTestingService:
if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve(
- retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
+ retrieval_method=RetrievalMethod(
+ resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)
+ ),
dataset_id=dataset.id,
query=query,
attachment_ids=attachment_ids,
- top_k=retrieval_model.get("top_k", 4),
- score_threshold=retrieval_model.get("score_threshold", 0.0)
- if retrieval_model["score_threshold_enabled"]
+ top_k=resolved_retrieval_model.get("top_k", 4),
+ score_threshold=resolved_retrieval_model.get("score_threshold", 0.0)
+ if resolved_retrieval_model["score_threshold_enabled"]
else 0.0,
- reranking_model=retrieval_model.get("reranking_model", None)
- if retrieval_model["reranking_enable"]
+ reranking_model=resolved_retrieval_model.get("reranking_model", None)
+ if resolved_retrieval_model["reranking_enable"]
else None,
- reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
- weights=retrieval_model.get("weights", None),
+ reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model",
+ weights=resolved_retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
)
diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py
index 68ef67dec1..8b4983e5f7 100644
--- a/api/services/human_input_delivery_test_service.py
+++ b/api/services/human_input_delivery_test_service.py
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py
index aa7456dcd3..8c9a81af87 100644
--- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py
+++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py
@@ -50,7 +50,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language
:return:
"""
- builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
+ builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod
@@ -60,5 +60,5 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param template_id: Template ID
:return:
"""
- builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
+ builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(template_id)
diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py
index 0ffbef8365..9d446f6d4b 100644
--- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py
+++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import Any, TypedDict
import yaml
from sqlalchemy import select
@@ -10,6 +10,30 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
+class CustomizedTemplateItemDict(TypedDict):
+ id: str
+ name: str
+ description: str
+ icon: dict[str, Any]
+ position: int
+ chunk_structure: str
+
+
+class CustomizedTemplatesResultDict(TypedDict):
+ pipeline_templates: list[CustomizedTemplateItemDict]
+
+
+class CustomizedTemplateDetailDict(TypedDict):
+ id: str
+ name: str
+ icon_info: dict[str, Any]
+ description: str
+ chunk_structure: str
+ export_data: str
+ graph: dict[str, Any]
+ created_by: str
+
+
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval recommended app from database
@@ -17,12 +41,10 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
_, current_tenant_id = current_account_with_tenant()
- result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
- return result
+ return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
- result = self.fetch_pipeline_template_detail_from_db(template_id)
- return result
+ return self.fetch_pipeline_template_detail_from_db(template_id)
def get_type(self) -> str:
return PipelineTemplateType.CUSTOMIZED
@@ -40,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
).all()
- recommended_pipelines_results = []
+ recommended_pipelines_results: list[CustomizedTemplateItemDict] = []
for pipeline_customized_template in pipeline_customized_templates:
- recommended_pipeline_result = {
+ recommended_pipeline_result: CustomizedTemplateItemDict = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
"description": pipeline_customized_template.description,
diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py
index 073eed221c..2964537c35 100644
--- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py
+++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import Any, TypedDict
import yaml
from sqlalchemy import select
@@ -9,18 +9,41 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
+class PipelineTemplateItemDict(TypedDict):
+ id: str
+ name: str
+ description: str
+ icon: dict[str, Any]
+ copyright: str
+ privacy_policy: str
+ position: int
+ chunk_structure: str
+
+
+class PipelineTemplatesResultDict(TypedDict):
+ pipeline_templates: list[PipelineTemplateItemDict]
+
+
+class PipelineTemplateDetailDict(TypedDict):
+ id: str
+ name: str
+ icon_info: dict[str, Any]
+ description: str
+ chunk_structure: str
+ export_data: str
+ graph: dict[str, Any]
+
+
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval pipeline template from database
"""
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
- result = self.fetch_pipeline_templates_from_db(language)
- return result
+ return self.fetch_pipeline_templates_from_db(language)
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
- result = self.fetch_pipeline_template_detail_from_db(template_id)
- return result
+ return self.fetch_pipeline_template_detail_from_db(template_id)
def get_type(self) -> str:
return PipelineTemplateType.DATABASE
@@ -39,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
).all()
)
- recommended_pipelines_results = []
+ recommended_pipelines_results: list[PipelineTemplateItemDict] = []
for pipeline_built_in_template in pipeline_built_in_templates:
- recommended_pipeline_result = {
+ recommended_pipeline_result: PipelineTemplateItemDict = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,
"description": pipeline_built_in_template.description,
diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py
index d5ef745bec..9565ac46cc 100644
--- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py
+++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py
@@ -17,21 +17,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
- result: dict[str, Any] | None
try:
- result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
+ return self.fetch_pipeline_template_detail_from_dify_official(template_id)
except Exception as e:
logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e)
- result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
- return result
+ return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
try:
- result = self.fetch_pipeline_templates_from_dify_official(language)
+ return self.fetch_pipeline_templates_from_dify_official(language)
except Exception as e:
logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e)
- result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
- return result
+ return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
def get_type(self) -> str:
return PipelineTemplateType.REMOTE
diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py
index 968600d1bc..9db6682e10 100644
--- a/api/services/rag_pipeline/rag_pipeline.py
+++ b/api/services/rag_pipeline/rag_pipeline.py
@@ -476,7 +476,7 @@ class RagPipelineService:
:param filters: filter by node config parameters.
:return:
"""
- node_type_enum = NodeType(node_type)
+ node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py
index a91f49e9e6..cf39469be8 100644
--- a/api/services/summary_index_service.py
+++ b/api/services/summary_index_service.py
@@ -349,7 +349,6 @@ class SummaryIndexService:
summary_record_id,
)
summary_record_in_session = DocumentSegmentSummary(
- id=summary_record_id, # Use the same ID if available
dataset_id=dataset.id,
document_id=segment.document_id,
chunk_id=segment.id,
@@ -360,6 +359,9 @@ class SummaryIndexService:
status=SummaryStatus.COMPLETED,
enabled=True,
)
+ if summary_record_in_session is None:
+ raise RuntimeError("summary_record_in_session should not be None at this point")
+ summary_record_in_session.id = summary_record_id
session.add(summary_record_in_session)
logger.info(
"Created new summary record (id=%s) for segment %s after vectorization",
diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py
index 8ad010f62d..5ff2c21749 100644
--- a/api/services/tools/api_tools_manage_service.py
+++ b/api/services/tools/api_tools_manage_service.py
@@ -4,6 +4,7 @@ from typing import Any, TypedDict, cast
from httpx import get
from sqlalchemy import select
+from sqlalchemy.orm import sessionmaker
from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_runtime import ToolRuntime
@@ -15,6 +16,7 @@ from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ApiProviderSchemaType,
)
+from core.tools.errors import ApiToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_tool_provider_encrypter
@@ -116,71 +118,85 @@ class ApiToolManageService:
privacy_policy: str,
custom_disclaimer: str,
labels: list[str],
- ):
+ ) -> dict[str, Any]:
"""
- create api tool provider
+ Create a new API tool provider.
+
+ :param user_id: The ID of the user creating the provider.
+ :param tenant_id: The ID of the workspace/tenant.
+ :param provider_name: The name of the API tool provider.
+ :param icon: The icon configuration for the provider.
+ :param credentials: The credentials for the provider.
+ :param schema_type: The type of schema (e.g., OpenAPI).
+ :param schema: The raw schema string.
+ :param privacy_policy: The privacy policy URL or text.
+ :param custom_disclaimer: Custom disclaimer text.
+ :param labels: A list of labels for the provider.
+ :return: A dictionary indicating the result status.
"""
+
provider_name = provider_name.strip()
# check if the provider exists
- provider = db.session.scalar(
- select(ApiToolProvider)
- .where(
- ApiToolProvider.tenant_id == tenant_id,
- ApiToolProvider.name == provider_name,
+ # Create new session with automatic transaction management
+ with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
+ provider: ApiToolProvider | None = _session.scalar(
+ select(ApiToolProvider)
+ .where(
+ ApiToolProvider.tenant_id == tenant_id,
+ ApiToolProvider.name == provider_name,
+ )
+ .limit(1)
)
- .limit(1)
- )
- if provider is not None:
- raise ValueError(f"provider {provider_name} already exists")
+ if provider is not None:
+ raise ValueError(f"provider {provider_name} already exists")
- # parse openapi to tool bundle
- extra_info: dict[str, str] = {}
- # extra info like description will be set here
- tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
+ # parse openapi to tool bundle
+ extra_info: dict[str, str] = {}
+ # extra info like description will be set here
+ tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
- if len(tool_bundles) > 100:
- raise ValueError("the number of apis should be less than 100")
+ if len(tool_bundles) > 100:
+ raise ValueError("the number of apis should be less than 100")
- # create db provider
- db_provider = ApiToolProvider(
- tenant_id=tenant_id,
- user_id=user_id,
- name=provider_name,
- icon=json.dumps(icon),
- schema=schema,
- description=extra_info.get("description", ""),
- schema_type_str=schema_type,
- tools_str=json.dumps(jsonable_encoder(tool_bundles)),
- credentials_str="{}",
- privacy_policy=privacy_policy,
- custom_disclaimer=custom_disclaimer,
- )
+ # create API tool provider
+ api_tool_provider = ApiToolProvider(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ name=provider_name,
+ icon=json.dumps(icon),
+ schema=schema,
+ description=extra_info.get("description", ""),
+ schema_type_str=schema_type,
+ tools_str=json.dumps(jsonable_encoder(tool_bundles)),
+ credentials_str="{}",
+ privacy_policy=privacy_policy,
+ custom_disclaimer=custom_disclaimer,
+ )
- if "auth_type" not in credentials:
- raise ValueError("auth_type is required")
+ if "auth_type" not in credentials:
+ raise ValueError("auth_type is required")
- # get auth type, none or api key
- auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
+ # get auth type, none or api key
+ auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
- # create provider entity
- provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
- # load tools into provider entity
- provider_controller.load_bundled_tools(tool_bundles)
+ # create provider entity
+ provider_controller = ApiToolProviderController.from_db(api_tool_provider, auth_type)
+ # load tools into provider entity
+ provider_controller.load_bundled_tools(tool_bundles)
- # encrypt credentials
- encrypter, _ = create_tool_provider_encrypter(
- tenant_id=tenant_id,
- controller=provider_controller,
- )
- db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
+ # encrypt credentials
+ encrypter, _ = create_tool_provider_encrypter(
+ tenant_id=tenant_id,
+ controller=provider_controller,
+ )
+ api_tool_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
- db.session.add(db_provider)
- db.session.commit()
+ _session.add(api_tool_provider)
- # update labels
- ToolLabelManager.update_tool_labels(provider_controller, labels)
+ # update labels
+ ToolLabelManager.update_tool_labels(provider_controller, labels, _session)
return {"result": "success"}
@@ -212,16 +228,25 @@ class ApiToolManageService:
@staticmethod
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
"""
- list api tool provider tools
+ List tools provided by a specific API tool provider.
+
+ :param user_id: The ID of the user requesting the list.
+ :param tenant_id: The ID of the workspace/tenant.
+ :param provider_name: The name of the API tool provider.
+ :return: A list of ToolApiEntity objects.
"""
- provider: ApiToolProvider | None = db.session.scalar(
- select(ApiToolProvider)
- .where(
- ApiToolProvider.tenant_id == tenant_id,
- ApiToolProvider.name == provider_name,
+
+ # create new session with automatic transaction management
+ provider: ApiToolProvider | None = None
+ with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
+ provider = _session.scalar(
+ select(ApiToolProvider)
+ .where(
+ ApiToolProvider.tenant_id == tenant_id,
+ ApiToolProvider.name == provider_name,
+ )
+ .limit(1)
)
- .limit(1)
- )
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
@@ -251,103 +276,133 @@ class ApiToolManageService:
privacy_policy: str | None,
custom_disclaimer: str,
labels: list[str],
- ):
+ ) -> dict[str, Any]:
"""
- update api tool provider
+ Update an existing API tool provider.
+
+ :param user_id: The ID of the user updating the provider.
+ :param tenant_id: The ID of the workspace/tenant.
+ :param provider_name: The new name of the API tool provider.
+ :param original_provider: The original name of the API tool provider.
+ :param icon: The icon configuration for the provider.
+ :param credentials: The credentials for the provider.
+ :param _schema_type: The type of schema (e.g., OpenAPI).
+ :param schema: The raw schema string.
+ :param privacy_policy: The privacy policy URL or text.
+ :param custom_disclaimer: Custom disclaimer text.
+ :param labels: A list of labels for the provider.
+ :return: A dictionary indicating the result status.
"""
+
provider_name = provider_name.strip()
# check if the provider exists
- provider = db.session.scalar(
- select(ApiToolProvider)
- .where(
- ApiToolProvider.tenant_id == tenant_id,
- ApiToolProvider.name == original_provider,
+ # create new session with automatic transaction management
+ with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
+ provider: ApiToolProvider | None = _session.scalar(
+ select(ApiToolProvider)
+ .where(
+ ApiToolProvider.tenant_id == tenant_id,
+ ApiToolProvider.name == original_provider,
+ )
+ .limit(1)
)
- .limit(1)
- )
- if provider is None:
- raise ValueError(f"api provider {provider_name} does not exists")
- # parse openapi to tool bundle
- extra_info: dict[str, str] = {}
- # extra info like description will be set here
- tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
+ if provider is None:
+ raise ApiToolProviderNotFoundError(provider_name=original_provider, tenant_id=tenant_id)
- # update db provider
- provider.name = provider_name
- provider.icon = json.dumps(icon)
- provider.schema = schema
- provider.description = extra_info.get("description", "")
- provider.schema_type_str = schema_type
- provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
- provider.privacy_policy = privacy_policy
- provider.custom_disclaimer = custom_disclaimer
+ # parse openapi to tool bundle
+ extra_info: dict[str, str] = {}
+ # extra info like description will be set here
+ tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
- if "auth_type" not in credentials:
- raise ValueError("auth_type is required")
+ # update db provider
+ provider.name = provider_name
+ provider.icon = json.dumps(icon)
+ provider.schema = schema
+ provider.description = extra_info.get("description", "")
+ provider.schema_type_str = schema_type
+ provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
+ provider.privacy_policy = privacy_policy
+ provider.custom_disclaimer = custom_disclaimer
- # get auth type, none or api key
- auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
+ if "auth_type" not in credentials:
+ raise ValueError("auth_type is required")
- # create provider entity
- provider_controller = ApiToolProviderController.from_db(provider, auth_type)
- # load tools into provider entity
- provider_controller.load_bundled_tools(tool_bundles)
+ # get auth type, none or api key
+ auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
- # get original credentials if exists
- encrypter, cache = create_tool_provider_encrypter(
- tenant_id=tenant_id,
- controller=provider_controller,
- )
+ # create provider entity
+ provider_controller = ApiToolProviderController.from_db(provider, auth_type)
+ # load tools into provider entity
+ provider_controller.load_bundled_tools(tool_bundles)
- original_credentials = encrypter.decrypt(provider.credentials)
- masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
- # check if the credential has changed, save the original credential
- for name, value in credentials.items():
- if name in masked_credentials and value == masked_credentials[name]:
- credentials[name] = original_credentials[name]
+ # get original credentials if exists
+ encrypter, cache = create_tool_provider_encrypter(
+ tenant_id=tenant_id,
+ controller=provider_controller,
+ )
- credentials = dict(encrypter.encrypt(credentials))
- provider.credentials_str = json.dumps(credentials)
+ original_credentials = encrypter.decrypt(provider.credentials)
+ masked_credentials = encrypter.mask_plugin_credentials(original_credentials)
- db.session.add(provider)
- db.session.commit()
+ # check if the credential has changed, save the original credential
+ for name, value in credentials.items():
+ if name in masked_credentials and value == masked_credentials[name]:
+ credentials[name] = original_credentials[name]
+
+ credentials = dict(encrypter.encrypt(credentials))
+ provider.credentials_str = json.dumps(credentials)
+
+ _session.add(provider)
+
+ # update labels
+ ToolLabelManager.update_tool_labels(provider_controller, labels, _session)
# delete cache
cache.delete()
- # update labels
- ToolLabelManager.update_tool_labels(provider_controller, labels)
-
return {"result": "success"}
@staticmethod
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
- delete tool provider
+ Delete an API tool provider.
+
+ :param user_id: The ID of the user performing the deletion operation.
+ :param tenant_id: The ID of the workspace/tenant where the provider belongs.
+ :param provider_name: The unique name of the API tool provider to be deleted.
+ :raises ValueError: If the specified provider does not exist in the tenant.
+ :return: A dictionary indicating the result status.
"""
- provider = db.session.scalar(
- select(ApiToolProvider)
- .where(
- ApiToolProvider.tenant_id == tenant_id,
- ApiToolProvider.name == provider_name,
+
+ # create new session with automatic transaction management
+ with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
+ provider: ApiToolProvider | None = _session.scalar(
+ select(ApiToolProvider)
+ .where(
+ ApiToolProvider.tenant_id == tenant_id,
+ ApiToolProvider.name == provider_name,
+ )
+ .limit(1)
)
- .limit(1)
- )
- if provider is None:
- raise ValueError(f"you have not added provider {provider_name}")
+ if provider is None:
+ raise ValueError(f"you have not added provider {provider_name}")
- db.session.delete(provider)
- db.session.commit()
+ _session.delete(provider)
return {"result": "success"}
@staticmethod
- def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
+ def get_api_tool_provider(user_id: str, tenant_id: str, provider: str) -> dict[str, Any]:
"""
- get api tool provider
+ Get API tool provider details.
+
+ :param user_id: The ID of the user requesting the provider.
+ :param tenant_id: The ID of the workspace/tenant.
+ :param provider: The name of the API tool provider.
+ :return: A dictionary containing the provider details.
"""
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
@@ -360,10 +415,20 @@ class ApiToolManageService:
parameters: dict[str, Any],
schema_type: ApiProviderSchemaType,
schema: str,
- ):
+ ) -> dict[str, Any]:
"""
- test api tool before adding api tool provider
+ Test an API tool before adding the API tool provider.
+
+ :param tenant_id: The ID of the workspace/tenant.
+ :param provider_name: The name of the API tool provider.
+ :param tool_name: The name of the specific tool to test.
+ :param credentials: The credentials for the provider.
+ :param parameters: The parameters to pass to the tool.
+ :param schema_type: The type of schema (e.g., OpenAPI).
+ :param schema: The raw schema string.
+ :return: A dictionary containing the result or error message.
"""
+
if schema_type not in [member.value for member in ApiProviderSchemaType]:
raise ValueError(f"invalid schema type {schema_type}")
@@ -377,18 +442,21 @@ class ApiToolManageService:
if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}")
- db_provider = db.session.scalar(
- select(ApiToolProvider)
- .where(
- ApiToolProvider.tenant_id == tenant_id,
- ApiToolProvider.name == provider_name,
+ # create new session with automatic transaction management to get the provider
+ provider: ApiToolProvider | None = None
+ with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
+ provider = _session.scalar(
+ select(ApiToolProvider)
+ .where(
+ ApiToolProvider.tenant_id == tenant_id,
+ ApiToolProvider.name == provider_name,
+ )
+ .limit(1)
)
- .limit(1)
- )
- if not db_provider:
+ if provider is None:
# create a fake db provider
- db_provider = ApiToolProvider(
+ provider = ApiToolProvider(
tenant_id="",
user_id="",
name="",
@@ -407,12 +475,12 @@ class ApiToolManageService:
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
# create provider entity
- provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
+ provider_controller = ApiToolProviderController.from_db(provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
# decrypt credentials
- if db_provider.id:
+ if provider.id:
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=provider_controller,
@@ -443,14 +511,21 @@ class ApiToolManageService:
@staticmethod
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
"""
- list api tools
+ List all API tools for a specific tenant.
+
+ :param tenant_id: The ID of the workspace/tenant.
+ :return: A list of ToolProviderApiEntity objects.
"""
# get all api providers
- db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
+ # create new session with automatic transaction management
+ providers: list[ApiToolProvider] = []
+ with sessionmaker(db.engine, expire_on_commit=False).begin() as _session:
+ providers = list(
+ _session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
+ )
result: list[ToolProviderApiEntity] = []
-
- for provider in db_providers:
+ for provider in providers:
# convert provider controller to user provider
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(provider_controller)
diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py
index c96050ce13..1529c2b98f 100644
--- a/api/services/variable_truncator.py
+++ b/api/services/variable_truncator.py
@@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator):
return TruncationResult(StringSegment(value=fallback_result.value), True)
# Apply final fallback - convert to JSON string and truncate
- json_str = dumps_with_segments(result.value, ensure_ascii=False)
+ json_str = dumps_with_segments(result.value)
if len(json_str) > self._max_size_bytes:
json_str = json_str[: self._max_size_bytes] + "..."
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index 5ec00ee336..96f936ff9b 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -1067,7 +1067,7 @@ class DraftVariableSaver:
filename = f"{self._generate_filename(name)}.txt"
else:
# For other types, store as JSON
- original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False)
+ original_content_serialized = dumps_with_segments(value_seg.value)
content_type = "application/json"
filename = f"{self._generate_filename(name)}.json"
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index d71223314e..d4b9095ce5 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly,
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.trigger.constants import is_trigger_node_type
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
- normalize_human_input_node_data_for_graph,
+ adapt_human_input_node_data_for_graph,
parse_human_input_delivery_methods,
)
from core.workflow.node_factory import (
@@ -791,7 +791,7 @@ class WorkflowService:
:param filters: filter by node config parameters.
:return:
"""
- node_type_enum = NodeType(node_type)
+ node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
@@ -1096,7 +1096,7 @@ class WorkflowService:
raise ValueError("Node type must be human-input.")
node_data = HumanInputNodeData.model_validate(
- normalize_human_input_node_data_for_graph(node_config["data"]),
+ adapt_human_input_node_data_for_graph(node_config["data"]),
from_attributes=True,
)
delivery_method = self._resolve_human_input_delivery_method(
@@ -1237,9 +1237,10 @@ class WorkflowService:
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
+ node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
node = HumanInputNode(
- id=node_config["id"],
- config=node_config,
+ node_id=node_config["id"],
+ config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(run_context),
@@ -1529,7 +1530,7 @@ class WorkflowService:
from graphon.nodes.human_input.entities import HumanInputNodeData
try:
- HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data))
+ HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data))
except Exception as e:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py
index f8ae3f4b6e..2a60be7762 100644
--- a/api/tasks/mail_human_input_delivery_task.py
+++ b/api/tasks/mail_human_input_delivery_task.py
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
-from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod
+from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod
from extensions.ext_database import db
from extensions.ext_mail import mail
from graphon.runtime import GraphRuntimeState, VariablePool
diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
index b5318aaa2b..2392084c36 100644
--- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
+++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
+from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamCompletedEvent
@@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
- id="n",
- config={
- "id": "n",
- "data": {
- "type": "datasource",
- "version": "1",
- "title": "Datasource",
- "provider_type": "plugin",
- "provider_name": "p",
- "plugin_id": "plug",
- "datasource_name": "ds",
- },
- },
+ node_id="n",
+ config=DatasourceNodeData(
+ type="datasource",
+ version="1",
+ title="Datasource",
+ provider_type="plugin",
+ provider_name="p",
+ plugin_id="plug",
+ datasource_name="ds",
+ ),
graph_init_params=_GP(),
graph_runtime_state=_GS(vp),
)
diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py
index e3476c292b..aaa6092993 100644
--- a/api/tests/integration_tests/workflow/nodes/test_code.py
+++ b/api/tests/integration_tests/workflow/nodes/test_code.py
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import NodeRunResult
from graphon.nodes.code.code_node import CodeNode
+from graphon.nodes.code.entities import CodeNodeData
from graphon.nodes.code.limits import CodeNodeLimits
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -64,8 +65,8 @@ def init_code_node(code_config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = CodeNode(
- id=str(uuid.uuid4()),
- config=code_config,
+ node_id=str(uuid.uuid4()),
+ config=CodeNodeData.model_validate(code_config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_executor=node_factory._code_executor,
diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py
index aa6cf1e021..b9f7b9575b 100644
--- a/api/tests/integration_tests/workflow/nodes/test_http.py
+++ b/api/tests/integration_tests/workflow/nodes/test_http.py
@@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.graph import Graph
-from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig
+from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -75,8 +75,8 @@ def init_http_node(config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=HttpRequestNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
@@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
- id=str(uuid.uuid4()),
- config=graph_config["nodes"][1],
+ node_id=str(uuid.uuid4()),
+ config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index fa5d63cfbf..3eead70163 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import StreamCompletedEvent
+from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
@@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode:
llm_file_saver = MagicMock(spec=LLMFileSaver)
node = LLMNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
index 52886855b8..f2eabb86c3 100644
--- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
+++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
@@ -11,6 +11,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
+from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
@@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = ParameterExtractorNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=ParameterExtractorNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
index 9e3e1a47e3..e2e0723fb8 100644
--- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py
+++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
@@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
+from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.template_rendering import TemplateRenderError
@@ -86,8 +87,8 @@ def test_execute_template_transform():
assert graph is not None
node = TemplateTransformNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=TemplateTransformNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
jinja2_template_renderer=_SimpleJinja2Renderer(),
diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py
index f9ec51ee10..a8e9422c1e 100644
--- a/api/tests/integration_tests/workflow/nodes/test_tool.py
+++ b/api/tests/integration_tests/workflow/nodes/test_tool.py
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import StreamCompletedEvent
from graphon.nodes.protocols import ToolFileManagerProtocol
+from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool.tool_node import ToolNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -60,8 +61,8 @@ def init_tool_node(config: dict):
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py
index 15dec06311..18755ef012 100644
--- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py
+++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py
@@ -234,6 +234,35 @@ class TestAppEndpoints:
}
)
+ def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch):
+ api = app_module.AppIconApi()
+ method = _unwrap(api.post)
+ payload = {
+ "icon": "https://example.com/icon.png",
+ "icon_type": "image",
+ "icon_background": "#FFFFFF",
+ }
+ app_service = MagicMock()
+ app_service.update_app_icon.return_value = SimpleNamespace()
+ response_model = MagicMock()
+ response_model.model_dump.return_value = {"id": "app-1"}
+
+ monkeypatch.setattr(app_module, "AppService", lambda: app_service)
+ monkeypatch.setattr(app_module.AppDetail, "model_validate", MagicMock(return_value=response_model))
+
+ with (
+ app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload),
+ patch.object(type(console_ns), "payload", payload),
+ ):
+ response = method(app_model=SimpleNamespace())
+
+ assert response == {"id": "app-1"}
+ assert app_service.update_app_icon.call_args.args[1:] == (
+ payload["icon"],
+ payload["icon_background"],
+ app_module.IconType.IMAGE,
+ )
+
class TestOpsTraceEndpoints:
@pytest.fixture
diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
index 14d5740072..6524d6ce61 100644
--- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
+++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
index da4f8847d6..5aed230cd4 100644
--- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
+++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
@@ -101,8 +101,8 @@ def _build_graph(
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
- id="start",
- config={"id": "start", "data": start_data.model_dump()},
+ node_id="start",
+ config=start_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
@@ -116,8 +116,8 @@ def _build_graph(
],
)
human_node = HumanInputNode(
- id="human",
- config={"id": "human", "data": human_data.model_dump()},
+ node_id="human",
+ config=human_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
@@ -130,8 +130,8 @@ def _build_graph(
desc=None,
)
end_node = EndNode(
- id="end",
- config={"id": "end", "data": end_data.model_dump()},
+ node_id="end",
+ config=end_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
index 2e207ddc67..35e41035df 100644
--- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
+++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
@@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase):
file_related_id = related_id
return File(
- id=str(uuid4()), # Generate new UUID for File.id
+ file_id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,
diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
index aaf9a85d60..54b7afc018 100644
--- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
+++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
@@ -271,7 +271,7 @@ def _create_recipient(
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
- from core.workflow.human_input_compat import DeliveryMethodType
+ from core.workflow.human_input_adapter import DeliveryMethodType
from models.human_input import ConsoleDeliveryPayload
delivery = HumanInputDelivery(
diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py
index fa57dd4a6f..b695ae9fd9 100644
--- a/api/tests/test_containers_integration_tests/services/test_app_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_app_service.py
@@ -658,15 +658,17 @@ class TestAppService:
# Update app icon
new_icon = "🌟"
new_icon_background = "#FFD93D"
+ new_icon_type = "image"
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
- updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
+ updated_app = app_service.update_app_icon(app, new_icon, new_icon_background, new_icon_type)
assert updated_app.icon == new_icon
assert updated_app.icon_background == new_icon_background
+ assert str(updated_app.icon_type).lower() == new_icon_type
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py
new file mode 100644
index 0000000000..2bec703f0c
--- /dev/null
+++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py
@@ -0,0 +1,650 @@
+"""Testcontainers integration tests for SQL-backed DocumentService paths."""
+
+import datetime
+import json
+from unittest.mock import create_autospec, patch
+from uuid import uuid4
+
+import pytest
+from werkzeug.exceptions import Forbidden, NotFound
+
+from core.rag.index_processor.constant.index_type import IndexStructureType
+from extensions.storage.storage_type import StorageType
+from models import Account
+from models.dataset import Dataset, Document
+from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
+from models.model import UploadFile
+from services.dataset_service import DocumentService
+from services.errors.account import NoPermissionError
+
+FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0)
+
+
+class DocumentServiceIntegrationFactory:
+ @staticmethod
+ def create_dataset(
+ db_session_with_containers,
+ *,
+ tenant_id: str | None = None,
+ created_by: str | None = None,
+ name: str | None = None,
+ ) -> Dataset:
+ dataset = Dataset(
+ tenant_id=tenant_id or str(uuid4()),
+ name=name or f"dataset-{uuid4()}",
+ data_source_type=DataSourceType.UPLOAD_FILE,
+ created_by=created_by or str(uuid4()),
+ )
+ db_session_with_containers.add(dataset)
+ db_session_with_containers.commit()
+ return dataset
+
+ @staticmethod
+ def create_document(
+ db_session_with_containers,
+ *,
+ dataset: Dataset,
+ name: str = "doc.txt",
+ position: int = 1,
+ tenant_id: str | None = None,
+ indexing_status: str = IndexingStatus.COMPLETED,
+ enabled: bool = True,
+ archived: bool = False,
+ is_paused: bool = False,
+ need_summary: bool = False,
+ doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
+ batch: str | None = None,
+ data_source_type: str = DataSourceType.UPLOAD_FILE,
+ data_source_info: dict | None = None,
+ created_by: str | None = None,
+ ) -> Document:
+ document = Document(
+ tenant_id=tenant_id or dataset.tenant_id,
+ dataset_id=dataset.id,
+ position=position,
+ data_source_type=data_source_type,
+ data_source_info=json.dumps(data_source_info or {}),
+ batch=batch or f"batch-{uuid4()}",
+ name=name,
+ created_from=DocumentCreatedFrom.WEB,
+ created_by=created_by or dataset.created_by,
+ doc_form=doc_form,
+ )
+ document.indexing_status = indexing_status
+ document.enabled = enabled
+ document.archived = archived
+ document.is_paused = is_paused
+ document.need_summary = need_summary
+ if indexing_status == IndexingStatus.COMPLETED:
+ document.completed_at = FIXED_UPLOAD_CREATED_AT
+ db_session_with_containers.add(document)
+ db_session_with_containers.commit()
+ return document
+
+ @staticmethod
+ def create_upload_file(
+ db_session_with_containers,
+ *,
+ tenant_id: str,
+ created_by: str,
+ file_id: str | None = None,
+ name: str = "source.txt",
+ ) -> UploadFile:
+ upload_file = UploadFile(
+ tenant_id=tenant_id,
+ storage_type=StorageType.LOCAL,
+ key=f"uploads/{uuid4()}",
+ name=name,
+ size=128,
+ extension="txt",
+ mime_type="text/plain",
+ created_by_role=CreatorUserRole.ACCOUNT,
+ created_by=created_by,
+ created_at=FIXED_UPLOAD_CREATED_AT,
+ used=False,
+ )
+ if file_id:
+ upload_file.id = file_id
+ db_session_with_containers.add(upload_file)
+ db_session_with_containers.commit()
+ return upload_file
+
+
+@pytest.fixture
+def current_user_mock():
+ with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
+ current_user.id = str(uuid4())
+ current_user.current_tenant_id = str(uuid4())
+ current_user.current_role = None
+ yield current_user
+
+
+def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ assert DocumentService.get_document(dataset.id, None) is None
+
+
+def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset)
+
+ result = DocumentService.get_document(dataset.id, document.id)
+
+ assert result is not None
+ assert result.id == document.id
+
+
+def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ result = DocumentService.get_documents_by_ids(dataset.id, [])
+
+ assert result == []
+
+
+def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt")
+ doc_b = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ name="b.txt",
+ position=2,
+ )
+
+ result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id])
+
+ assert {document.id for document in result} == {doc_a.id, doc_b.id}
+
+
+def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ assert DocumentService.update_documents_need_summary(dataset.id, []) == 0
+
+
+def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ paragraph_doc = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ need_summary=True,
+ )
+ qa_doc = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ need_summary=True,
+ doc_form=IndexStructureType.QA_INDEX,
+ )
+
+ updated_count = DocumentService.update_documents_need_summary(
+ dataset.id,
+ [paragraph_doc.id, qa_doc.id],
+ need_summary=False,
+ )
+
+ db_session_with_containers.expire_all()
+ refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id)
+ refreshed_qa = db_session_with_containers.get(Document, qa_doc.id)
+ assert updated_count == 1
+ assert refreshed_paragraph is not None
+ assert refreshed_qa is not None
+ assert refreshed_paragraph.need_summary is False
+ assert refreshed_qa.need_summary is True
+
+
+def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file.id},
+ )
+
+ with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url:
+ result = DocumentService.get_document_download_url(document)
+
+ assert result == "signed-url"
+ get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True)
+
+
+def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_type=DataSourceType.WEBSITE_CRAWL,
+ data_source_info={"url": "https://example.com"},
+ )
+
+ with pytest.raises(NotFound, match="invalid source"):
+ DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="invalid source",
+ missing_file_message="missing file",
+ )
+
+
+def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={},
+ )
+
+ with pytest.raises(NotFound, match="missing file"):
+ DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="invalid source",
+ missing_file_message="missing file",
+ )
+
+
+def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": 99},
+ )
+
+ result = DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="invalid source",
+ missing_file_message="missing file",
+ )
+
+ assert result == "99"
+
+
+def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": "missing-file"},
+ )
+
+ with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}):
+ with pytest.raises(NotFound, match="Uploaded file not found"):
+ DocumentService._get_upload_file_for_upload_file_document(document)
+
+
+def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file.id},
+ )
+
+ result = DocumentService._get_upload_file_for_upload_file_document(document)
+
+ assert result.id == upload_file.id
+
+
+def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ with pytest.raises(NotFound, match="Document not found"):
+ DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset.id,
+ document_ids=[str(uuid4())],
+ tenant_id=dataset.tenant_id,
+ )
+
+
+def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ tenant_id=str(uuid4()),
+ data_source_info={"upload_file_id": upload_file.id},
+ )
+
+ with pytest.raises(Forbidden, match="No permission"):
+ DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset.id,
+ document_ids=[document.id],
+ tenant_id=dataset.tenant_id,
+ )
+
+
+def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": str(uuid4())},
+ )
+
+ with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"):
+ DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset.id,
+ document_ids=[document.id],
+ tenant_id=dataset.tenant_id,
+ )
+
+
+def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="a.txt",
+ )
+ upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="b.txt",
+ )
+ document_a = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file_a.id},
+ )
+ document_b = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ data_source_info={"upload_file_id": upload_file_b.id},
+ )
+
+ mapping = DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset.id,
+ document_ids=[document_a.id, document_b.id],
+ tenant_id=dataset.tenant_id,
+ )
+
+ assert mapping[document_a.id].id == upload_file_a.id
+ assert mapping[document_b.id].id == upload_file_b.id
+
+
+def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset(
+ current_user_mock, flask_app_with_containers
+):
+ with flask_app_with_containers.app_context():
+ with pytest.raises(NotFound, match="Dataset not found"):
+ DocumentService.prepare_document_batch_download_zip(
+ dataset_id=str(uuid4()),
+ document_ids=[str(uuid4())],
+ tenant_id=current_user_mock.current_tenant_id,
+ current_user=current_user_mock,
+ )
+
+
+def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(
+ db_session_with_containers,
+ current_user_mock,
+):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(
+ db_session_with_containers,
+ tenant_id=current_user_mock.current_tenant_id,
+ created_by=current_user_mock.id,
+ )
+
+ with patch(
+ "services.dataset_service.DatasetService.check_dataset_permission",
+ side_effect=NoPermissionError("denied"),
+ ):
+ with pytest.raises(Forbidden, match="denied"):
+ DocumentService.prepare_document_batch_download_zip(
+ dataset_id=dataset.id,
+ document_ids=[],
+ tenant_id=current_user_mock.current_tenant_id,
+ current_user=current_user_mock,
+ )
+
+
+def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(
+ db_session_with_containers,
+ current_user_mock,
+):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(
+ db_session_with_containers,
+ tenant_id=current_user_mock.current_tenant_id,
+ created_by=current_user_mock.id,
+ )
+ upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="a.txt",
+ )
+ upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="b.txt",
+ )
+ document_a = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file_a.id},
+ )
+ document_b = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ data_source_info={"upload_file_id": upload_file_b.id},
+ )
+
+ upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
+ dataset_id=dataset.id,
+ document_ids=[document_b.id, document_a.id],
+ tenant_id=current_user_mock.current_tenant_id,
+ current_user=current_user_mock,
+ )
+
+ assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id]
+ assert download_name.endswith(".zip")
+
+
+def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ enabled_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ enabled=True,
+ )
+ DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ enabled=False,
+ )
+
+ result = DocumentService.get_document_by_dataset_id(dataset.id)
+
+ assert [document.id for document in result] == [enabled_document.id]
+
+
+def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ available_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ indexing_status=IndexingStatus.COMPLETED,
+ enabled=True,
+ archived=False,
+ )
+ DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ indexing_status=IndexingStatus.ERROR,
+ )
+
+ result = DocumentService.get_working_documents_by_dataset_id(dataset.id)
+
+ assert [document.id for document in result] == [available_document.id]
+
+
+def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ error_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ indexing_status=IndexingStatus.ERROR,
+ )
+ paused_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ indexing_status=IndexingStatus.PAUSED,
+ )
+ DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=3,
+ indexing_status=IndexingStatus.COMPLETED,
+ )
+
+ result = DocumentService.get_error_documents_by_dataset_id(dataset.id)
+
+ assert {document.id for document in result} == {error_document.id, paused_document.id}
+
+
+def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ batch = f"batch-{uuid4()}"
+ matching_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ batch=batch,
+ )
+ DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ tenant_id=str(uuid4()),
+ batch=batch,
+ )
+
+ with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
+ current_user.current_tenant_id = dataset.tenant_id
+ result = DocumentService.get_batch_documents(dataset.id, batch)
+
+ assert [document.id for document in result] == [matching_document.id]
+
+
+def test_get_document_file_detail_returns_upload_file(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+
+ result = DocumentService.get_document_file_detail(upload_file.id)
+
+ assert result is not None
+ assert result.id == upload_file.id
+
+
+def test_delete_document_emits_signal_and_commits(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file.id},
+ )
+
+ with patch("services.dataset_service.document_was_deleted.send") as signal_send:
+ DocumentService.delete_document(document)
+
+ assert db_session_with_containers.get(Document, document.id) is None
+ signal_send.assert_called_once_with(
+ document.id,
+ dataset_id=document.dataset_id,
+ doc_form=document.doc_form,
+ file_id=upload_file.id,
+ )
+
+
+def test_delete_documents_ignores_empty_input(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
+ DocumentService.delete_documents(dataset, [])
+
+ delay.assert_not_called()
+
+
+def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX
+ db_session_with_containers.commit()
+ upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="a.txt",
+ )
+ upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="b.txt",
+ )
+ document_a = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file_a.id},
+ )
+ document_b = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ data_source_info={"upload_file_id": upload_file_b.id},
+ )
+
+ with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
+ DocumentService.delete_documents(dataset, [document_a.id, document_b.id])
+
+ assert db_session_with_containers.get(Document, document_a.id) is None
+ assert db_session_with_containers.get(Document, document_b.id) is None
+ delay.assert_called_once()
+ args = delay.call_args.args
+ assert args[0] == [document_a.id, document_b.id]
+ assert args[1] == dataset.id
+ assert set(args[3]) == {upload_file_a.id, upload_file_b.id}
+
+
+def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3)
+
+ assert DocumentService.get_documents_position(dataset.id) == 4
+
+
+def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ assert DocumentService.get_documents_position(dataset.id) == 1
diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
index 18c5320d0a..80f9083e81 100644
--- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
+++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
index 21a54e909e..ed75363f3b 100644
--- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
@@ -8,7 +8,7 @@ import pytest
from sqlalchemy.engine import Engine
from configs import dify_config
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py
index d3e765055a..af83adaae0 100644
--- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py
+++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py
@@ -1,3 +1,5 @@
+import inspect
+import json
from unittest.mock import patch
import pytest
@@ -6,6 +8,8 @@ from pydantic import TypeAdapter, ValidationError
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ApiProviderSchemaType
+from core.tools.errors import ApiToolProviderNotFoundError
+from core.tools.tool_label_manager import ToolLabelManager
from models import Account, Tenant
from models.tools import ApiToolProvider
from services.tools.api_tools_manage_service import ApiToolManageService
@@ -590,30 +594,204 @@ class TestApiToolManageService:
with pytest.raises(ValueError, match="you have not added provider"):
ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent")
- def test_update_api_tool_provider_not_found(
+ def test_update_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
- """Test update raises ValueError when original provider not found."""
fake = Faker()
+
+ # Firmware fix for cache.delete() in update flow
+ mock_encrypter = mock_external_service_dependencies["encrypter"]
+ from unittest.mock import MagicMock
+
+ mock_cache = MagicMock()
+ mock_cache.delete.return_value = None
+ mock_encrypter.return_value = (mock_encrypter, mock_cache)
+
+ # Get fake account and tenant
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
- with pytest.raises(ValueError, match="does not exists"):
- ApiToolManageService.update_api_tool_provider(
+ # original provider name
+ original_name = "original-provider"
+
+ # Create original provider
+ _ = ApiToolManageService.create_api_tool_provider(
+ user_id=account.id,
+ tenant_id=tenant.id,
+ provider_name=original_name,
+ icon={"type": "emoji", "value": "🔧"},
+ credentials={"auth_type": "none"},
+ schema_type=ApiProviderSchemaType.OPENAPI,
+ schema=self._create_test_openapi_schema(),
+ privacy_policy="",
+ custom_disclaimer="",
+ labels=["old-label"],
+ )
+
+ # new provide name and new labels for update
+ new_name = "updated-provider"
+ new_labels = ["new-label-1", "new-label-2"]
+
+ # Reset mock history so assertions focus on update path only
+ mock_external_service_dependencies["encrypter"].reset_mock()
+ mock_external_service_dependencies["provider_controller"].from_db.reset_mock()
+ mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock()
+
+ # Act: Update the provider with new values
+ result = ApiToolManageService.update_api_tool_provider(
+ user_id=account.id,
+ tenant_id=tenant.id,
+ # new provider name - changed 1
+ provider_name=new_name,
+ original_provider=original_name,
+ # new icon - changed 2
+ icon={"type": "emoji", "value": "🚀"},
+ credentials={"auth_type": "none"},
+ _schema_type=ApiProviderSchemaType.OPENAPI,
+ schema=self._create_test_openapi_schema(),
+ # new privacy policy - changed 3
+ privacy_policy="https://new-policy.com",
+ # new custom disclaimer - changed 4
+ custom_disclaimer="New disclaimer",
+ # new labels - changed 5 (However, we will not verify this, not this layer responsibility.)
+ labels=new_labels,
+ )
+
+ # Assert: Verify the result
+ assert result == {"result": "success"}
+
+ # Get the updated provider from the database
+ updated_provider: ApiToolProvider | None = (
+ db_session_with_containers.query(ApiToolProvider)
+ .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == new_name)
+ .first()
+ )
+
+ # Verify the provider was updated successfully
+ assert updated_provider is not None
+
+ # Manually refresh to keep object detachment
+ db_session_with_containers.refresh(updated_provider)
+ # Verify all the updated fields
+ # - changed 1
+ assert updated_provider.name == new_name
+ # - changed 2
+ icon_data = json.loads(updated_provider.icon)
+ assert icon_data["type"] == "emoji"
+ assert icon_data["value"] == "🚀"
+ # - changed 3
+ assert updated_provider.privacy_policy == "https://new-policy.com"
+ # - changed 4
+ assert updated_provider.custom_disclaimer == "New disclaimer"
+
+ # Verify old provider name no longer exists after rename
+ original_provider: ApiToolProvider | None = (
+ db_session_with_containers.query(ApiToolProvider)
+ .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == original_name)
+ .first()
+ )
+ assert original_provider is None
+
+ # Verify update flow calls critical collaborators
+ mock_external_service_dependencies["provider_controller"].from_db.assert_called_once()
+ mock_external_service_dependencies["encrypter"].assert_called_once()
+ mock_cache.delete.assert_called_once()
+
+ # Deeply verify on session propagation of labels update logics:
+ # Since in refactoring, we pass session down to label manager to keep atomicity.
+ # The assertion here is to verify this.
+ sig = inspect.signature(ToolLabelManager.update_tool_labels)
+ args, kwargs = mock_external_service_dependencies["tool_label_manager"].update_tool_labels.call_args
+ bound_args = sig.bind(*args, **kwargs)
+ passed_session = bound_args.arguments.get("session")
+ # Ensure the type: Session
+ assert isinstance(passed_session, Session), f"Expected Session object, got {type(passed_session)}"
+ assert passed_session is not None, (
+ "Atomicity Failure: Session cannot be passed to Label Manager in update_api_tool_provider"
+ )
+
+ def test_update_api_tool_provider_not_found(
+ self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
+ ):
+ """
+ Test update raises ValueError when original provider not found.
+
+ This test verifies:
+ - Proper error when trying to update a non-existing original provider
+ - No accidental upsert/new provider creation
+ - No external dependency invocation on early failure path
+ """
+ # Arrange: Create test account and tenant
+ account, tenant = self._create_test_account_and_tenant(
+ db_session_with_containers, mock_external_service_dependencies
+ )
+
+ # Keep an existing provider in DB to ensure unrelated data remains unchanged
+ existing_provider_name = "existing-provider"
+ _ = ApiToolManageService.create_api_tool_provider(
+ user_id=account.id,
+ tenant_id=tenant.id,
+ provider_name=existing_provider_name,
+ icon={"type": "emoji", "value": "🔧"},
+ credentials={"auth_type": "none"},
+ schema_type=ApiProviderSchemaType.OPENAPI,
+ schema=self._create_test_openapi_schema(),
+ privacy_policy="https://existing-policy.com",
+ custom_disclaimer="Existing disclaimer",
+ labels=["existing-label"],
+ )
+
+ # Reset mock history so assertions focus on update failure path only
+ mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock()
+ mock_external_service_dependencies["encrypter"].reset_mock()
+ mock_external_service_dependencies["provider_controller"].from_db.reset_mock()
+
+ # Act & Assert: Verify update fails with clear error message
+ target_new_name = "new-provider-name"
+ missing_original_name = "missing-original-provider"
+ with pytest.raises(ApiToolProviderNotFoundError) as exc_info:
+ _ = ApiToolManageService.update_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
- provider_name="new-name",
- original_provider="nonexistent",
- icon={},
+ provider_name=target_new_name,
+ original_provider=missing_original_name,
+ icon={"type": "emoji", "value": "🚀"},
credentials={"auth_type": "none"},
_schema_type=ApiProviderSchemaType.OPENAPI,
schema=self._create_test_openapi_schema(),
- privacy_policy=None,
- custom_disclaimer="",
- labels=[],
+ privacy_policy="https://new-policy.com",
+ custom_disclaimer="New disclaimer",
+ labels=["new-label"],
)
+ error = exc_info.value
+ assert error.provider_name == missing_original_name
+ assert error.tenant_id == tenant.id
+ assert error.error_code == "api_tool_provider_not_found"
+
+ # Assert: Existing provider should remain unchanged
+ existing_provider: ApiToolProvider | None = (
+ db_session_with_containers.query(ApiToolProvider)
+ .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == existing_provider_name)
+ .first()
+ )
+ assert existing_provider is not None
+ assert existing_provider.name == existing_provider_name
+
+ # Assert: No new provider should be created
+ unexpected_new_provider: ApiToolProvider | None = (
+ db_session_with_containers.query(ApiToolProvider)
+ .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == target_new_name)
+ .first()
+ )
+ assert unexpected_new_provider is None
+
+ # Assert: Early failure should skip all downstream external interactions
+ mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_not_called()
+ mock_external_service_dependencies["encrypter"].assert_not_called()
+ mock_external_service_dependencies["provider_controller"].from_db.assert_not_called()
+
def test_update_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
index 328bdbf055..95a867dbb5 100644
--- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
+++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
@@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py
index baac4cd4e0..1af15d8dc6 100644
--- a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py
+++ b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py
@@ -1,6 +1,25 @@
import datetime
+from types import SimpleNamespace
+from unittest.mock import PropertyMock, patch
-from controllers.console.app.mcp_server import AppMCPServerResponse
+from flask import Flask
+
+from controllers.console import console_ns
+from controllers.console.app.mcp_server import AppMCPServerController, AppMCPServerResponse
+
+
+def unwrap(func):
+ while hasattr(func, "__wrapped__"):
+ func = func.__wrapped__
+ return func
+
+
+class _ValidatedResponse:
+ def __init__(self, payload):
+ self._payload = payload
+
+ def model_dump(self, mode="json"):
+ return self._payload
class TestAppMCPServerResponse:
@@ -40,6 +59,18 @@ class TestAppMCPServerResponse:
resp = AppMCPServerResponse.model_validate(data)
assert resp.parameters == {"already": "parsed"}
+ def test_parameters_json_array_parsed(self):
+ data = {
+ "id": "s1",
+ "name": "test",
+ "server_code": "code",
+ "description": "desc",
+ "status": "active",
+ "parameters": '["a", "b"]',
+ }
+ resp = AppMCPServerResponse.model_validate(data)
+ assert resp.parameters == ["a", "b"]
+
def test_timestamps_normalized(self):
dt = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC)
data = {
@@ -68,3 +99,40 @@ class TestAppMCPServerResponse:
resp = AppMCPServerResponse.model_validate(data)
assert resp.created_at is None
assert resp.updated_at is None
+
+
+class TestAppMCPServerController:
+ def test_get_returns_empty_dict_when_server_missing(self):
+ api = AppMCPServerController()
+ method = unwrap(api.get)
+
+ with patch("controllers.console.app.mcp_server.db.session.scalar", return_value=None):
+ response = method(api, app_model=SimpleNamespace(id="app-1"))
+
+ assert response == {}
+
+ def test_post_returns_201(self):
+ api = AppMCPServerController()
+ method = unwrap(api.post)
+ payload = {"parameters": {"timeout": 30}}
+ app = Flask(__name__)
+ app.config["TESTING"] = True
+
+ with (
+ app.test_request_context("/", json=payload),
+ patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
+ patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")),
+ patch("controllers.console.app.mcp_server.db.session.add"),
+ patch("controllers.console.app.mcp_server.db.session.commit"),
+ patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"),
+ patch(
+ "controllers.console.app.mcp_server.AppMCPServerResponse.model_validate",
+ return_value=_ValidatedResponse({"id": "server-1"}),
+ ),
+ ):
+ response, status_code = method(
+ api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description")
+ )
+
+ assert response == {"id": "server-1"}
+ assert status_code == 201
diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py
index 6ff3b19362..e91c0a0597 100644
--- a/api/tests/unit_tests/controllers/console/app/test_workflow.py
+++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py
@@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
file_list = [
File(
tenant_id="t1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="http://u",
)
diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
index b19a1740eb..22b80b748e 100644
--- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
+++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
@@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url():
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
- id="test_file_id",
- type=FileType.IMAGE,
+ file_id="test_file_id",
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
@@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url():
# Create a File object with REMOTE_URL transfer method
test_file = File(
- id="test_file_id",
- type=FileType.IMAGE,
+ file_id="test_file_id",
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py
index 94d6c17915..9465936f28 100644
--- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py
+++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py
@@ -1772,6 +1772,21 @@ class TestDatasetApiBaseUrlApi:
assert response["api_base_url"] == "http://localhost:5000/v1"
+ def test_get_api_base_url_no_double_v1(self, app):
+ api = DatasetApiBaseUrlApi()
+ method = unwrap(api.get)
+
+ with (
+ app.test_request_context("/"),
+ patch(
+ "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL",
+ "https://example.com/v1",
+ ),
+ ):
+ response = method(api)
+
+ assert response["api_base_url"] == "https://example.com/v1"
+
class TestDatasetRetrievalSettingApi:
def test_get_success(self, app):
diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py
index ce2278de4f..d9b02ac453 100644
--- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py
+++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py
@@ -1,3 +1,4 @@
+from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@@ -215,17 +216,23 @@ class TestDatasetDocumentListApi:
method = unwrap(api.post)
payload = {"indexing_technique": "economy"}
+ created_dataset = SimpleNamespace(id="ds-1", name="Dataset", indexing_technique="economy")
+ created_document = SimpleNamespace(id="doc-1", name="Document", doc_metadata_details=None)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
+ patch(
+ "controllers.console.datasets.datasets_document.DatasetService.get_dataset",
+ return_value=created_dataset,
+ ),
patch(
"controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id",
- return_value=([MagicMock()], "batch-1"),
+ return_value=([created_document], "batch-1"),
),
):
response = method(api, "ds-1")
diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py
index c513be950b..26ff264f18 100644
--- a/api/tests/unit_tests/controllers/console/test_workspace_account.py
+++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py
@@ -68,7 +68,10 @@ class TestChangeEmailSend:
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
- mock_get_change_data.return_value = {"email": "current@example.com"}
+ mock_get_change_data.return_value = {
+ "email": "current@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
+ }
mock_send_email.return_value = "token-abc"
with app.test_request_context(
@@ -85,12 +88,55 @@ class TestChangeEmailSend:
email="new@example.com",
old_email="current@example.com",
language="en-US",
- phase="new_email",
+ phase=AccountService.CHANGE_EMAIL_PHASE_NEW,
)
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.send_change_email_email")
+ @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
+ @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_extract_ip,
+ mock_is_ip_limit,
+ mock_send_email,
+ mock_get_change_data,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ mock_account = _build_account("current@example.com", "acc1")
+ mock_current_account.return_value = (mock_account, None)
+ mock_get_change_data.return_value = {
+ "email": "current@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
+ }
+
+ with app.test_request_context(
+ "/account/change-email",
+ method="POST",
+ json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailSendEmailApi().post()
+
+ mock_send_email.assert_not_called()
+
class TestChangeEmailValidity:
@patch("controllers.console.wraps.db")
@@ -122,7 +168,12 @@ class TestChangeEmailValidity:
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
- mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
+ mock_get_data.return_value = {
+ "email": "user@example.com",
+ "code": "1234",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
+ }
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
@@ -138,11 +189,169 @@ class TestChangeEmailValidity:
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
- "user@example.com", code="1234", old_email="old@example.com", additional_data={}
+ "user@example.com",
+ code="1234",
+ old_email="old@example.com",
+ additional_data={
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
+ },
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_csrf.assert_called_once()
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_upgrade_new_phase_token_to_new_verified(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_rate_limit,
+ mock_get_data,
+ mock_add_rate,
+ mock_revoke_token,
+ mock_generate_token,
+ mock_reset_rate,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
+ mock_is_rate_limit.return_value = False
+ mock_get_data.return_value = {
+ "email": "new@example.com",
+ "code": "1234",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
+ }
+ mock_generate_token.return_value = (None, "new-verified-token")
+
+ with app.test_request_context(
+ "/account/change-email/validity",
+ method="POST",
+ json={"email": "new@example.com", "code": "1234", "token": "token-123"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ response = ChangeEmailCheckApi().post()
+
+ assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"}
+ mock_generate_token.assert_called_once_with(
+ "new@example.com",
+ code="1234",
+ old_email="old@example.com",
+ additional_data={
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
+ },
+ )
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_validity_when_token_phase_is_unknown(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_rate_limit,
+ mock_get_data,
+ mock_add_rate,
+ mock_revoke_token,
+ mock_generate_token,
+ mock_reset_rate,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """A token whose phase marker is a string but not a known transition must be rejected."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
+ mock_is_rate_limit.return_value = False
+ mock_get_data.return_value = {
+ "email": "user@example.com",
+ "code": "1234",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else",
+ }
+
+ with app.test_request_context(
+ "/account/change-email/validity",
+ method="POST",
+ json={"email": "user@example.com", "code": "1234", "token": "token-123"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailCheckApi().post()
+
+ mock_revoke_token.assert_not_called()
+ mock_generate_token.assert_not_called()
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_validity_when_token_has_no_phase(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_rate_limit,
+ mock_get_data,
+ mock_add_rate,
+ mock_revoke_token,
+ mock_generate_token,
+ mock_reset_rate,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
+ mock_is_rate_limit.return_value = False
+ mock_get_data.return_value = {
+ "email": "user@example.com",
+ "code": "1234",
+ "old_email": "old@example.com",
+ }
+
+ with app.test_request_context(
+ "/account/change-email/validity",
+ method="POST",
+ json={"email": "user@example.com", "code": "1234", "token": "token-123"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailCheckApi().post()
+
+ mock_revoke_token.assert_not_called()
+ mock_generate_token.assert_not_called()
+
class TestChangeEmailReset:
@patch("controllers.console.wraps.db")
@@ -175,7 +384,11 @@ class TestChangeEmailReset:
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
- mock_get_data.return_value = {"old_email": "OLD@example.com"}
+ mock_get_data.return_value = {
+ "email": "new@example.com",
+ "old_email": "OLD@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
+ }
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
mock_update_account.return_value = mock_account_after_update
@@ -194,6 +407,155 @@ class TestChangeEmailReset:
mock_send_notify.assert_called_once_with(email="new@example.com")
mock_csrf.assert_called_once()
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
+ @patch("controllers.console.workspace.account.AccountService.update_account_email")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.check_email_unique")
+ @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_reset_when_token_phase_is_not_new_verified(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_freeze,
+ mock_check_unique,
+ mock_get_data,
+ mock_revoke_token,
+ mock_update_account,
+ mock_send_notify,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ current_user = _build_account("old@example.com", "acc3")
+ mock_current_account.return_value = (current_user, None)
+ mock_is_freeze.return_value = False
+ mock_check_unique.return_value = True
+ # Simulate a token straight out of step #1 (phase=old_email) — exactly
+ # the replay used in the advisory PoC.
+ mock_get_data.return_value = {
+ "email": "old@example.com",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
+ }
+
+ with app.test_request_context(
+ "/account/change-email/reset",
+ method="POST",
+ json={"new_email": "attacker@example.com", "token": "token-from-step1"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailResetApi().post()
+
+ mock_revoke_token.assert_not_called()
+ mock_update_account.assert_not_called()
+ mock_send_notify.assert_not_called()
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
+ @patch("controllers.console.workspace.account.AccountService.update_account_email")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.check_email_unique")
+ @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_reset_when_token_email_differs_from_payload_new_email(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_freeze,
+ mock_check_unique,
+ mock_get_data,
+ mock_revoke_token,
+ mock_update_account,
+ mock_send_notify,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """A verified token for address A must not be replayed to change to address B."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ current_user = _build_account("old@example.com", "acc3")
+ mock_current_account.return_value = (current_user, None)
+ mock_is_freeze.return_value = False
+ mock_check_unique.return_value = True
+ mock_get_data.return_value = {
+ "email": "verified@example.com",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
+ }
+
+ with app.test_request_context(
+ "/account/change-email/reset",
+ method="POST",
+ json={"new_email": "attacker@example.com", "token": "token-verified"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailResetApi().post()
+
+ mock_revoke_token.assert_not_called()
+ mock_update_account.assert_not_called()
+ mock_send_notify.assert_not_called()
+
+
+class TestAccountServiceSendChangeEmailEmail:
+ """Service-level coverage for the phase-bound changes in `send_change_email_email`."""
+
+ def test_should_raise_value_error_for_invalid_phase(self):
+ with pytest.raises(ValueError, match="phase must be one of"):
+ AccountService.send_change_email_email(
+ email="user@example.com",
+ old_email="user@example.com",
+ phase="old_email_verified",
+ )
+
+ @patch("services.account_service.send_change_mail_task")
+ @patch("services.account_service.AccountService.change_email_rate_limiter")
+ @patch("services.account_service.AccountService.generate_change_email_token")
+ def test_should_stamp_phase_into_generated_token(
+ self,
+ mock_generate_token,
+ mock_rate_limiter,
+ mock_mail_task,
+ ):
+ mock_rate_limiter.is_rate_limited.return_value = False
+ mock_generate_token.return_value = ("123456", "the-token")
+
+ returned = AccountService.send_change_email_email(
+ email="user@example.com",
+ old_email="user@example.com",
+ language="en-US",
+ phase=AccountService.CHANGE_EMAIL_PHASE_NEW,
+ )
+
+ assert returned == "the-token"
+ mock_generate_token.assert_called_once_with(
+ "user@example.com",
+ None,
+ old_email="user@example.com",
+ additional_data={
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
+ },
+ )
+ mock_mail_task.delay.assert_called_once()
+ mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com")
+
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py
index 0895fac3a4..d1b09c3a58 100644
--- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py
+++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py
@@ -41,17 +41,22 @@ class TestTenantUserPayload:
class TestGetUser:
"""Test get_user function"""
+ @patch("controllers.inner_api.plugin.wraps.select")
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db")
- def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
+ def test_should_return_existing_user_by_id(
+ self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
+ ):
"""Test returning existing user when found by ID"""
# Arrange
mock_user = MagicMock()
mock_user.id = "user123"
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
- mock_session.get.return_value = mock_user
+ mock_session.scalar.return_value = mock_user
+ mock_query = MagicMock()
+ mock_select.return_value.where.return_value.limit.return_value = mock_query
# Act
with app.app_context():
@@ -59,13 +64,45 @@ class TestGetUser:
# Assert
assert result == mock_user
- mock_session.get.assert_called_once()
+ mock_session.scalar.assert_called_once()
+ @patch("controllers.inner_api.plugin.wraps.select")
+ @patch("controllers.inner_api.plugin.wraps.EndUser")
+ @patch("controllers.inner_api.plugin.wraps.sessionmaker")
+ @patch("controllers.inner_api.plugin.wraps.db")
+ def test_should_not_resolve_non_anonymous_users_across_tenants(
+ self,
+ mock_db,
+ mock_sessionmaker,
+ mock_enduser_class,
+ mock_select,
+ app: Flask,
+ ):
+ """Test that explicit user IDs remain scoped to the current tenant."""
+ # Arrange
+ mock_session = MagicMock()
+ mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = None
+ mock_new_user = MagicMock()
+ mock_new_user.tenant_id = "tenant-current"
+ mock_enduser_class.return_value = mock_new_user
+
+ # Act
+ with app.app_context():
+ result = get_user("tenant-current", "foreign-user-id")
+
+ # Assert
+ assert result == mock_new_user
+ mock_session.get.assert_not_called()
+ mock_session.scalar.assert_called_once()
+ mock_session.add.assert_called_once_with(mock_new_user)
+
+ @patch("controllers.inner_api.plugin.wraps.select")
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_return_existing_anonymous_user_by_session_id(
- self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask
+ self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
):
"""Test returning existing anonymous user by session_id"""
# Arrange
@@ -73,8 +110,9 @@ class TestGetUser:
mock_user.session_id = "anonymous_session"
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
- # non-anonymous path uses session.get(); anonymous uses session.scalar()
- mock_session.get.return_value = mock_user
+ mock_session.scalar.return_value = mock_user
+ mock_query = MagicMock()
+ mock_select.return_value.where.return_value.limit.return_value = mock_query
# Act
with app.app_context():
@@ -83,17 +121,22 @@ class TestGetUser:
# Assert
assert result == mock_user
+ @patch("controllers.inner_api.plugin.wraps.select")
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db")
- def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
+ def test_should_create_new_user_when_not_found(
+ self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
+ ):
"""Test creating new user when not found in database"""
# Arrange
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
- mock_session.get.return_value = None
+ mock_session.scalar.return_value = None
mock_new_user = MagicMock()
mock_enduser_class.return_value = mock_new_user
+ mock_query = MagicMock()
+ mock_select.return_value.where.return_value.limit.return_value = mock_query
# Act
with app.app_context():
@@ -134,7 +177,7 @@ class TestGetUser:
# Arrange
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
- mock_session.get.side_effect = Exception("Database error")
+ mock_session.scalar.side_effect = Exception("Database error")
# Act & Assert
with app.app_context():
diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
index 14c35a9ed5..4fb8ecf784 100644
--- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
+++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
@@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import (
ConversationVariableUpdatePayload,
)
from controllers.service_api.app.error import NotChatAppError
+from fields._value_type_serializer import serialize_value_type
+from graphon.variables import StringSegment
from graphon.variables.types import SegmentType
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@@ -284,6 +286,32 @@ class TestConversationVariableResponseModels:
assert response.created_at == int(created_at.timestamp())
assert response.updated_at == int(created_at.timestamp())
+ def test_variable_response_normalizes_string_value_type_alias(self):
+ response = ConversationVariableResponse.model_validate(
+ {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "foo",
+ "value_type": SegmentType.INTEGER.value,
+ }
+ )
+
+ assert response.value_type == "number"
+
+ def test_variable_response_normalizes_callable_exposed_type(self):
+ response = ConversationVariableResponse.model_validate(
+ {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "foo",
+ "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()),
+ }
+ )
+
+ assert response.value_type == "string"
+
+ def test_serialize_value_type_supports_segments_and_mappings(self):
+ assert serialize_value_type(StringSegment(value="hello")) == "string"
+ assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number"
+
def test_variable_pagination_response(self):
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
{
diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
index 3ab63aed25..dd6cd0e919 100644
--- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
+++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
@@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
def create_test_file(self, file_id: str = "test_file_1") -> File:
"""Create a test File object"""
return File(
- id=file_id,
- type=FileType.DOCUMENT,
+ file_id=file_id,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related_123",
filename=f"{file_id}.txt",
diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py
index a04a7b7576..6104b8d6ca 100644
--- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py
+++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py
@@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow import node_factory as node_factory_module
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.base_node_data import BaseNodeData, RetryConfig
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.entities.pause_reason import SchedulingPause
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus
from graphon.graph import Graph
@@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]):
def version(cls) -> str:
return "1"
- def init_node_data(self, data):
- self._node_data = _StubToolNodeData.model_validate(data)
+ def __init__(
+ self,
+ node_id: str,
+ config: _StubToolNodeData,
+ *,
+ graph_init_params,
+ graph_runtime_state,
+ **_kwargs: Any,
+ ) -> None:
+ super().__init__(
+ node_id=node_id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
def _get_error_strategy(self):
return self._node_data.error_strategy
@@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]):
def _patch_tool_node(mocker):
- original_create_node = DifyNodeFactory.create_node
+ original_resolve_node_class = node_factory_module.resolve_workflow_node_class
- def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
- typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
- node_data = typed_node_config["data"]
- if node_data.type == BuiltinNodeTypes.TOOL:
- return _StubToolNode(
- id=str(typed_node_config["id"]),
- config=typed_node_config,
- graph_init_params=self.graph_init_params,
- graph_runtime_state=self.graph_runtime_state,
- )
- return original_create_node(self, typed_node_config)
+ def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
+ if node_type == BuiltinNodeTypes.TOOL:
+ return _StubToolNode
+ return original_resolve_node_class(node_type=node_type, node_version=node_version)
- mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
+ mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class)
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
index cddd03f4b0..701863b927 100644
--- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
+++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
@@ -26,8 +26,8 @@ def _build_file(
extension: str | None = None,
) -> File:
return File(
- id="file-id",
- type=FileType.IMAGE,
+ file_id="file-id",
+ file_type=FileType.IMAGE,
transfer_method=transfer_method,
reference=reference,
remote_url=remote_url,
@@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
assert runtime.multimodal_send_format == "url"
- with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get:
+ with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
assert runtime.http_get("http://example", follow_redirects=False) == "response"
mock_get.assert_called_once_with("http://example", follow_redirects=False)
diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py
index c4bfb23272..30a068f4c5 100644
--- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py
+++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py
@@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes
class DummyNode:
- def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
- self.id = id
+ def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
+ self.id = node_id
self.config = config
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py
index 81315d2508..deeac49bbc 100644
--- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py
+++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py
@@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE)
built = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tool_file_1",
extension=".png",
@@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
file_in = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tf",
extension=".pdf",
diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py
index a0b2820157..aeca2e3afd 100644
--- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py
+++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py
@@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
# Assert
assert simple_provider.provider == "openai"
- assert simple_provider.label.en_US == "OpenAI"
+ assert simple_provider.label.en_us == "OpenAI"
assert simple_provider.supported_model_types == [ModelType.LLM]
diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py
index fe2c226843..a28143026f 100644
--- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py
+++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py
@@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
)
]
)
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key")
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
+ encrypted_config="encrypted-old-key"
+ )
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
- with patch(
- "core.entities.provider_configuration.encrypter.encrypt_token",
- side_effect=lambda tenant_id, value: f"enc::{value}",
- ):
- validated = configuration.validate_provider_credentials(
- credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
- credential_id="credential-1",
- session=session,
- )
+ with _patched_session(mock_session):
+ with patch(
+ "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
+ ):
+ with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
+ with patch(
+ "core.entities.provider_configuration.encrypter.encrypt_token",
+ side_effect=lambda tenant_id, value: f"enc::{value}",
+ ):
+ validated = configuration.validate_provider_credentials(
+ credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
+ credential_id="credential-1",
+ )
assert validated["openai_api_key"] == "enc::restored-key"
assert validated["region"] == "us"
@@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
)
-def test_validate_provider_credentials_opens_session_when_not_passed() -> None:
+def test_validate_provider_credentials_without_credential_id() -> None:
configuration = _build_provider_configuration()
- mock_session = Mock()
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
- with patch("core.entities.provider_configuration.Session") as mock_session_cls:
- with patch("core.entities.provider_configuration.db") as mock_db:
- mock_db.engine = Mock()
- mock_session_cls.return_value.__enter__.return_value = mock_session
- with patch(
- "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
- ):
- validated = configuration.validate_provider_credentials(credentials={"region": "us"})
+ with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
+ validated = configuration.validate_provider_credentials(credentials={"region": "us"})
assert validated == {"region": "us"}
- mock_session_cls.assert_called_once()
def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None:
@@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non
def test_validate_provider_credentials_handles_invalid_original_json() -> None:
configuration = _build_provider_configuration()
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
+ encrypted_config="{invalid-json"
+ )
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
- validated = configuration.validate_provider_credentials(
- credentials={"openai_api_key": HIDDEN_VALUE},
- credential_id="cred-1",
- session=session,
- )
+ with _patched_session(mock_session):
+ with patch(
+ "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
+ ):
+ with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
+ validated = configuration.validate_provider_credentials(
+ credentials={"openai_api_key": HIDDEN_VALUE},
+ credential_id="cred-1",
+ )
assert validated == {"openai_api_key": "enc-key"}
@@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback(
def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None:
configuration = _build_provider_configuration()
configuration.provider.model_credential_schema = _build_secret_model_schema()
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
encrypted_config='{"openai_api_key":"enc"}'
)
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
- with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
- validated = configuration.validate_custom_model_credentials(
- model_type=ModelType.LLM,
- model="gpt-4o",
- credentials={"openai_api_key": HIDDEN_VALUE},
- credential_id="cred-1",
- session=session,
- )
- assert validated == {"openai_api_key": "enc-new"}
-
- session = Mock()
- mock_factory = Mock()
- mock_factory.model_credentials_validate.return_value = {"region": "us"}
- with _patched_session(session):
+ with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
- validated = configuration.validate_custom_model_credentials(
- model_type=ModelType.LLM,
- model="gpt-4o",
- credentials={"region": "us"},
- )
+ with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
+ with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
+ validated = configuration.validate_custom_model_credentials(
+ model_type=ModelType.LLM,
+ model="gpt-4o",
+ credentials={"openai_api_key": HIDDEN_VALUE},
+ credential_id="cred-1",
+ )
+ assert validated == {"openai_api_key": "enc-new"}
+
+ mock_factory2 = Mock()
+ mock_factory2.model_credentials_validate.return_value = {"region": "us"}
+ with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
+ validated = configuration.validate_custom_model_credentials(
+ model_type=ModelType.LLM,
+ model="gpt-4o",
+ credentials={"region": "us"},
+ )
assert validated == {"region": "us"}
@@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None:
def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None:
configuration = _build_provider_configuration()
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = None
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = None
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
- validated = configuration.validate_provider_credentials(
- credentials={"openai_api_key": HIDDEN_VALUE},
- credential_id="cred-1",
- session=session,
- )
+ with _patched_session(mock_session):
+ with patch(
+ "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
+ ):
+ with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
+ validated = configuration.validate_provider_credentials(
+ credentials={"openai_api_key": HIDDEN_VALUE},
+ credential_id="cred-1",
+ )
assert validated == {"openai_api_key": "enc-new"}
@@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None:
def test_validate_custom_model_credentials_handles_invalid_original_json() -> None:
configuration = _build_provider_configuration()
configuration.provider.model_credential_schema = _build_secret_model_schema()
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
+ encrypted_config="{invalid-json"
+ )
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
- validated = configuration.validate_custom_model_credentials(
- model_type=ModelType.LLM,
- model="gpt-4o",
- credentials={"openai_api_key": HIDDEN_VALUE},
- credential_id="cred-1",
- session=session,
- )
+ with _patched_session(mock_session):
+ with patch(
+ "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
+ ):
+ with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
+ validated = configuration.validate_custom_model_credentials(
+ model_type=ModelType.LLM,
+ model="gpt-4o",
+ credentials={"openai_api_key": HIDDEN_VALUE},
+ credential_id="cred-1",
+ )
assert validated == {"openai_api_key": "enc-new"}
diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py
index bb6e40e224..8cb0938575 100644
--- a/api/tests/unit_tests/core/file/test_models.py
+++ b/api/tests/unit_tests/core/file/test_models.py
@@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType
def test_file():
file = File(
- id="test-file",
+ file_id="test-file",
tenant_id="test-tenant-id",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
@@ -25,27 +25,21 @@ def test_file():
assert file.size == 67
-def test_file_model_validate_accepts_legacy_tenant_id():
- data = {
- "id": "test-file",
- "tenant_id": "test-tenant-id",
- "type": "image",
- "transfer_method": "tool_file",
- "related_id": "test-related-id",
- "filename": "image.png",
- "extension": ".png",
- "mime_type": "image/png",
- "size": 67,
- "storage_key": "test-storage-key",
- "url": "https://example.com/image.png",
- # Extra legacy fields
- "tool_file_id": "tool-file-123",
- "upload_file_id": "upload-file-456",
- "datasource_file_id": "datasource-file-789",
- }
+def test_file_constructor_accepts_legacy_tenant_id():
+ file = File(
+ file_id="test-file",
+ tenant_id="test-tenant-id",
+ file_type=FileType.IMAGE,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ tool_file_id="tool-file-123",
+ filename="image.png",
+ extension=".png",
+ mime_type="image/png",
+ size=67,
+ storage_key="test-storage-key",
+ url="https://example.com/image.png",
+ )
- file = File.model_validate(data)
-
- assert file.related_id == "test-related-id"
+ assert file.related_id == "tool-file-123"
assert file.storage_key == "test-storage-key"
assert "tenant_id" not in file.model_dump()
diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
index 3b5c5e6597..d9fed9ae2a 100644
--- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
+++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
@@ -1,11 +1,17 @@
from unittest.mock import MagicMock, patch
+import httpx
import pytest
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
+ SSRFProxy,
_get_user_provided_host_header,
+ _to_graphon_http_response,
+ graphon_ssrf_proxy,
make_request,
+ max_retries_exceeded_error,
+ request_error,
)
@@ -174,3 +180,56 @@ class TestFollowRedirectsParameter:
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
+
+
+def test_to_graphon_http_response_preserves_httpx_response_fields() -> None:
+ response = httpx.Response(
+ 201,
+ headers={"X-Test": "1"},
+ content=b"payload",
+ request=httpx.Request("GET", "https://example.com/resource"),
+ )
+
+ wrapped = _to_graphon_http_response(response)
+
+ assert wrapped.status_code == 201
+ assert wrapped.headers == {"x-test": "1", "content-length": "7"}
+ assert wrapped.content == b"payload"
+ assert wrapped.url == "https://example.com/resource"
+ assert wrapped.reason_phrase == "Created"
+ assert wrapped.text == "payload"
+
+
+def test_ssrf_proxy_exposes_expected_error_types() -> None:
+ proxy = SSRFProxy()
+
+ assert proxy.max_retries_exceeded_error is max_retries_exceeded_error
+ assert proxy.request_error is request_error
+ assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error
+ assert graphon_ssrf_proxy.request_error is request_error
+
+
+@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"])
+def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None:
+ response = httpx.Response(
+ 200,
+ headers={"X-Test": "1"},
+ content=b"ok",
+ request=httpx.Request("GET", "https://example.com/resource"),
+ )
+
+ with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method:
+ wrapped = getattr(graphon_ssrf_proxy, method_name)(
+ "https://example.com/resource",
+ max_retries=3,
+ headers={"X-Test": "1"},
+ )
+
+ mock_method.assert_called_once_with(
+ url="https://example.com/resource",
+ max_retries=3,
+ headers={"X-Test": "1"},
+ )
+ assert wrapped.status_code == 200
+ assert wrapped.url == "https://example.com/resource"
+ assert wrapped.content == b"ok"
diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py
index 81f8da9a62..bbbffa2e69 100644
--- a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py
+++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py
@@ -971,6 +971,23 @@ class TestHandlePostRequestNew:
assert isinstance(item, SessionMessage)
assert isinstance(item.message.root, JSONRPCError)
assert item.message.root.id == 77
+ assert item.message.root.error.message == "Session terminated by server"
+
+ def test_404_on_initialization_includes_url_in_error(self):
+ t = _new_transport(url="http://example.com/mcp/server/abc123/mcp")
+ q: queue.Queue = queue.Queue()
+ msg = _make_request_msg("initialize", 1)
+ ctx = self._make_ctx(t, q, message=msg)
+ mock_resp = MagicMock()
+ mock_resp.status_code = 404
+ ctx.client.stream = self._stream_ctx(mock_resp)
+ t._handle_post_request(ctx)
+ item = q.get_nowait()
+ assert isinstance(item, SessionMessage)
+ assert isinstance(item.message.root, JSONRPCError)
+ assert item.message.root.error.code == 32600
+ assert "404 Not Found" in item.message.root.error.message
+ assert "http://example.com/mcp/server/abc123/mcp" in item.message.root.error.message
def test_404_for_notification_no_error_sent(self):
t = _new_transport()
diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
index 249ecb5006..c4fd970562 100644
--- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
+++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
@@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderCredentialSchema,
ProviderEntity,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
-from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
-from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
+from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py
index 2cbff54c42..69650c85cc 100644
--- a/api/tests/unit_tests/core/ops/test_config_entity.py
+++ b/api/tests/unit_tests/core/ops/test_config_entity.py
@@ -1,16 +1,11 @@
-import pytest
-from pydantic import ValidationError
+from dify_trace_aliyun.config import AliyunConfig
+from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
+from dify_trace_langfuse.config import LangfuseConfig
+from dify_trace_langsmith.config import LangSmithConfig
+from dify_trace_opik.config import OpikConfig
+from dify_trace_weave.config import WeaveConfig
-from core.ops.entities.config_entity import (
- AliyunConfig,
- ArizeConfig,
- LangfuseConfig,
- LangSmithConfig,
- OpikConfig,
- PhoenixConfig,
- TracingProviderEnum,
- WeaveConfig,
-)
+from core.ops.entities.config_entity import TracingProviderEnum
class TestTracingProviderEnum:
@@ -27,349 +22,8 @@ class TestTracingProviderEnum:
assert TracingProviderEnum.ALIYUN == "aliyun"
-class TestArizeConfig:
- """Test cases for ArizeConfig"""
-
- def test_valid_config(self):
- """Test valid Arize configuration"""
- config = ArizeConfig(
- api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
- )
- assert config.api_key == "test_key"
- assert config.space_id == "test_space"
- assert config.project == "test_project"
- assert config.endpoint == "https://custom.arize.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = ArizeConfig()
- assert config.api_key is None
- assert config.space_id is None
- assert config.project is None
- assert config.endpoint == "https://otlp.arize.com"
-
- def test_project_validation_empty(self):
- """Test project validation with empty value"""
- config = ArizeConfig(project="")
- assert config.project == "default"
-
- def test_project_validation_none(self):
- """Test project validation with None value"""
- config = ArizeConfig(project=None)
- assert config.project == "default"
-
- def test_endpoint_validation_empty(self):
- """Test endpoint validation with empty value"""
- config = ArizeConfig(endpoint="")
- assert config.endpoint == "https://otlp.arize.com"
-
- def test_endpoint_validation_with_path(self):
- """Test endpoint validation normalizes URL by removing path"""
- config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
- assert config.endpoint == "https://custom.arize.com"
-
- def test_endpoint_validation_invalid_scheme(self):
- """Test endpoint validation rejects invalid schemes"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- ArizeConfig(endpoint="ftp://invalid.com")
-
- def test_endpoint_validation_no_scheme(self):
- """Test endpoint validation rejects URLs without scheme"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- ArizeConfig(endpoint="invalid.com")
-
-
-class TestPhoenixConfig:
- """Test cases for PhoenixConfig"""
-
- def test_valid_config(self):
- """Test valid Phoenix configuration"""
- config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
- assert config.api_key == "test_key"
- assert config.project == "test_project"
- assert config.endpoint == "https://custom.phoenix.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = PhoenixConfig()
- assert config.api_key is None
- assert config.project is None
- assert config.endpoint == "https://app.phoenix.arize.com"
-
- def test_project_validation_empty(self):
- """Test project validation with empty value"""
- config = PhoenixConfig(project="")
- assert config.project == "default"
-
- def test_endpoint_validation_with_path(self):
- """Test endpoint validation with path"""
- config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
- assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
-
- def test_endpoint_validation_without_path(self):
- """Test endpoint validation without path"""
- config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
- assert config.endpoint == "https://app.phoenix.arize.com"
-
-
-class TestLangfuseConfig:
- """Test cases for LangfuseConfig"""
-
- def test_valid_config(self):
- """Test valid Langfuse configuration"""
- config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
- assert config.public_key == "public_key"
- assert config.secret_key == "secret_key"
- assert config.host == "https://custom.langfuse.com"
-
- def test_valid_config_with_path(self):
- host = "https://custom.langfuse.com/api/v1"
- config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
- assert config.public_key == "public_key"
- assert config.secret_key == "secret_key"
- assert config.host == host
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = LangfuseConfig(public_key="public", secret_key="secret")
- assert config.host == "https://api.langfuse.com"
-
- def test_missing_required_fields(self):
- """Test that required fields are enforced"""
- with pytest.raises(ValidationError):
- LangfuseConfig()
-
- with pytest.raises(ValidationError):
- LangfuseConfig(public_key="public")
-
- with pytest.raises(ValidationError):
- LangfuseConfig(secret_key="secret")
-
- def test_host_validation_empty(self):
- """Test host validation with empty value"""
- config = LangfuseConfig(public_key="public", secret_key="secret", host="")
- assert config.host == "https://api.langfuse.com"
-
-
-class TestLangSmithConfig:
- """Test cases for LangSmithConfig"""
-
- def test_valid_config(self):
- """Test valid LangSmith configuration"""
- config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
- assert config.api_key == "test_key"
- assert config.project == "test_project"
- assert config.endpoint == "https://custom.smith.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = LangSmithConfig(api_key="key", project="project")
- assert config.endpoint == "https://api.smith.langchain.com"
-
- def test_missing_required_fields(self):
- """Test that required fields are enforced"""
- with pytest.raises(ValidationError):
- LangSmithConfig()
-
- with pytest.raises(ValidationError):
- LangSmithConfig(api_key="key")
-
- with pytest.raises(ValidationError):
- LangSmithConfig(project="project")
-
- def test_endpoint_validation_https_only(self):
- """Test endpoint validation only allows HTTPS"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
-
-
-class TestOpikConfig:
- """Test cases for OpikConfig"""
-
- def test_valid_config(self):
- """Test valid Opik configuration"""
- config = OpikConfig(
- api_key="test_key",
- project="test_project",
- workspace="test_workspace",
- url="https://custom.comet.com/opik/api/",
- )
- assert config.api_key == "test_key"
- assert config.project == "test_project"
- assert config.workspace == "test_workspace"
- assert config.url == "https://custom.comet.com/opik/api/"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = OpikConfig()
- assert config.api_key is None
- assert config.project is None
- assert config.workspace is None
- assert config.url == "https://www.comet.com/opik/api/"
-
- def test_project_validation_empty(self):
- """Test project validation with empty value"""
- config = OpikConfig(project="")
- assert config.project == "Default Project"
-
- def test_url_validation_empty(self):
- """Test URL validation with empty value"""
- config = OpikConfig(url="")
- assert config.url == "https://www.comet.com/opik/api/"
-
- def test_url_validation_missing_suffix(self):
- """Test URL validation requires /api/ suffix"""
- with pytest.raises(ValidationError, match="URL should end with /api/"):
- OpikConfig(url="https://custom.comet.com/opik/")
-
- def test_url_validation_invalid_scheme(self):
- """Test URL validation rejects invalid schemes"""
- with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
- OpikConfig(url="ftp://custom.comet.com/opik/api/")
-
-
-class TestWeaveConfig:
- """Test cases for WeaveConfig"""
-
- def test_valid_config(self):
- """Test valid Weave configuration"""
- config = WeaveConfig(
- api_key="test_key",
- entity="test_entity",
- project="test_project",
- endpoint="https://custom.wandb.ai",
- host="https://custom.host.com",
- )
- assert config.api_key == "test_key"
- assert config.entity == "test_entity"
- assert config.project == "test_project"
- assert config.endpoint == "https://custom.wandb.ai"
- assert config.host == "https://custom.host.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = WeaveConfig(api_key="key", project="project")
- assert config.entity is None
- assert config.endpoint == "https://trace.wandb.ai"
- assert config.host is None
-
- def test_missing_required_fields(self):
- """Test that required fields are enforced"""
- with pytest.raises(ValidationError):
- WeaveConfig()
-
- with pytest.raises(ValidationError):
- WeaveConfig(api_key="key")
-
- with pytest.raises(ValidationError):
- WeaveConfig(project="project")
-
- def test_endpoint_validation_https_only(self):
- """Test endpoint validation only allows HTTPS"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
-
- def test_host_validation_optional(self):
- """Test host validation is optional but validates when provided"""
- config = WeaveConfig(api_key="key", project="project", host=None)
- assert config.host is None
-
- config = WeaveConfig(api_key="key", project="project", host="")
- assert config.host == ""
-
- config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
- assert config.host == "https://valid.host.com"
-
- def test_host_validation_invalid_scheme(self):
- """Test host validation rejects invalid schemes when provided"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
-
-
-class TestAliyunConfig:
- """Test cases for AliyunConfig"""
-
- def test_valid_config(self):
- """Test valid Aliyun configuration"""
- config = AliyunConfig(
- app_name="test_app",
- license_key="test_license_key",
- endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
- )
- assert config.app_name == "test_app"
- assert config.license_key == "test_license_key"
- assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
- assert config.app_name == "dify_app"
-
- def test_missing_required_fields(self):
- """Test that required fields are enforced"""
- with pytest.raises(ValidationError):
- AliyunConfig()
-
- with pytest.raises(ValidationError):
- AliyunConfig(license_key="test_license")
-
- with pytest.raises(ValidationError):
- AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
-
- def test_app_name_validation_empty(self):
- """Test app_name validation with empty value"""
- config = AliyunConfig(
- license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
- )
- assert config.app_name == "dify_app"
-
- def test_endpoint_validation_empty(self):
- """Test endpoint validation with empty value"""
- config = AliyunConfig(license_key="test_license", endpoint="")
- assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
-
- def test_endpoint_validation_with_path(self):
- """Test endpoint validation preserves path for Aliyun endpoints"""
- config = AliyunConfig(
- license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
- )
- assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
-
- def test_endpoint_validation_invalid_scheme(self):
- """Test endpoint validation rejects invalid schemes"""
- with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
- AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
-
- def test_endpoint_validation_no_scheme(self):
- """Test endpoint validation rejects URLs without scheme"""
- with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
- AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
-
- def test_license_key_required(self):
- """Test that license_key is required and cannot be empty"""
- with pytest.raises(ValidationError):
- AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
-
- def test_valid_endpoint_format_examples(self):
- """Test valid endpoint format examples from comments"""
- valid_endpoints = [
- # cms2.0 public endpoint
- "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
- # cms2.0 intranet endpoint
- "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
- # xtrace public endpoint
- "http://tracing-cn-heyuan.arms.aliyuncs.com",
- # xtrace intranet endpoint
- "http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
- ]
-
- for endpoint in valid_endpoints:
- config = AliyunConfig(license_key="test_license", endpoint=endpoint)
- assert config.endpoint == endpoint
-
-
class TestConfigIntegration:
- """Integration tests for configuration classes"""
+ """Cross-provider configuration sanity checks"""
def test_all_configs_can_be_instantiated(self):
"""Test that all config classes can be instantiated with valid data"""
@@ -388,7 +42,6 @@ class TestConfigIntegration:
def test_url_normalization_consistency(self):
"""Test that URL normalization works consistently across configs"""
- # Test that paths are removed from endpoints
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
index 68aa130518..88bf555594 100644
--- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
+++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
@@ -56,7 +56,7 @@ class TestPluginModelRuntime:
assert len(providers) == 1
assert providers[0].provider == "langgenius/openai/openai"
assert providers[0].provider_name == "openai"
- assert providers[0].label.en_US == "OpenAI"
+ assert providers[0].label.en_us == "OpenAI"
client.fetch_model_providers.assert_called_once_with("tenant")
def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None:
diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
index d49b6e4b71..00a4207786 100644
--- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
+++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
@@ -466,7 +466,7 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
file_param = File(
tenant_id="tenant-1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/file.png",
storage_key="",
@@ -499,14 +499,14 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
file_one = File(
tenant_id="tenant-1",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/a.txt",
storage_key="",
)
file_two = File(
tenant_id="tenant-1",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/b.txt",
storage_key="",
diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
index 395d392127..e536c0831f 100644
--- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
@@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
files = [
File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
@@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files():
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query():
memory = MagicMock(spec=TokenBufferMemory)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
index 803afa54d7..28966242d8 100644
--- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
@@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import Conversation
diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
index 9f9ea33695..5308c8e7b3 100644
--- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
@@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey
# from graphon.model_runtime.entities.message_entities import UserPromptMessage
# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
# from graphon.model_runtime.entities.provider_entities import ProviderEntity
-# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
# from core.prompt.prompt_transform import PromptTransform
diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
index 64eb89590a..0220fb6d4a 100644
--- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
+++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
@@ -1,12 +1,14 @@
"""Primarily used for testing merged cell scenarios"""
+import gc
import io
import os
import tempfile
+import warnings
from collections import UserDict
from pathlib import Path
from types import SimpleNamespace
-from unittest.mock import MagicMock
+from unittest.mock import AsyncMock, MagicMock
import pytest
from docx import Document
@@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
WordExtractor("not-a-file", "tenant", "user")
-def test_del_closes_temp_file():
+def test_close_closes_temp_file():
extractor = object.__new__(WordExtractor)
+ extractor._closed = False
extractor.temp_file = MagicMock()
- WordExtractor.__del__(extractor)
+ extractor.close()
extractor.temp_file.close.assert_called_once()
+def test_close_is_idempotent():
+ extractor = object.__new__(WordExtractor)
+ extractor._closed = False
+ extractor.temp_file = MagicMock()
+
+ extractor.close()
+ extractor.close()
+
+ extractor.temp_file.close.assert_called_once()
+
+
+def test_close_handles_async_close_mock():
+ extractor = object.__new__(WordExtractor)
+ extractor._closed = False
+ extractor.temp_file = MagicMock()
+ extractor.temp_file.close = AsyncMock()
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ extractor.close()
+ gc.collect()
+
+ extractor.temp_file.close.assert_called_once()
+ assert not [
+ warning
+ for warning in caught
+ if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message)
+ ]
+
+
def test_extract_images_handles_invalid_external_cases(monkeypatch):
class FakeTargetRef:
def __contains__(self, item):
diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
index 8be1ac318c..18ae9fafc8 100644
--- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
+++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
@@ -14,7 +14,7 @@ from core.repositories.human_input_repository import (
HumanInputFormSubmissionRepository,
_WorkspaceMemberInfo,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py
index 1297a95df1..4248782d93 100644
--- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py
+++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py
@@ -21,7 +21,7 @@ from core.repositories.human_input_repository import (
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py
index f17927f16b..eab0176f41 100644
--- a/api/tests/unit_tests/core/test_file.py
+++ b/api/tests/unit_tests/core/test_file.py
@@ -6,9 +6,9 @@ from models.workflow import Workflow
def test_file_to_dict():
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",
diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py
index f45b43082c..a5a542c94f 100644
--- a/api/tests/unit_tests/core/test_provider_manager.py
+++ b/api/tests/unit_tests/core/test_provider_manager.py
@@ -372,6 +372,78 @@ def test_get_configurations_binds_manager_runtime_to_provider_configuration(
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
+def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerFixture, mock_provider_entity):
+ manager = _build_provider_manager(mocker)
+ provider_configuration = Mock()
+ provider_factory = Mock()
+ provider_factory.get_providers.return_value = [mock_provider_entity]
+ custom_configuration = SimpleNamespace(provider=None, models=[])
+ system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
+
+ with (
+ patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
+ patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
+ patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
+ patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
+ patch.object(manager, "_get_all_provider_model_settings", return_value={}),
+ patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
+ patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
+ patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
+ patch.object(manager, "_to_system_configuration", return_value=system_configuration),
+ patch.object(manager, "_to_model_settings", return_value=[]),
+ patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls,
+ patch(
+ "core.provider_manager.ProviderConfiguration",
+ return_value=provider_configuration,
+ ) as mock_provider_configuration,
+ ):
+ first = manager.get_configurations("tenant-id")
+ second = manager.get_configurations("tenant-id")
+
+ assert first is second
+ mock_get_all_providers.assert_called_once_with("tenant-id")
+ mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
+ mock_provider_configuration.assert_called_once()
+ provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
+
+
+def test_clear_configurations_cache_rebuilds_requested_tenant(mocker: MockerFixture, mock_provider_entity):
+ manager = _build_provider_manager(mocker)
+ provider_factory = Mock()
+ provider_factory.get_providers.return_value = [mock_provider_entity]
+ custom_configuration = SimpleNamespace(provider=None, models=[])
+ system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
+ provider_configuration_first = Mock()
+ provider_configuration_second = Mock()
+
+ with (
+ patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
+ patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
+ patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
+ patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
+ patch.object(manager, "_get_all_provider_model_settings", return_value={}),
+ patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
+ patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
+ patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
+ patch.object(manager, "_to_system_configuration", return_value=system_configuration),
+ patch.object(manager, "_to_model_settings", return_value=[]),
+ patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory),
+ patch(
+ "core.provider_manager.ProviderConfiguration",
+ side_effect=[provider_configuration_first, provider_configuration_second],
+ ) as mock_provider_configuration,
+ ):
+ first = manager.get_configurations("tenant-id")
+ manager.clear_configurations_cache("tenant-id")
+ second = manager.get_configurations("tenant-id")
+
+ assert first is not second
+ assert mock_get_all_providers.call_count == 2
+ assert mock_provider_configuration.call_count == 2
+ provider_configuration_first.bind_model_runtime.assert_called_once_with(manager._model_runtime)
+ provider_configuration_second.bind_model_runtime.assert_called_once_with(manager._model_runtime)
+
+
def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture):
manager = _build_provider_manager(mocker)
provider_configuration = Mock()
diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py
index 72052c8c05..9e07ea1b6d 100644
--- a/api/tests/unit_tests/core/variables/test_segment.py
+++ b/api/tests/unit_tests/core/variables/test_segment.py
@@ -1,8 +1,9 @@
import dataclasses
+from typing import Annotated
import orjson
import pytest
-from pydantic import BaseModel
+from pydantic import BaseModel, Discriminator, Tag
from core.helper import encrypter
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
@@ -12,17 +13,18 @@ from graphon.runtime import VariablePool
from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import (
ArrayAnySegment,
+ ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
+ BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
- SegmentUnion,
StringSegment,
get_segment_discriminator,
)
@@ -47,6 +49,26 @@ from graphon.variables.variables import (
StringVariable,
Variable,
)
+from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup
+
+type SegmentUnion = Annotated[
+ (
+ Annotated[NoneSegment, Tag(SegmentType.NONE)]
+ | Annotated[StringSegment, Tag(SegmentType.STRING)]
+ | Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
+ | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
+ | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
+ | Annotated[FileSegment, Tag(SegmentType.FILE)]
+ | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
+ | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
+ | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
+ | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
+ | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
+ | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
+ | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
+ ),
+ Discriminator(get_segment_discriminator),
+]
def _build_variable_pool(
@@ -123,7 +145,7 @@ def create_test_file(
) -> File:
"""Factory function to create File objects for testing"""
return File(
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
@@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad:
assert restored == model
def test_all_segments_serialization(self):
- """Test serialization/deserialization of all segment types"""
+ """Test file-aware segment serialization through Dify's model boundary."""
# Create one instance of each segment type
test_file = create_test_file()
@@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
- loaded = _Segments.model_validate_json(json_str)
+ loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
@@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
- """Test serialization/deserialization of all variable types"""
+ """Test file-aware variable serialization through Dify's model boundary."""
# Create one instance of each variable type
test_file = create_test_file()
@@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
- loaded = _Variables.model_validate_json(json_str)
+ loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)
diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py
index 94e788edb2..317fe99d37 100644
--- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py
+++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py
@@ -35,7 +35,7 @@ def create_test_file(
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
index 76b2984a4b..9f3e3b00b9 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
@@ -1,12 +1,13 @@
-"""
-Mock node factory for testing workflows with third-party service dependencies.
+"""Mock node factory for third-party-service workflow tests.
-This module provides a MockNodeFactory that automatically detects and mocks nodes
-requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
+The factory follows the same config adaptation path as production
+`DifyNodeFactory.create_node()`, but swaps selected node classes for mock
+implementations before instantiation.
"""
from typing import TYPE_CHECKING, Any
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_factory import DifyNodeFactory
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes, NodeType
@@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory):
:param node_config: Node configuration dictionary
:return: Node instance (real or mocked)
"""
- typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
+ typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
+ node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_type = node_data.type
# Check if this node type should be mocked
if node_type in self._mock_node_types:
- node_id = typed_node_config["id"]
-
# Create mock node instance
mock_class = self._mock_node_types[node_type]
+ resolved_node_data = self._validate_resolved_node_data(mock_class, node_data)
if node_type == BuiltinNodeTypes.CODE:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory):
)
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory):
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
}:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory):
)
else:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory):
return mock_instance
# For non-mocked node types, use parent implementation
- return super().create_node(typed_node_config)
+ return super().create_node(node_config)
def should_mock_node(self, node_type: NodeType) -> bool:
"""
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
index 971b9b2bbf..f9819c47ec 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
@@ -55,13 +55,14 @@ class MockNodeMixin:
def __init__(
self,
- id: str,
- config: Mapping[str, Any],
+ node_id: str,
+ config: Any,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
- ):
+ ) -> None:
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
@@ -96,7 +97,7 @@ class MockNodeMixin:
kwargs.setdefault("message_transformer", MagicMock())
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
index 55a329eba9..75bc6d05f7 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
@@ -139,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
- id=start_config["id"],
- config=start_config,
+ node_id=start_config["id"],
+ config=StartNodeData(title="Start", variables=[]),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -154,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
- id=human_a_config["id"],
- config=human_a_config,
+ node_id=human_a_config["id"],
+ config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -164,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
- id=human_b_config["id"],
- config=human_b_config,
+ node_id=human_b_config["id"],
+ config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -182,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
- id=end_config["id"],
- config=end_config,
+ node_id=end_config["id"],
+ config=end_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
index 9c0ad25b58..76b4cd1ef4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
+++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
@@ -9,6 +9,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.answer.answer_node import AnswerNode
+from graphon.nodes.answer.entities import AnswerNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,20 +67,15 @@ def test_execute_answer():
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "answer",
- "data": {
- "title": "123",
- "type": "answer",
- "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
- },
- }
-
node = AnswerNode(
- id=str(uuid.uuid4()),
+ node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
+ config=AnswerNodeData(
+ title="123",
+ type="answer",
+ answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
+ ),
)
# Mock db.session.close()
diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
index 9cceadde49..d7ef781732 100644
--- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
+from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@@ -77,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
- id="n",
- config={
- "id": "n",
- "data": {
- "type": "datasource",
- "version": "1",
- "title": "Datasource",
- "provider_type": "plugin",
- "provider_name": "p",
- "plugin_id": "plug",
- "datasource_name": "ds",
- },
- },
+ node_id="n",
+ config=DatasourceNodeData(
+ type="datasource",
+ version="1",
+ title="Datasource",
+ provider_type="plugin",
+ provider_name="p",
+ plugin_id="plug",
+ datasource_name="ds",
+ ),
graph_init_params=gp,
graph_runtime_state=gs,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
index a3cadc0681..2e89a2da3c 100644
--- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
@@ -12,7 +12,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
-from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response
+from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config():
assert default_config["retry_config"]["max_retries"] == 7
-def test_get_default_config_with_malformed_http_request_config_raises_value_error():
- with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"):
+def test_get_default_config_with_malformed_http_request_config_raises_type_error():
+ with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"):
HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"})
@@ -114,8 +114,8 @@ def _build_http_node(
start_at=time.perf_counter(),
)
return HttpRequestNode(
- id="http-node",
- config=node_config,
+ node_id="http-node",
+ config=HttpRequestNodeData.model_validate(node_data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
index 1d6a4da7c4..07430498e5 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
@@ -1,4 +1,4 @@
-from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients
+from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients
from graphon.runtime import VariablePool
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
index c0e21d0bf7..0659984c76 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
@@ -19,7 +19,7 @@ from core.repositories.human_input_repository import (
HumanInputFormRecipientEntity,
HumanInputFormRepository,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryMethodType,
EmailDeliveryConfig,
EmailDeliveryMethod,
@@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
entity.status_value = HumanInputFormStatus.SUBMITTED
+def _build_human_input_node(
+ *,
+ node_id: str,
+ node_data: HumanInputNodeData | Mapping[str, Any],
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
+ runtime: DifyHumanInputNodeRuntime,
+) -> HumanInputNode:
+ typed_node_data = (
+ node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data)
+ )
+ return HumanInputNode(
+ node_id=node_id,
+ config=typed_node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ runtime=runtime,
+ )
+
+
class TestDeliveryMethod:
"""Test DeliveryMethod entity."""
@@ -239,7 +259,7 @@ class TestUserAction:
data[field_name] = value
with pytest.raises(ValidationError) as exc_info:
- UserAction(**data)
+ UserAction.model_validate(data)
errors = exc_info.value.errors()
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
@@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent:
form_repository = InMemoryHumanInputFormRepository()
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
index bc98028d5b..4a9438b14f 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
@@ -11,6 +11,7 @@ from graphon.graph_events import (
NodeRunHumanInputFormTimeoutEvent,
NodeRunStartedEvent,
)
+from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.human_input.enums import HumanInputFormStatus
from graphon.nodes.human_input.human_input_node import HumanInputNode
from graphon.runtime import GraphRuntimeState, VariablePool
@@ -25,6 +26,28 @@ class _FakeFormRepository:
return self._form
+def _create_human_input_node(
+ *,
+ config: dict,
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
+ repo: _FakeFormRepository,
+) -> HumanInputNode:
+ node_data = (
+ config["data"]
+ if isinstance(config["data"], HumanInputNodeData)
+ else HumanInputNodeData.model_validate(config["data"])
+ )
+ return HumanInputNode(
+ node_id=config["id"],
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ form_repository=repo,
+ runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ )
+
+
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
@@ -80,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
)
repo = _FakeFormRepository(fake_form)
- return HumanInputNode(
- id="node-1",
+ return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
- form_repository=repo,
- runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ repo=repo,
)
@@ -145,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode:
)
repo = _FakeFormRepository(fake_form)
- return HumanInputNode(
- id="node-1",
+ return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
- form_repository=repo,
- runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ repo=repo,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
index 82cc734274..8ffce39cd6 100644
--- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
@@ -5,6 +5,7 @@ import pytest
from core.workflow.system_variables import default_system_variables
from graphon.entities import GraphInitParams
+from graphon.nodes.iteration.entities import IterationNodeData
from graphon.nodes.iteration.exc import IterationGraphNotFoundError
from graphon.nodes.iteration.iteration_node import IterationNode
from graphon.runtime import (
@@ -44,17 +45,14 @@ def _build_iteration_node(
) -> IterationNode:
init_params = build_test_graph_init_params(graph_config=graph_config)
return IterationNode(
- id="iteration-node",
- config={
- "id": "iteration-node",
- "data": {
- "type": "iteration",
- "title": "Iteration",
- "iterator_selector": ["start", "items"],
- "output_selector": ["iteration-node", "output"],
- "start_node_id": start_node_id,
- },
- },
+ node_id="iteration-node",
+ config=IterationNodeData(
+ type="iteration",
+ title="Iteration",
+ iterator_selector=["start", "items"],
+ output_selector=["iteration-node", "output"],
+ start_node_id=start_node_id,
+ ),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
index a6fca1bfb4..f254fc3d09 100644
--- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
@@ -93,6 +93,25 @@ def sample_chunks():
}
+def _build_node(
+ *,
+ node_id: str,
+ node_data: KnowledgeIndexNodeData | dict[str, object],
+ graph_init_params,
+ graph_runtime_state,
+) -> KnowledgeIndexNode:
+ return KnowledgeIndexNode(
+ node_id=node_id,
+ config=(
+ node_data
+ if isinstance(node_data, KnowledgeIndexNodeData)
+ else KnowledgeIndexNodeData.model_validate(node_data)
+ ),
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+
class TestKnowledgeIndexNode:
"""
Test suite for KnowledgeIndexNode.
@@ -115,9 +134,9 @@ class TestKnowledgeIndexNode:
}
# Act
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -143,9 +162,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -176,9 +195,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -212,9 +231,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -269,9 +288,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -332,9 +351,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -383,9 +402,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -440,9 +459,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -498,9 +517,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -536,9 +555,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -583,9 +602,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
index 45e8ae7d20..e923ee761b 100644
--- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
@@ -14,7 +14,11 @@ from core.workflow.nodes.knowledge_retrieval.entities import (
SingleRetrievalConfig,
)
from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError
-from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
+from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import (
+ KnowledgeRetrievalNode,
+ _normalize_metadata_filter_scalar,
+ _normalize_metadata_filter_sequence_item,
+)
from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
@@ -85,6 +89,12 @@ def sample_node_data():
)
+def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None:
+ assert _normalize_metadata_filter_scalar(3) == 3
+ assert _normalize_metadata_filter_scalar(True) == "True"
+ assert _normalize_metadata_filter_sequence_item(4) == "4"
+
+
class TestKnowledgeRetrievalNode:
"""
Test suite for KnowledgeRetrievalNode.
@@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode:
# Act
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -470,8 +480,8 @@ class TestFetchDatasetRetriever:
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -507,8 +517,8 @@ class TestFetchDatasetRetriever:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -562,8 +572,8 @@ class TestFetchDatasetRetriever:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -610,8 +620,8 @@ class TestFetchDatasetRetriever:
mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme"))
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -671,8 +681,8 @@ class TestFetchDatasetRetriever:
node_id = str(uuid.uuid4())
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
index eca34f05be..388654f279 100644
--- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
+++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
@@ -1,3 +1,4 @@
+from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@@ -5,6 +6,7 @@ import pytest
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from graphon.entities import GraphInitParams
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
+from graphon.nodes.list_operator.entities import ListOperatorNodeData
from graphon.nodes.list_operator.node import ListOperatorNode
from graphon.runtime import GraphRuntimeState
from graphon.variables import ArrayNumberSegment, ArrayStringSegment
@@ -13,11 +15,28 @@ from graphon.variables import ArrayNumberSegment, ArrayStringSegment
class TestListOperatorNode:
"""Comprehensive tests for ListOperatorNode."""
+ @staticmethod
+ def _build_node(*, config, graph_init_params, graph_runtime_state):
+ return ListOperatorNode(
+ node_id="test",
+ config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config),
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+ @staticmethod
+ def _filter_by(comparison_operator: str, value: str) -> dict[str, object]:
+ return {
+ "enabled": True,
+ "conditions": [{"comparison_operator": comparison_operator, "value": value}],
+ }
+
@pytest.fixture
def mock_graph_runtime_state(self):
"""Create mock GraphRuntimeState."""
mock_state = MagicMock(spec=GraphRuntimeState)
mock_variable_pool = MagicMock()
+ mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value)
mock_state.variable_pool = mock_variable_pool
return mock_state
@@ -45,9 +64,8 @@ class TestListOperatorNode:
def _create_node(config, mock_variable):
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
- return ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ return self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -64,9 +82,8 @@ class TestListOperatorNode:
"limit": {"enabled": False},
}
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -109,9 +126,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=[])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -128,11 +144,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "contains",
- "value": "app",
- },
+ "filter_by": self._filter_by("contains", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -140,9 +152,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -157,11 +168,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "not contains",
- "value": "app",
- },
+ "filter_by": self._filter_by("not contains", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -169,9 +176,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -186,11 +192,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": ">",
- "value": "5",
- },
+ "filter_by": self._filter_by(">", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -198,9 +200,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -226,9 +227,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -254,9 +254,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -282,9 +281,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -299,11 +297,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": ">",
- "value": "3",
- },
+ "filter_by": self._filter_by(">", "3"),
"order_by": {
"enabled": True,
"value": "desc",
@@ -317,9 +311,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -341,9 +334,8 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = None
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -366,9 +358,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["first", "middle", "last"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -384,11 +375,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "start with",
- "value": "app",
- },
+ "filter_by": self._filter_by("start with", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -396,9 +383,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -413,11 +399,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "end with",
- "value": "le",
- },
+ "filter_by": self._filter_by("end with", "le"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -425,9 +407,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -442,11 +423,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": "=",
- "value": "5",
- },
+ "filter_by": self._filter_by("=", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -454,9 +431,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -471,11 +447,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": "≠",
- "value": "5",
- },
+ "filter_by": self._filter_by("≠", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -483,9 +455,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -511,9 +482,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
index 4186bbdc93..212ad07bd3 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
@@ -71,8 +71,8 @@ def _build_image_file(
mime_type: str = "image/png",
) -> File:
return File(
- id=file_id,
- type=FileType.IMAGE,
+ file_id=file_id,
+ file_type=FileType.IMAGE,
filename=f"{file_id}{extension}",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=remote_url,
@@ -95,6 +95,8 @@ def variable_pool() -> VariablePool:
def _fetch_prompt_messages_with_mocked_content(content):
variable_pool = VariablePool.empty()
model_instance = mock.MagicMock(spec=ModelInstance)
+ model_schema = mock.MagicMock()
+ model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text"
prompt_template = [
LLMNodeChatModelMessage(
text="You are a classifier.",
@@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content):
with (
mock.patch(
"graphon.nodes.llm.llm_utils.fetch_model_schema",
- return_value=mock.MagicMock(features=[]),
+ return_value=model_schema,
),
mock.patch(
"graphon.nodes.llm.llm_utils.handle_list_messages",
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
index b1f81b6c48..c707cf28cd 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
@@ -140,8 +140,8 @@ def _build_image_file(
mime_type: str = "image/png",
) -> File:
return File(
- id=file_id,
- type=FileType.IMAGE,
+ file_id=file_id,
+ file_type=FileType.IMAGE,
filename=f"{file_id}{extension}",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=remote_url,
@@ -205,14 +205,10 @@ def llm_node(
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol)
- node_config = {
- "id": "1",
- "data": llm_node_data.model_dump(),
- }
http_client = mock.MagicMock()
node = LLMNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers():
def test_fetch_files_with_file_segment():
file = File(
- id="1",
- type=FileType.IMAGE,
+ file_id="1",
+ file_type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
@@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment():
def test_fetch_files_with_array_file_segment():
files = [
File(
- id="1",
- type=FileType.IMAGE,
+ file_id="1",
+ file_type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
),
File(
- id="2",
- type=FileType.IMAGE,
+ file_id="2",
+ file_type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
@@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol)
- node_config = {
- "id": "1",
- "data": llm_node_data.model_dump(),
- }
http_client = mock.MagicMock()
node = LLMNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png",
)
mock_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
@@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/jpg",
)
mock_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
@@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
image_b64_data = base64.b64encode(image_raw_data).decode()
mock_saved_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
filename="test.png",
extension=".png",
@@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable
file_saver=file_saver,
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
reasoning_format="separated",
)
)
@@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out
file_saver=mock.MagicMock(spec=LLMFileSaver),
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
model_instance=_build_prepared_llm_mock(),
reasoning_format="separated",
request_start_time=1.0,
@@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors():
file_saver=mock.MagicMock(spec=LLMFileSaver),
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
model_instance=model_instance,
)
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
index bc44ececd8..892f6cc586 100644
--- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
@@ -13,6 +13,28 @@ from graphon.template_rendering import TemplateRenderError
from tests.workflow_test_utils import build_test_graph_init_params
+def _build_template_transform_node(
+ *,
+ node_data,
+ graph_init_params,
+ graph_runtime_state,
+ node_id: str = "test_node",
+ **kwargs,
+) -> TemplateTransformNode:
+ typed_node_data = (
+ node_data
+ if isinstance(node_data, TemplateTransformNodeData)
+ else TemplateTransformNodeData.model_validate(node_data)
+ )
+ return TemplateTransformNode(
+ node_id=node_id,
+ config=typed_node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ **kwargs,
+ )
+
+
class TestTemplateTransformNode:
"""Comprehensive test suite for TemplateTransformNode."""
@@ -59,9 +81,8 @@ class TestTemplateTransformNode:
def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test that TemplateTransformNode initializes correctly."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -75,9 +96,8 @@ class TestTemplateTransformNode:
def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_title method."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -88,9 +108,8 @@ class TestTemplateTransformNode:
def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_description method."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -108,9 +127,8 @@ class TestTemplateTransformNode:
}
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -143,9 +161,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
with pytest.raises(ValueError, match="max_output_length must be a positive integer"):
- TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -170,9 +187,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -198,9 +214,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Value: "
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -218,9 +233,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error")
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -238,9 +252,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -260,9 +273,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "1234567890"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -302,9 +314,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -375,8 +386,8 @@ class TestTemplateTransformNode:
)
assert mapping == {
- "node_123.var1": ["sys", "input1"],
- "node_123.empty_selector": [],
+ "node_123.var1": ("sys", "input1"),
+ "node_123.empty_selector": (),
}
def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self):
@@ -409,9 +420,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a static message."
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -448,9 +458,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Total: $31.5"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -477,9 +486,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -507,9 +515,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Tags: #python #ai #workflow "
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
index 636237e56e..a846efbb43 100644
--- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
@@ -4,6 +4,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from graphon.nodes.base.entities import VariableSelector
+from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import (
DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH,
TemplateTransformNode,
@@ -37,15 +38,13 @@ def mock_graph_runtime_state():
def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state):
node = TemplateTransformNode(
- id="test_node",
- config={
- "id": "test_node",
- "data": {
- "title": "Template Transform",
- "variables": [],
- "template": "hello",
- },
- },
+ node_id="test_node",
+ config=TemplateTransformNodeData(
+ title="Template Transform",
+ type="template-transform",
+ variables=[],
+ template="hello",
+ ),
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=MagicMock(),
@@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri
assert mapping == {
"node_123.validated": ["sys", "input1"],
- "node_123.raw": ["sys", "input2"],
+ "node_123.raw": ("sys", "input2"),
}
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
index 0522dd9d14..364408ead6 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
@@ -7,7 +7,6 @@ from core.workflow.node_runtime import resolve_dify_run_context
from core.workflow.system_variables import build_system_variables
from graphon.entities import GraphInitParams
from graphon.entities.base_node_data import BaseNodeData
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes
from graphon.nodes.base.node import Node
from graphon.runtime import GraphRuntimeState, VariablePool
@@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
return init_params, runtime_state
-def _build_node_config() -> NodeConfigDict:
- return NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- },
- }
- )
+def _build_node_config() -> dict[str, object]:
+ return {
+ "id": "node-1",
+ "data": _SampleNodeData(
+ type=BuiltinNodeTypes.ANSWER,
+ title="Sample",
+ foo="bar",
+ ),
+ }
+
+
+def _build_node_data() -> _SampleNodeData:
+ return _build_node_config()["data"] # type: ignore[return-value]
def test_node_hydrates_data_during_initialization():
@@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization():
init_params, runtime_state = _build_context(graph_config)
node = _SampleNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum():
)
node = _SampleNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error():
def test_base_node_data_keeps_dict_style_access_compatibility():
- node_data = _SampleNodeData.model_validate(
- {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- }
- )
+ node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar")
assert node_data["foo"] == "bar"
assert node_data.get("foo") == "bar"
@@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility():
def test_node_hydration_preserves_compatibility_extra_fields():
graph_config: dict[str, object] = {}
init_params, runtime_state = _build_context(graph_config)
- node_config = NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- "compat_flag": True,
- },
- }
- )
+ node_config = {
+ "id": "node-1",
+ "data": _SampleNodeData(
+ type=BuiltinNodeTypes.ANSWER,
+ title="Sample",
+ foo="bar",
+ compat_flag=True,
+ ),
+ }
node = _SampleNode(
- id="node-1",
- config=node_config,
+ node_id="node-1",
+ config=node_config["data"],
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
index 87ec2d5bce..dd75b32593 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
@@ -11,14 +11,16 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.file import File, FileTransferMethod
from graphon.node_events import NodeRunResult
from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
+from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError
from graphon.nodes.document_extractor.node import (
_extract_text_from_docx,
_extract_text_from_excel,
+ _extract_text_from_file,
_extract_text_from_pdf,
_extract_text_from_plain_text,
_normalize_docx_zip,
)
-from graphon.variables import ArrayFileSegment
+from graphon.variables import ArrayFileSegment, FileSegment
from graphon.variables.segments import ArrayStringSegment
from graphon.variables.variables import StringVariable
from tests.workflow_test_utils import build_test_graph_init_params
@@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params):
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
- node_config = {"id": "test_node_id", "data": node_data.model_dump()}
http_client = Mock()
node = DocumentExtractorNode(
- id="test_node_id",
- config=node_config,
+ node_id="test_node_id",
+ config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=Mock(),
http_client=http_client,
@@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file):
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"]
- mock_excel_instance.parse.side_effect = [df, Exception("Parse error")]
+ mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_mixed_content"
@@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"]
- mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")]
+ mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_all_bad_sheets"
@@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
assert mock_excel_instance.parse.call_count == 2
+@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook"))
+def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file):
+ with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"):
+ _extract_text_from_excel(b"broken")
+
+
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
"""Test extracting text from Excel file with numeric column names."""
@@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
assert expected_manual == result
+@pytest.mark.parametrize(
+ ("extension", "mime_type"),
+ [
+ (".xlsx", "text/plain"),
+ (None, "application/vnd.ms-excel"),
+ ],
+)
+def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type):
+ file = Mock(spec=File)
+ file.extension = extension
+ file.mime_type = mime_type
+
+ with (
+ patch(
+ "graphon.nodes.document_extractor.node._download_file_content",
+ return_value=b"excel",
+ ),
+ patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_excel",
+ return_value="excel text",
+ ) as mock_extract,
+ ):
+ result = _extract_text_from_file(
+ document_extractor_node.http_client,
+ file,
+ unstructured_api_config=document_extractor_node._unstructured_api_config,
+ )
+
+ assert result == "excel text"
+ mock_extract.assert_called_once_with(b"excel")
+
+
+def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):
+ file = Mock(spec=File)
+ file.extension = None
+ file.mime_type = None
+
+ with patch(
+ "graphon.nodes.document_extractor.node._download_file_content",
+ return_value=b"unknown",
+ ):
+ with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"):
+ _extract_text_from_file(
+ document_extractor_node.http_client,
+ file,
+ unstructured_api_config=document_extractor_node._unstructured_api_config,
+ )
+
+
+def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_list = Mock(spec=ArrayFileSegment)
+ file_list.value = [Mock(spec=File)]
+ mock_graph_runtime_state.variable_pool.get.return_value = file_list
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ side_effect=TextExtractionError("bad file"),
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert result.error == "bad file"
+
+
+def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_segment = Mock(spec=FileSegment)
+ file_segment.value = Mock(spec=File)
+ mock_graph_runtime_state.variable_pool.get.return_value = file_segment
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ side_effect=TextExtractionError("single file failed"),
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert result.error == "single file failed"
+
+
+def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_segment = Mock(spec=FileSegment)
+ file_segment.value = Mock(spec=File)
+ mock_graph_runtime_state.variable_pool.get.return_value = file_segment
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ return_value="single file text",
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs == {"text": "single file text"}
+
+
def _make_docx_zip(use_backslash: bool) -> bytes:
"""Helper to build a minimal in-memory DOCX zip.
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
index 782750e02e..aa9a1360b0 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
@@ -19,6 +19,20 @@ from graphon.variables import ArrayFileSegment
from tests.workflow_test_utils import build_test_graph_init_params
+def _build_if_else_node(
+ *,
+ node_data: IfElseNodeData | dict[str, object],
+ init_params,
+ graph_runtime_state,
+) -> IfElseNode:
+ return IfElseNode(
+ node_id=str(uuid.uuid4()),
+ graph_init_params=init_params,
+ graph_runtime_state=graph_runtime_state,
+ config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
+ )
+
+
def test_execute_if_else_result_true():
graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]}
@@ -61,9 +75,8 @@ def test_execute_if_else_result_true():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "if-else",
- "data": {
+ node = _build_if_else_node(
+ node_data={
"title": "123",
"type": "if-else",
"logical_operator": "and",
@@ -104,13 +117,8 @@ def test_execute_if_else_result_true():
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
- }
-
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
)
# Mock db.session.close()
@@ -155,9 +163,8 @@ def test_execute_if_else_result_false():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "if-else",
- "data": {
+ node = _build_if_else_node(
+ node_data={
"title": "123",
"type": "if-else",
"logical_operator": "or",
@@ -174,13 +181,8 @@ def test_execute_if_else_result_false():
},
],
},
- }
-
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
)
# Mock db.session.close()
@@ -222,11 +224,6 @@ def test_array_file_contains_file_name():
],
)
- node_config = {
- "id": "if-else",
- "data": node_data.model_dump(),
- }
-
# Create properly configured mock for graph_init_params
graph_init_params = Mock()
graph_init_params.workflow_id = "test_workflow"
@@ -242,17 +239,12 @@ def test_array_file_contains_file_name():
}
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=graph_init_params,
- graph_runtime_state=Mock(),
- config=node_config,
- )
+ node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock())
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
filename="ab",
@@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
"logical_operator": "and",
"conditions": [condition.model_dump()],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={"id": "if-else", "data": node_data},
)
# Mock db.session.close()
@@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions():
],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={
- "id": "if-else",
- "data": node_data,
- },
)
# Mock db.session.close()
@@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure():
}
],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={"id": "if-else", "data": node_data},
)
# Mock db.session.close()
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
index b217e4e8e7..465a4c0ff4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
@@ -19,6 +19,15 @@ from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract
from graphon.variables import ArrayFileSegment
+def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode:
+ return ListOperatorNode(
+ node_id="test_node_id",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=MagicMock(),
+ )
+
+
@pytest.fixture
def list_operator_node():
config = {
@@ -35,10 +44,6 @@ def list_operator_node():
"title": "Test Title",
}
node_data = ListOperatorNodeData.model_validate(config)
- node_config = {
- "id": "test_node_id",
- "data": node_data.model_dump(),
- }
# Create properly configured mock for graph_init_params
graph_init_params = MagicMock()
graph_init_params.workflow_id = "test_workflow"
@@ -54,12 +59,7 @@ def list_operator_node():
}
}
- node = ListOperatorNode(
- id="test_node_id",
- config=node_config,
- graph_init_params=graph_init_params,
- graph_runtime_state=MagicMock(),
- )
+ node = _build_list_operator_node(node_data, graph_init_params)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node
@@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node):
files = [
File(
filename="image1.jpg",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related1",
storage_key="",
),
File(
filename="document1.pdf",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related2",
storage_key="",
),
File(
filename="image2.png",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related3",
storage_key="",
),
File(
filename="audio1.mp3",
- type=FileType.AUDIO,
+ file_type=FileType.AUDIO,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related4",
storage_key="",
@@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node):
def test_get_file_extract_string_func():
# Create a File object
file = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename="test_file.txt",
extension=".txt",
@@ -156,7 +156,7 @@ def test_get_file_extract_string_func():
# Test with empty values
empty_file = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename=None,
extension=None,
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
index 543f9878de..5655f80737 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
@@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables):
inputs=user_inputs,
)
- config = {
- "id": "start",
- "data": StartNodeData(title="Start", variables=variables).model_dump(),
- }
+ node_data = StartNodeData(title="Start", variables=variables)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
@@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables):
)
return StartNode(
- id="start",
- config=config,
+ node_id="start",
+ config=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
@@ -109,7 +106,7 @@ def test_json_object_invalid_json_string():
node = make_start_node(user_inputs, variables)
- with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"):
+ with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"):
node._run()
@@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot():
inputs={"profile": {"age": 20, "name": "Tom"}},
)
- config = {
- "id": "start",
- "data": StartNodeData(
- title="Start",
- variables=[
- VariableEntity(
- variable="profile",
- label="profile",
- type=VariableEntityType.JSON_OBJECT,
- required=True,
- )
- ],
- ).model_dump(),
- }
+ node_data = StartNodeData(
+ title="Start",
+ variables=[
+ VariableEntity(
+ variable="profile",
+ label="profile",
+ type=VariableEntityType.JSON_OBJECT,
+ required=True,
+ )
+ ],
+ )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = StartNode(
- id="start",
- config=config,
+ node_id="start",
+ config=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
index c806181340..284af68319 100644
--- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
@@ -13,6 +13,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import StreamChunkEvent, StreamCompletedEvent
+from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variables.segments import ArrayFileSegment
@@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode:
runtime = _StubToolRuntime()
node = ToolNode(
- id="node-instance",
- config=config,
+ node_id="node-instance",
+ config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
@@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode:
return node
-def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
+def _collect_events(generator: Generator) -> list[Any]:
events: list[Any] = []
try:
while True:
events.append(next(generator))
- except StopIteration as stop:
- return events, stop.value
+ except StopIteration:
+ return events
def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]:
@@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li
node_id=tool_node._node_id,
tool_runtime=ToolRuntimeHandle(raw=object()),
)
- return _collect_events(generator)
+ events = _collect_events(generator)
+ completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)]
+ assert completed_events
+ return events, completed_events[-1].node_run_result.llm_usage
def test_link_messages_with_file_populate_files_output(tool_node: ToolNode):
file_obj = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="file-id",
filename="demo.pdf",
@@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode):
def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode):
file_obj = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="file-id",
filename="demo.pdf",
diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
index c8ddc53284..e3b5e3b591 100644
--- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
@@ -1,10 +1,10 @@
from collections.abc import Mapping
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
+from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
from core.workflow.system_variables import build_system_variables
from graphon.entities import GraphInitParams
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.runtime import GraphRuntimeState
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
@@ -27,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
return init_params, runtime_state
-def _build_node_config() -> NodeConfigDict:
- return NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": TRIGGER_PLUGIN_NODE_TYPE,
- "title": "Trigger Event",
- "plugin_id": "plugin-id",
- "provider_id": "provider-id",
- "event_name": "event-name",
- "subscription_id": "subscription-id",
- "plugin_unique_identifier": "plugin-unique-identifier",
- "event_parameters": {},
- },
- }
+def _build_node_data() -> TriggerEventNodeData:
+ return TriggerEventNodeData(
+ type=TRIGGER_PLUGIN_NODE_TYPE,
+ title="Trigger Event",
+ plugin_id="plugin-id",
+ provider_id="provider-id",
+ event_name="event-name",
+ subscription_id="subscription-id",
+ plugin_unique_identifier="plugin-unique-identifier",
+ event_parameters={},
)
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
init_params, runtime_state = _build_context(graph_config={})
node = TriggerEventNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
index 1bbc12b23f..07d03bec05 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
@@ -30,11 +30,6 @@ def create_webhook_node(
tenant_id: str = "test-tenant",
) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
- node_config = {
- "id": "webhook-node-1",
- "data": webhook_data.model_dump(),
- }
-
graph_init_params = GraphInitParams(
workflow_id="test-workflow",
graph_config={},
@@ -56,8 +51,8 @@ def create_webhook_node(
)
node = TriggerWebhookNode(
- id="webhook-node-1",
- config=node_config,
+ node_id="webhook-node-1",
+ config=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -66,10 +61,6 @@ def create_webhook_node(
runtime_state.app_config = Mock()
runtime_state.app_config.tenant_id = tenant_id
- # Provide compatibility alias expected by node implementation
- # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests
- node.node_id = node.id
-
return node
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
index 427afa96ec..b839490d3c 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
@@ -24,11 +24,6 @@ from tests.workflow_test_utils import build_test_variable_pool
def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
- node_config = {
- "id": "1",
- "data": webhook_data.model_dump(),
- }
-
graph_init_params = GraphInitParams(
workflow_id="1",
graph_config={},
@@ -48,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
start_at=0,
)
node = TriggerWebhookNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -57,9 +52,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
# Provide tenant_id for conversion path
runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})()
- # Compatibility alias for some nodes referencing `self.node_id`
- node.node_id = node.id
-
return node
@@ -225,7 +217,7 @@ def test_webhook_node_run_with_file_params():
"""Test webhook node execution with file parameter extraction."""
# Create mock file objects
file1 = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="image.jpg",
@@ -234,7 +226,7 @@ def test_webhook_node_run_with_file_params():
)
file2 = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file2",
filename="document.pdf",
@@ -269,8 +261,19 @@ def test_webhook_node_run_with_file_params():
# Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id
with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory:
- def _to_file(*, mapping):
- return File.model_validate(mapping)
+ def _to_file(*, mapping: dict[str, Any]) -> File:
+ return File(
+ file_id=mapping.get("id"),
+ file_type=FileType(mapping["type"]),
+ transfer_method=FileTransferMethod(mapping["transfer_method"]),
+ related_id=mapping.get("related_id"),
+ filename=mapping.get("filename"),
+ extension=mapping.get("extension"),
+ mime_type=mapping.get("mime_type"),
+ size=mapping.get("size", -1),
+ storage_key=mapping.get("storage_key", ""),
+ remote_url=mapping.get("url"),
+ )
mock_file_factory.side_effect = _to_file
result = node._run()
@@ -284,7 +287,7 @@ def test_webhook_node_run_with_file_params():
def test_webhook_node_run_mixed_parameters():
"""Test webhook node execution with mixed parameter types."""
file_obj = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="test.jpg",
@@ -317,8 +320,19 @@ def test_webhook_node_run_mixed_parameters():
# Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id
with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory:
- def _to_file(*, mapping):
- return File.model_validate(mapping)
+ def _to_file(*, mapping: dict[str, Any]) -> File:
+ return File(
+ file_id=mapping.get("id"),
+ file_type=FileType(mapping["type"]),
+ transfer_method=FileTransferMethod(mapping["transfer_method"]),
+ related_id=mapping.get("related_id"),
+ filename=mapping.get("filename"),
+ extension=mapping.get("extension"),
+ mime_type=mapping.get("mime_type"),
+ size=mapping.get("size", -1),
+ storage_key=mapping.get("storage_key", ""),
+ remote_url=mapping.get("url"),
+ )
mock_file_factory.side_effect = _to_file
result = node._run()
diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py
new file mode 100644
index 0000000000..8b5fceeb37
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py
@@ -0,0 +1,350 @@
+from types import SimpleNamespace
+
+import pytest
+from pydantic import BaseModel
+
+from core.workflow.human_input_adapter import (
+ DeliveryMethodType,
+ EmailDeliveryConfig,
+ EmailDeliveryMethod,
+ EmailRecipients,
+ WebAppDeliveryMethod,
+ _WebAppDeliveryConfig,
+ adapt_human_input_node_data_for_graph,
+ adapt_node_config_for_graph,
+ adapt_node_data_for_graph,
+ is_human_input_webapp_enabled,
+ parse_human_input_delivery_methods,
+)
+from graphon.enums import BuiltinNodeTypes
+from graphon.nodes.base.variable_template_parser import VariableTemplateParser
+
+
+def test_email_delivery_config_helpers_render_and_sanitize_text() -> None:
+ variable_pool = SimpleNamespace(
+ convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42"))
+ )
+
+ rendered = EmailDeliveryConfig.render_body_template(
+ body="Open {{#url#}} and use {{#node.value#}}",
+ url="https://example.com",
+ variable_pool=variable_pool,
+ )
+ sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team")
+ html = EmailDeliveryConfig.render_markdown_body(
+ "**Hello** [mail](mailto:test@example.com)"
+ )
+
+ assert rendered == "Open https://example.com and use 42"
+ assert sanitized == "Hello alert(1) Team"
+ assert "Hello" in html
+ assert " Team")
- html = EmailDeliveryConfig.render_markdown_body(
- "**Hello** [mail](mailto:test@example.com)"
- )
-
- assert rendered == "Open https://example.com and use 42"
- assert sanitized == "Hello alert(1) Team"
- assert "Hello" in html
- assert ".txt' })
render(, { wrapper: createWrapper() })
- expect(screen.getByText('.txt')).toBeInTheDocument()
+ expect(screen.getByText('.txt'))!.toBeInTheDocument()
})
it('should memoize the component', () => {
@@ -343,7 +344,7 @@ describe('DocumentTableRow', () => {
const { rerender } = render(, { wrapper })
rerender()
- expect(screen.getByRole('row')).toBeInTheDocument()
+ expect(screen.getByRole('row'))!.toBeInTheDocument()
})
})
})
diff --git a/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx
index 5461f34921..0d51837cf2 100644
--- a/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx
+++ b/web/app/components/datasets/documents/components/document-list/components/document-source-icon.tsx
@@ -39,7 +39,7 @@ const getFileExtension = (fileName: string): string => {
const parts = fileName.split('.')
if (parts.length <= 1 || (parts[0] === '' && parts.length === 2))
return ''
- return parts[parts.length - 1].toLowerCase()
+ return parts[parts.length - 1]!.toLowerCase()
}
const DocumentSourceIcon: FC = React.memo(({
diff --git a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts
index 449478eb7b..9eebae4f81 100644
--- a/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts
+++ b/web/app/components/datasets/documents/components/document-list/hooks/__tests__/use-document-actions.spec.ts
@@ -27,7 +27,7 @@ vi.mock('@/service/knowledge/use-document', () => ({
useDocumentDownloadZip: () => ({ mutateAsync: mockDownloadZip, isPending: mockIsDownloadingZip }),
}))
-vi.mock('@/app/components/base/ui/toast', () => ({
+vi.mock('@langgenius/dify-ui/toast', () => ({
toast: {
success: mockToastSuccess,
error: mockToastError,
diff --git a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts
index 8b6c40e2be..a46c4fcfcc 100644
--- a/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts
+++ b/web/app/components/datasets/documents/components/document-list/hooks/use-document-actions.ts
@@ -1,7 +1,7 @@
import type { CommonResponse } from '@/models/common'
+import { toast } from '@langgenius/dify-ui/toast'
import { useCallback, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
-import { toast } from '@/app/components/base/ui/toast'
import { DocumentActionType } from '@/models/datasets'
import {
useDocumentArchive,
diff --git a/web/app/components/datasets/documents/components/documents-header.tsx b/web/app/components/datasets/documents/components/documents-header.tsx
index 4e098f5eda..bd74ad0487 100644
--- a/web/app/components/datasets/documents/components/documents-header.tsx
+++ b/web/app/components/datasets/documents/components/documents-header.tsx
@@ -4,13 +4,13 @@ import type { Item } from '@/app/components/base/select'
import type { BuiltInMetadataItem, MetadataItemWithValueLength } from '@/app/components/datasets/metadata/types'
import type { SortType } from '@/service/datasets'
import { PlusIcon } from '@heroicons/react/24/solid'
+import { Button } from '@langgenius/dify-ui/button'
import { RiDraftLine, RiExternalLinkLine } from '@remixicon/react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import Chip from '@/app/components/base/chip'
import Input from '@/app/components/base/input'
import Sort from '@/app/components/base/sort'
-import { Button } from '@/app/components/base/ui/button'
import AutoDisabledDocument from '@/app/components/datasets/common/document-status-with-action/auto-disabled-document'
import IndexFailed from '@/app/components/datasets/common/document-status-with-action/index-failed'
import StatusWithAction from '@/app/components/datasets/common/document-status-with-action/status-with-action'
diff --git a/web/app/components/datasets/documents/components/empty-element.tsx b/web/app/components/datasets/documents/components/empty-element.tsx
index 6eacf89264..506b8ab6db 100644
--- a/web/app/components/datasets/documents/components/empty-element.tsx
+++ b/web/app/components/datasets/documents/components/empty-element.tsx
@@ -1,8 +1,8 @@
'use client'
import type { FC } from 'react'
import { PlusIcon } from '@heroicons/react/24/solid'
+import { Button } from '@langgenius/dify-ui/button'
import { useTranslation } from 'react-i18next'
-import { Button } from '@/app/components/base/ui/button'
import s from '../style.module.css'
import { FolderPlusIcon, NotionIcon, ThreeDotsIcon } from './icons'
diff --git a/web/app/components/datasets/documents/components/list.tsx b/web/app/components/datasets/documents/components/list.tsx
index e40e4c061b..abd40c33a0 100644
--- a/web/app/components/datasets/documents/components/list.tsx
+++ b/web/app/components/datasets/documents/components/list.tsx
@@ -117,8 +117,8 @@ const DocumentList: FC = ({
return (
-
-
+
+
e.stopPropagation()}>
diff --git a/web/app/components/datasets/documents/components/operations.tsx b/web/app/components/datasets/documents/components/operations.tsx
index e7bbb03c94..8692da927d 100644
--- a/web/app/components/datasets/documents/components/operations.tsx
+++ b/web/app/components/datasets/documents/components/operations.tsx
@@ -1,15 +1,6 @@
import type { OperationName } from '../types'
import type { CommonResponse } from '@/models/common'
import type { DocumentDownloadResponse } from '@/service/datasets'
-import { cn } from '@langgenius/dify-ui/cn'
-import { useBoolean, useDebounceFn } from 'ahooks'
-import { noop } from 'es-toolkit/function'
-import * as React from 'react'
-import { useCallback, useState } from 'react'
-import { useTranslation } from 'react-i18next'
-import Divider from '@/app/components/base/divider'
-import Switch from '@/app/components/base/switch'
-import Tooltip from '@/app/components/base/tooltip'
import {
AlertDialog,
AlertDialogActions,
@@ -18,13 +9,22 @@ import {
AlertDialogContent,
AlertDialogDescription,
AlertDialogTitle,
-} from '@/app/components/base/ui/alert-dialog'
+} from '@langgenius/dify-ui/alert-dialog'
+import { cn } from '@langgenius/dify-ui/cn'
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuTrigger,
-} from '@/app/components/base/ui/dropdown-menu'
-import { toast } from '@/app/components/base/ui/toast'
+} from '@langgenius/dify-ui/dropdown-menu'
+import { Switch } from '@langgenius/dify-ui/switch'
+import { toast } from '@langgenius/dify-ui/toast'
+import { useBoolean, useDebounceFn } from 'ahooks'
+import { noop } from 'es-toolkit/function'
+import * as React from 'react'
+import { useCallback, useState } from 'react'
+import { useTranslation } from 'react-i18next'
+import Divider from '@/app/components/base/divider'
+import Tooltip from '@/app/components/base/tooltip'
import { IS_CE_EDITION } from '@/config'
import { DataSourceType, DocumentActionType } from '@/models/datasets'
import { useRouter } from '@/next/navigation'
diff --git a/web/app/components/datasets/documents/components/rename-modal.tsx b/web/app/components/datasets/documents/components/rename-modal.tsx
index c6f393f1ce..fc4626676b 100644
--- a/web/app/components/datasets/documents/components/rename-modal.tsx
+++ b/web/app/components/datasets/documents/components/rename-modal.tsx
@@ -1,13 +1,13 @@
'use client'
import type { FC } from 'react'
+import { Button } from '@langgenius/dify-ui/button'
+import { toast } from '@langgenius/dify-ui/toast'
import { useBoolean } from 'ahooks'
import * as React from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import Modal from '@/app/components/base/modal'
-import { Button } from '@/app/components/base/ui/button'
-import { toast } from '@/app/components/base/ui/toast'
import { renameDocumentName } from '@/service/datasets'
type Props = {
diff --git a/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx
index 8a2e251770..7daff43a8b 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/__tests__/index.spec.tsx
@@ -569,7 +569,7 @@ describe('StepOneContent', () => {
it('should render VectorSpaceFull when isShowVectorSpaceFull is true', () => {
render()
- expect(screen.getByTestId('vector-space-full')).toBeInTheDocument()
+ expect(screen.getByTestId('vector-space-full'))!.toBeInTheDocument()
})
it('should not render VectorSpaceFull when isShowVectorSpaceFull is false', () => {
@@ -587,7 +587,7 @@ describe('StepOneContent', () => {
localFileListLength={2}
/>,
)
- expect(screen.getByTestId('upgrade-card')).toBeInTheDocument()
+ expect(screen.getByTestId('upgrade-card'))!.toBeInTheDocument()
})
it('should not render UpgradeCard when supportBatchUpload is true', () => {
@@ -618,7 +618,7 @@ describe('StepOneContent', () => {
render()
const nextButton = screen.getByRole('button', { name: /datasetCreation.stepOne.button/i })
- expect(nextButton).toBeDisabled()
+ expect(nextButton)!.toBeDisabled()
})
})
@@ -664,17 +664,17 @@ describe('StepTwoContent', () => {
it('should render ProcessDocuments component', () => {
render()
- expect(screen.getByTestId('process-documents')).toBeInTheDocument()
+ expect(screen.getByTestId('process-documents'))!.toBeInTheDocument()
})
it('should pass dataSourceNodeId to ProcessDocuments', () => {
render()
- expect(screen.getByTestId('datasource-node-id')).toHaveTextContent('custom-node')
+ expect(screen.getByTestId('datasource-node-id'))!.toHaveTextContent('custom-node')
})
it('should pass isRunning to ProcessDocuments', () => {
render()
- expect(screen.getByTestId('is-running')).toHaveTextContent('true')
+ expect(screen.getByTestId('is-running'))!.toHaveTextContent('true')
})
it('should call onProcess when process button is clicked', () => {
@@ -709,18 +709,18 @@ describe('StepThreeContent', () => {
it('should render Processing component', () => {
render()
- expect(screen.getByTestId('processing')).toBeInTheDocument()
+ expect(screen.getByTestId('processing'))!.toBeInTheDocument()
})
it('should pass batchId to Processing', () => {
render()
- expect(screen.getByTestId('batch-id')).toHaveTextContent('batch-123')
+ expect(screen.getByTestId('batch-id'))!.toHaveTextContent('batch-123')
})
it('should pass documents count to Processing', () => {
const documents = [{ id: '1' }, { id: '2' }]
render()
- expect(screen.getByTestId('documents-count')).toHaveTextContent('2')
+ expect(screen.getByTestId('documents-count'))!.toHaveTextContent('2')
})
})
@@ -787,8 +787,8 @@ describe('StepOnePreview', () => {
currentLocalFile={createMockFile()}
/>,
)
- expect(screen.getByTestId('file-preview')).toBeInTheDocument()
- expect(screen.getByTestId('file-name')).toHaveTextContent('test.txt')
+ expect(screen.getByTestId('file-preview'))!.toBeInTheDocument()
+ expect(screen.getByTestId('file-name'))!.toHaveTextContent('test.txt')
})
it('should render OnlineDocumentPreview when currentDocument is set', () => {
@@ -799,7 +799,7 @@ describe('StepOnePreview', () => {
currentDocument={createMockNotionPage()}
/>,
)
- expect(screen.getByTestId('online-document-preview')).toBeInTheDocument()
+ expect(screen.getByTestId('online-document-preview'))!.toBeInTheDocument()
})
it('should render WebsitePreview when currentWebsite is set', () => {
@@ -809,7 +809,7 @@ describe('StepOnePreview', () => {
currentWebsite={createMockCrawlResult()}
/>,
)
- expect(screen.getByTestId('web-preview')).toBeInTheDocument()
+ expect(screen.getByTestId('web-preview'))!.toBeInTheDocument()
})
it('should call hidePreviewLocalFile when hide button is clicked', () => {
@@ -868,22 +868,22 @@ describe('StepTwoPreview', () => {
it('should render ChunkPreview component', () => {
render()
- expect(screen.getByTestId('chunk-preview')).toBeInTheDocument()
+ expect(screen.getByTestId('chunk-preview'))!.toBeInTheDocument()
})
it('should pass datasourceType to ChunkPreview', () => {
render()
- expect(screen.getByTestId('datasource-type')).toHaveTextContent(DatasourceType.onlineDocument)
+ expect(screen.getByTestId('datasource-type'))!.toHaveTextContent(DatasourceType.onlineDocument)
})
it('should pass isIdle to ChunkPreview', () => {
render()
- expect(screen.getByTestId('is-idle')).toHaveTextContent('false')
+ expect(screen.getByTestId('is-idle'))!.toHaveTextContent('false')
})
it('should pass isPendingPreview to ChunkPreview', () => {
render()
- expect(screen.getByTestId('is-pending')).toHaveTextContent('true')
+ expect(screen.getByTestId('is-pending'))!.toHaveTextContent('true')
})
it('should call onPreview when preview button is clicked', () => {
@@ -1092,7 +1092,7 @@ describe('Store Hooks', () => {
mockStoreState.selectedFileIds = ['file-1']
const { result } = renderHook(() => useOnlineDrive())
expect(result.current.selectedOnlineDriveFileList).toHaveLength(1)
- expect(result.current.selectedOnlineDriveFileList[0].id).toBe('file-1')
+ expect(result.current.selectedOnlineDriveFileList[0]!.id).toBe('file-1')
})
})
})
@@ -1166,8 +1166,8 @@ describe('useDatasourceOptions', () => {
const { result } = renderHook(() => useDatasourceOptions(mockNodes))
expect(result.current).toHaveLength(1)
- expect(result.current[0].label).toBe('Local File Source')
- expect(result.current[0].value).toBe('node-1')
+ expect(result.current[0]!.label).toBe('Local File Source')
+ expect(result.current[0]!.value).toBe('node-1')
})
it('should return multiple options for multiple data source nodes', () => {
@@ -1616,7 +1616,7 @@ describe('StepOneContent - All Datasource Types', () => {
datasourceType={DatasourceType.onlineDocument}
/>,
)
- expect(screen.getByTestId('online-documents-component')).toBeInTheDocument()
+ expect(screen.getByTestId('online-documents-component'))!.toBeInTheDocument()
})
it('should render WebsiteCrawl when datasourceType is websiteCrawl', () => {
@@ -1632,7 +1632,7 @@ describe('StepOneContent - All Datasource Types', () => {
datasourceType={DatasourceType.websiteCrawl}
/>,
)
- expect(screen.getByTestId('website-crawl-component')).toBeInTheDocument()
+ expect(screen.getByTestId('website-crawl-component'))!.toBeInTheDocument()
})
it('should render OnlineDrive when datasourceType is onlineDrive', () => {
@@ -1648,7 +1648,7 @@ describe('StepOneContent - All Datasource Types', () => {
datasourceType={DatasourceType.onlineDrive}
/>,
)
- expect(screen.getByTestId('online-drive-component')).toBeInTheDocument()
+ expect(screen.getByTestId('online-drive-component'))!.toBeInTheDocument()
})
it('should render LocalFile when datasourceType is localFile', () => {
@@ -1659,7 +1659,7 @@ describe('StepOneContent - All Datasource Types', () => {
datasourceType={DatasourceType.localFile}
/>,
)
- expect(screen.getByTestId('local-file-component')).toBeInTheDocument()
+ expect(screen.getByTestId('local-file-component'))!.toBeInTheDocument()
})
})
@@ -1690,7 +1690,8 @@ describe('StepTwoPreview - File List Mapping', () => {
)
// ChunkPreview should be rendered
- expect(screen.getByTestId('chunk-preview')).toBeInTheDocument()
+ // ChunkPreview should be rendered
+ expect(screen.getByTestId('chunk-preview'))!.toBeInTheDocument()
})
})
diff --git a/web/app/components/datasets/documents/create-from-pipeline/__tests__/step-indicator.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/__tests__/step-indicator.spec.tsx
index 7103dced26..7bffe3577e 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/__tests__/step-indicator.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/__tests__/step-indicator.spec.tsx
@@ -19,14 +19,14 @@ describe('StepIndicator', () => {
const { container } = render()
const dots = container.querySelectorAll('.rounded-lg')
// Second step (index 1) should be active
- expect(dots[1].className).toContain('bg-state-accent-solid')
- expect(dots[1].className).toContain('w-2')
+ expect(dots[1]!.className).toContain('bg-state-accent-solid')
+ expect(dots[1]!.className).toContain('w-2')
})
it('should not apply active style to non-current steps', () => {
const { container } = render()
const dots = container.querySelectorAll('.rounded-lg')
- expect(dots[1].className).toContain('bg-divider-solid')
- expect(dots[2].className).toContain('bg-divider-solid')
+ expect(dots[1]!.className).toContain('bg-divider-solid')
+ expect(dots[2]!.className).toContain('bg-divider-solid')
})
})
diff --git a/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx
index 64c5faac33..caffed6500 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/actions/index.tsx
@@ -1,9 +1,9 @@
+import { Button } from '@langgenius/dify-ui/button'
import { RiArrowRightLine } from '@remixicon/react'
import * as React from 'react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import Checkbox from '@/app/components/base/checkbox'
-import { Button } from '@/app/components/base/ui/button'
import Link from '@/next/link'
import { useParams } from '@/next/navigation'
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/__tests__/index.spec.tsx
index 0ac2dfce20..78542ad522 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/__tests__/index.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/__tests__/index.spec.tsx
@@ -129,7 +129,7 @@ describe('DatasourceIcon', () => {
it('should render without crashing', () => {
const { container } = render()
- expect(container.firstChild).toBeInTheDocument()
+ expect(container.firstChild)!.toBeInTheDocument()
})
it('should render icon with background image', () => {
@@ -138,15 +138,16 @@ describe('DatasourceIcon', () => {
const { container } = render()
const iconDiv = container.querySelector('[style*="background-image"]')
- expect(iconDiv).toHaveStyle({ backgroundImage: `url(${iconUrl})` })
+ expect(iconDiv)!.toHaveStyle({ backgroundImage: `url(${iconUrl})` })
})
it('should render with default size (sm)', () => {
const { container } = render()
// Assert - Default size is 'sm' which maps to 'w-5 h-5'
- expect(container.firstChild).toHaveClass('w-5')
- expect(container.firstChild).toHaveClass('h-5')
+ // Assert - Default size is 'sm' which maps to 'w-5 h-5'
+ expect(container.firstChild)!.toHaveClass('w-5')
+ expect(container.firstChild)!.toHaveClass('h-5')
})
})
@@ -157,9 +158,9 @@ describe('DatasourceIcon', () => {
,
)
- expect(container.firstChild).toHaveClass('w-4')
- expect(container.firstChild).toHaveClass('h-4')
- expect(container.firstChild).toHaveClass('rounded-[5px]')
+ expect(container.firstChild)!.toHaveClass('w-4')
+ expect(container.firstChild)!.toHaveClass('h-4')
+ expect(container.firstChild)!.toHaveClass('rounded-[5px]')
})
it('should render with sm size', () => {
@@ -167,9 +168,9 @@ describe('DatasourceIcon', () => {
,
)
- expect(container.firstChild).toHaveClass('w-5')
- expect(container.firstChild).toHaveClass('h-5')
- expect(container.firstChild).toHaveClass('rounded-md')
+ expect(container.firstChild)!.toHaveClass('w-5')
+ expect(container.firstChild)!.toHaveClass('h-5')
+ expect(container.firstChild)!.toHaveClass('rounded-md')
})
it('should render with md size', () => {
@@ -177,9 +178,9 @@ describe('DatasourceIcon', () => {
,
)
- expect(container.firstChild).toHaveClass('w-6')
- expect(container.firstChild).toHaveClass('h-6')
- expect(container.firstChild).toHaveClass('rounded-lg')
+ expect(container.firstChild)!.toHaveClass('w-6')
+ expect(container.firstChild)!.toHaveClass('h-6')
+ expect(container.firstChild)!.toHaveClass('rounded-lg')
})
})
@@ -189,7 +190,7 @@ describe('DatasourceIcon', () => {
,
)
- expect(container.firstChild).toHaveClass('custom-class')
+ expect(container.firstChild)!.toHaveClass('custom-class')
})
it('should merge custom className with default classes', () => {
@@ -197,9 +198,9 @@ describe('DatasourceIcon', () => {
,
)
- expect(container.firstChild).toHaveClass('custom-class')
- expect(container.firstChild).toHaveClass('w-5')
- expect(container.firstChild).toHaveClass('h-5')
+ expect(container.firstChild)!.toHaveClass('custom-class')
+ expect(container.firstChild)!.toHaveClass('w-5')
+ expect(container.firstChild)!.toHaveClass('h-5')
})
})
@@ -208,7 +209,7 @@ describe('DatasourceIcon', () => {
const { container } = render()
const iconDiv = container.querySelector('[style*="background-image"]')
- expect(iconDiv).toHaveStyle({ backgroundImage: 'url()' })
+ expect(iconDiv)!.toHaveStyle({ backgroundImage: 'url()' })
})
it('should handle special characters in iconUrl', () => {
@@ -217,7 +218,7 @@ describe('DatasourceIcon', () => {
const { container } = render()
const iconDiv = container.querySelector('[style*="background-image"]')
- expect(iconDiv).toHaveStyle({ backgroundImage: `url(${iconUrl})` })
+ expect(iconDiv)!.toHaveStyle({ backgroundImage: `url(${iconUrl})` })
})
it('should handle data URL as iconUrl', () => {
@@ -226,7 +227,7 @@ describe('DatasourceIcon', () => {
const { container } = render()
const iconDiv = container.querySelector('[style*="background-image"]')
- expect(iconDiv).toBeInTheDocument()
+ expect(iconDiv)!.toBeInTheDocument()
})
})
})
@@ -235,25 +236,26 @@ describe('DatasourceIcon', () => {
it('should have flex container classes', () => {
const { container } = render()
- expect(container.firstChild).toHaveClass('flex')
- expect(container.firstChild).toHaveClass('items-center')
- expect(container.firstChild).toHaveClass('justify-center')
+ expect(container.firstChild)!.toHaveClass('flex')
+ expect(container.firstChild)!.toHaveClass('items-center')
+ expect(container.firstChild)!.toHaveClass('justify-center')
})
it('should have shadow-xs class from size map', () => {
const { container } = render()
// Assert - Default size 'sm' has shadow-xs
- expect(container.firstChild).toHaveClass('shadow-xs')
+ // Assert - Default size 'sm' has shadow-xs
+ expect(container.firstChild)!.toHaveClass('shadow-xs')
})
it('should have inner div with bg-cover class', () => {
const { container } = render()
const innerDiv = container.querySelector('.bg-cover')
- expect(innerDiv).toBeInTheDocument()
- expect(innerDiv).toHaveClass('bg-center')
- expect(innerDiv).toHaveClass('rounded-md')
+ expect(innerDiv)!.toBeInTheDocument()
+ expect(innerDiv)!.toHaveClass('bg-center')
+ expect(innerDiv)!.toHaveClass('rounded-md')
})
})
})
@@ -519,13 +521,13 @@ describe('OptionCard', () => {
it('should render without crashing', () => {
renderWithProviders()
- expect(screen.getByText('Test Option')).toBeInTheDocument()
+ expect(screen.getByText('Test Option'))!.toBeInTheDocument()
})
it('should render label text', () => {
renderWithProviders()
- expect(screen.getByText('Custom Label')).toBeInTheDocument()
+ expect(screen.getByText('Custom Label'))!.toBeInTheDocument()
})
it('should render DatasourceIcon component', () => {
@@ -533,7 +535,7 @@ describe('OptionCard', () => {
// Assert - DatasourceIcon container should exist
const iconContainer = container.querySelector('.size-8')
- expect(iconContainer).toBeInTheDocument()
+ expect(iconContainer)!.toBeInTheDocument()
})
it('should set title attribute for label truncation', () => {
@@ -542,7 +544,7 @@ describe('OptionCard', () => {
renderWithProviders()
const labelElement = screen.getByText(longLabel)
- expect(labelElement).toHaveAttribute('title', longLabel)
+ expect(labelElement)!.toHaveAttribute('title', longLabel)
})
})
@@ -554,8 +556,8 @@ describe('OptionCard', () => {
)
const card = container.firstChild
- expect(card).toHaveClass('border-components-option-card-option-selected-border')
- expect(card).toHaveClass('bg-components-option-card-option-selected-bg')
+ expect(card)!.toHaveClass('border-components-option-card-option-selected-border')
+ expect(card)!.toHaveClass('bg-components-option-card-option-selected-bg')
})
it('should apply unselected styles when selected is false', () => {
@@ -564,22 +566,22 @@ describe('OptionCard', () => {
)
const card = container.firstChild
- expect(card).toHaveClass('border-components-option-card-option-border')
- expect(card).toHaveClass('bg-components-option-card-option-bg')
+ expect(card)!.toHaveClass('border-components-option-card-option-border')
+ expect(card)!.toHaveClass('bg-components-option-card-option-bg')
})
it('should apply text-text-primary to label when selected', () => {
renderWithProviders()
const label = screen.getByText('Test Option')
- expect(label).toHaveClass('text-text-primary')
+ expect(label)!.toHaveClass('text-text-primary')
})
it('should apply text-text-secondary to label when not selected', () => {
renderWithProviders()
const label = screen.getByText('Test Option')
- expect(label).toHaveClass('text-text-secondary')
+ expect(label)!.toHaveClass('text-text-secondary')
})
})
@@ -593,7 +595,7 @@ describe('OptionCard', () => {
// Act - Click on the label text's parent card
const labelElement = screen.getByText('Test Option')
const card = labelElement.closest('[class*="cursor-pointer"]')
- expect(card).toBeInTheDocument()
+ expect(card)!.toBeInTheDocument()
fireEvent.click(card!)
expect(mockOnClick).toHaveBeenCalledTimes(1)
@@ -607,11 +609,12 @@ describe('OptionCard', () => {
// Act - Click on the label text's parent card should not throw
const labelElement = screen.getByText('Test Option')
const card = labelElement.closest('[class*="cursor-pointer"]')
- expect(card).toBeInTheDocument()
+ expect(card)!.toBeInTheDocument()
fireEvent.click(card!)
// Assert - Component should still be rendered
- expect(screen.getByText('Test Option')).toBeInTheDocument()
+ // Assert - Component should still be rendered
+ expect(screen.getByText('Test Option'))!.toBeInTheDocument()
})
})
@@ -631,35 +634,35 @@ describe('OptionCard', () => {
it('should have cursor-pointer class', () => {
const { container } = renderWithProviders()
- expect(container.firstChild).toHaveClass('cursor-pointer')
+ expect(container.firstChild)!.toHaveClass('cursor-pointer')
})
it('should have flex layout classes', () => {
const { container } = renderWithProviders()
- expect(container.firstChild).toHaveClass('flex')
- expect(container.firstChild).toHaveClass('items-center')
- expect(container.firstChild).toHaveClass('gap-2')
+ expect(container.firstChild)!.toHaveClass('flex')
+ expect(container.firstChild)!.toHaveClass('items-center')
+ expect(container.firstChild)!.toHaveClass('gap-2')
})
it('should have rounded-xl border', () => {
const { container } = renderWithProviders()
- expect(container.firstChild).toHaveClass('rounded-xl')
- expect(container.firstChild).toHaveClass('border')
+ expect(container.firstChild)!.toHaveClass('rounded-xl')
+ expect(container.firstChild)!.toHaveClass('border')
})
it('should have padding p-3', () => {
const { container } = renderWithProviders()
- expect(container.firstChild).toHaveClass('p-3')
+ expect(container.firstChild)!.toHaveClass('p-3')
})
it('should have line-clamp-2 for label truncation', () => {
renderWithProviders()
const label = screen.getByText('Test Option')
- expect(label).toHaveClass('line-clamp-2')
+ expect(label)!.toHaveClass('line-clamp-2')
})
})
@@ -669,7 +672,7 @@ describe('OptionCard', () => {
expect(OptionCard).toBeDefined()
// React.memo wraps the component, so we check it renders correctly
const { container } = renderWithProviders()
- expect(container.firstChild).toBeInTheDocument()
+ expect(container.firstChild)!.toBeInTheDocument()
})
})
})
@@ -698,27 +701,27 @@ describe('DataSourceOptions', () => {
it('should render without crashing', () => {
renderWithProviders()
- expect(screen.getByText('Data Source 1')).toBeInTheDocument()
- expect(screen.getByText('Data Source 2')).toBeInTheDocument()
- expect(screen.getByText('Data Source 3')).toBeInTheDocument()
+ expect(screen.getByText('Data Source 1'))!.toBeInTheDocument()
+ expect(screen.getByText('Data Source 2'))!.toBeInTheDocument()
+ expect(screen.getByText('Data Source 3'))!.toBeInTheDocument()
})
it('should render correct number of option cards', () => {
renderWithProviders()
- expect(screen.getByText('Data Source 1')).toBeInTheDocument()
- expect(screen.getByText('Data Source 2')).toBeInTheDocument()
- expect(screen.getByText('Data Source 3')).toBeInTheDocument()
+ expect(screen.getByText('Data Source 1'))!.toBeInTheDocument()
+ expect(screen.getByText('Data Source 2'))!.toBeInTheDocument()
+ expect(screen.getByText('Data Source 3'))!.toBeInTheDocument()
})
it('should render with grid layout', () => {
const { container } = renderWithProviders()
const gridContainer = container.firstChild
- expect(gridContainer).toHaveClass('grid')
- expect(gridContainer).toHaveClass('w-full')
- expect(gridContainer).toHaveClass('grid-cols-4')
- expect(gridContainer).toHaveClass('gap-1')
+ expect(gridContainer)!.toHaveClass('grid')
+ expect(gridContainer)!.toHaveClass('w-full')
+ expect(gridContainer)!.toHaveClass('grid-cols-4')
+ expect(gridContainer)!.toHaveClass('gap-1')
})
it('should render no option cards when options is empty', () => {
@@ -728,16 +731,17 @@ describe('DataSourceOptions', () => {
expect(screen.queryByText('Data Source')).not.toBeInTheDocument()
// Grid container should still exist
- expect(container.firstChild).toHaveClass('grid')
+ // Grid container should still exist
+ expect(container.firstChild)!.toHaveClass('grid')
})
it('should render single option card when only one option exists', () => {
- const singleOption = [createMockDatasourceOption(defaultNodes[0])]
+ const singleOption = [createMockDatasourceOption(defaultNodes[0]!)]
mockUseDatasourceOptions.mockReturnValue(singleOption)
renderWithProviders()
- expect(screen.getByText('Data Source 1')).toBeInTheDocument()
+ expect(screen.getByText('Data Source 1'))!.toBeInTheDocument()
expect(screen.queryByText('Data Source 2')).not.toBeInTheDocument()
})
})
@@ -778,7 +782,7 @@ describe('DataSourceOptions', () => {
// Assert - Check for selected styling on second card
const cards = container.querySelectorAll('.rounded-xl.border')
- expect(cards[1]).toHaveClass('border-components-option-card-option-selected-border')
+ expect(cards[1])!.toHaveClass('border-components-option-card-option-selected-border')
})
it('should show no selection when datasourceNodeId is empty', () => {
@@ -816,7 +820,7 @@ describe('DataSourceOptions', () => {
// Assert initial selection
let cards = container.querySelectorAll('.rounded-xl.border')
- expect(cards[0]).toHaveClass('border-components-option-card-option-selected-border')
+ expect(cards[0])!.toHaveClass('border-components-option-card-option-selected-border')
// Act - Change selection
rerender(
@@ -831,7 +835,7 @@ describe('DataSourceOptions', () => {
// Assert new selection
cards = container.querySelectorAll('.rounded-xl.border')
expect(cards[0]).not.toHaveClass('border-components-option-card-option-selected-border')
- expect(cards[1]).toHaveClass('border-components-option-card-option-selected-border')
+ expect(cards[1])!.toHaveClass('border-components-option-card-option-selected-border')
})
})
@@ -847,7 +851,8 @@ describe('DataSourceOptions', () => {
)
// Assert - Component renders without error
- expect(screen.getByText('Data Source 1')).toBeInTheDocument()
+ // Assert - Component renders without error
+ expect(screen.getByText('Data Source 1'))!.toBeInTheDocument()
})
})
})
@@ -870,7 +875,7 @@ describe('DataSourceOptions', () => {
expect(mockOnSelect).toHaveBeenCalledTimes(1)
expect(mockOnSelect).toHaveBeenCalledWith({
nodeId: 'node-1',
- nodeData: defaultOptions[0].data,
+ nodeData: defaultOptions[0]!.data,
} satisfies Datasource)
})
@@ -948,7 +953,8 @@ describe('DataSourceOptions', () => {
)
// Get initial click handlers
- expect(screen.getByText('Data Source 1')).toBeInTheDocument()
+ // Get initial click handlers
+ expect(screen.getByText('Data Source 1'))!.toBeInTheDocument()
// Trigger clicks to test handlers work
fireEvent.click(screen.getByText('Data Source 1'))
@@ -1003,7 +1009,7 @@ describe('DataSourceOptions', () => {
expect(mockOnSelect2).toHaveBeenCalledTimes(1)
expect(mockOnSelect2).toHaveBeenCalledWith({
nodeId: 'node-3',
- nodeData: defaultOptions[2].data,
+ nodeData: defaultOptions[2]!.data,
})
})
@@ -1022,7 +1028,7 @@ describe('DataSourceOptions', () => {
fireEvent.click(screen.getByText('Data Source 1'))
expect(mockOnSelect).toHaveBeenCalledWith({
nodeId: 'node-1',
- nodeData: defaultOptions[0].data,
+ nodeData: defaultOptions[0]!.data,
})
// Act - Change options
@@ -1045,8 +1051,8 @@ describe('DataSourceOptions', () => {
// Assert - Callback receives new option data
expect(mockOnSelect).toHaveBeenLastCalledWith({
- nodeId: newOptions[0].value,
- nodeData: newOptions[0].data,
+ nodeId: newOptions[0]!.value,
+ nodeData: newOptions[0]!.data,
})
})
})
@@ -1070,7 +1076,7 @@ describe('DataSourceOptions', () => {
expect(mockOnSelect).toHaveBeenCalledTimes(1)
expect(mockOnSelect).toHaveBeenCalledWith({
nodeId: 'node-2',
- nodeData: defaultOptions[1].data,
+ nodeData: defaultOptions[1]!.data,
} satisfies Datasource)
})
@@ -1090,7 +1096,7 @@ describe('DataSourceOptions', () => {
expect(mockOnSelect).toHaveBeenCalledTimes(1)
expect(mockOnSelect).toHaveBeenCalledWith({
nodeId: 'node-1',
- nodeData: defaultOptions[0].data,
+ nodeData: defaultOptions[0]!.data,
})
})
@@ -1112,15 +1118,15 @@ describe('DataSourceOptions', () => {
expect(mockOnSelect).toHaveBeenCalledTimes(3)
expect(mockOnSelect).toHaveBeenNthCalledWith(1, {
nodeId: 'node-1',
- nodeData: defaultOptions[0].data,
+ nodeData: defaultOptions[0]!.data,
})
expect(mockOnSelect).toHaveBeenNthCalledWith(2, {
nodeId: 'node-2',
- nodeData: defaultOptions[1].data,
+ nodeData: defaultOptions[1]!.data,
})
expect(mockOnSelect).toHaveBeenNthCalledWith(3, {
nodeId: 'node-3',
- nodeData: defaultOptions[2].data,
+ nodeData: defaultOptions[2]!.data,
})
})
})
@@ -1164,7 +1170,7 @@ describe('DataSourceOptions', () => {
/>,
)
- expect(container.firstChild).toBeInTheDocument()
+ expect(container.firstChild)!.toBeInTheDocument()
})
it('should not crash when datasourceNodeId is undefined', () => {
@@ -1176,7 +1182,7 @@ describe('DataSourceOptions', () => {
/>,
)
- expect(screen.getByText('Data Source 1')).toBeInTheDocument()
+ expect(screen.getByText('Data Source 1'))!.toBeInTheDocument()
})
})
@@ -1202,7 +1208,7 @@ describe('DataSourceOptions', () => {
renderWithProviders()
- expect(screen.getByText('Minimal Option')).toBeInTheDocument()
+ expect(screen.getByText('Minimal Option'))!.toBeInTheDocument()
})
})
@@ -1219,8 +1225,8 @@ describe('DataSourceOptions', () => {
/>,
)
- expect(screen.getByText('Data Source 1')).toBeInTheDocument()
- expect(screen.getByText('Data Source 50')).toBeInTheDocument()
+ expect(screen.getByText('Data Source 1'))!.toBeInTheDocument()
+ expect(screen.getByText('Data Source 50'))!.toBeInTheDocument()
})
})
@@ -1243,7 +1249,8 @@ describe('DataSourceOptions', () => {
)
// Assert - Special characters should be escaped/rendered safely
- expect(screen.getByText('Data Source ')).toBeInTheDocument()
+ // Assert - Special characters should be escaped/rendered safely
+ expect(screen.getByText('Data Source '))!.toBeInTheDocument()
})
it('should handle unicode characters in option labels', () => {
@@ -1263,7 +1270,7 @@ describe('DataSourceOptions', () => {
/>,
)
- expect(screen.getByText('数据源 📁 Source émoji')).toBeInTheDocument()
+ expect(screen.getByText('数据源 📁 Source émoji'))!.toBeInTheDocument()
})
it('should handle empty string as option value', () => {
@@ -1276,13 +1283,13 @@ describe('DataSourceOptions', () => {
renderWithProviders()
- expect(screen.getByText('Empty Value Option')).toBeInTheDocument()
+ expect(screen.getByText('Empty Value Option'))!.toBeInTheDocument()
})
})
describe('Boundary Conditions', () => {
it('should handle single option selection correctly', () => {
- const singleOption = [createMockDatasourceOption(defaultNodes[0])]
+ const singleOption = [createMockDatasourceOption(defaultNodes[0]!)]
mockUseDatasourceOptions.mockReturnValue(singleOption)
const mockOnSelect = vi.fn()
@@ -1327,7 +1334,7 @@ describe('DataSourceOptions', () => {
const labels = screen.getAllByText('Duplicate Label')
expect(labels).toHaveLength(2)
- fireEvent.click(labels[1])
+ fireEvent.click(labels[1]!)
expect(mockOnSelect).toHaveBeenCalledWith({
nodeId: 'node-b',
nodeData: expect.objectContaining({ plugin_id: 'plugin-b' }),
@@ -1347,6 +1354,37 @@ describe('DataSourceOptions', () => {
unmount()
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
+ // Assert - No errors thrown, component cleanly unmounted
// Assert - No errors thrown, component cleanly unmounted
expect(screen.queryByText('Data Source 1')).not.toBeInTheDocument()
})
@@ -1367,6 +1405,37 @@ describe('DataSourceOptions', () => {
// Unmount during/after interaction
unmount()
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
+ // Assert - Should not throw
// Assert - Should not throw
expect(screen.queryByText('Data Source 1')).not.toBeInTheDocument()
})
@@ -1392,7 +1461,7 @@ describe('DataSourceOptions', () => {
const cards = container.querySelectorAll('.rounded-xl.border')
expect(cards[0]).not.toHaveClass('border-components-option-card-option-selected-border')
- expect(cards[1]).toHaveClass('border-components-option-card-option-selected-border')
+ expect(cards[1])!.toHaveClass('border-components-option-card-option-selected-border')
expect(cards[2]).not.toHaveClass('border-components-option-card-option-selected-border')
})
@@ -1427,7 +1496,7 @@ describe('DataSourceOptions', () => {
/>,
)
- expect(screen.getByText('Data Source 1')).toBeInTheDocument()
+ expect(screen.getByText('Data Source 1'))!.toBeInTheDocument()
})
it.each([
@@ -1449,7 +1518,7 @@ describe('DataSourceOptions', () => {
)
if (count > 0)
- expect(screen.getByText('Data Source 1')).toBeInTheDocument()
+ expect(screen.getByText('Data Source 1'))!.toBeInTheDocument()
else
expect(screen.queryByText('Data Source 1')).not.toBeInTheDocument()
})
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx
index 8e3a29cfe5..51cf34d273 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source-options/index.tsx
@@ -31,7 +31,7 @@ const DataSourceOptions = ({
useEffect(() => {
if (options.length > 0 && !datasourceNodeId)
- handelSelect(options[0].value)
+ handelSelect(options[0]!.value)
}, [])
return (
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/__tests__/header.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/__tests__/header.spec.tsx
index b736935cc8..a6abad358e 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/__tests__/header.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/__tests__/header.spec.tsx
@@ -2,7 +2,7 @@ import { render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import Header from '../header'
-vi.mock('@/app/components/base/ui/button', () => ({
+vi.mock('@langgenius/dify-ui/button', () => ({
Button: ({ children }: { children: React.ReactNode }) => ,
}))
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx
index d595a50fe1..49b0cb0789 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx
@@ -97,8 +97,8 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByTestId('portal-root')).toBeInTheDocument()
- expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
+ expect(screen.getByTestId('portal-root'))!.toBeInTheDocument()
+ expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument()
})
it('should render current credential name in trigger', () => {
@@ -106,7 +106,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText('Credential 1')).toBeInTheDocument()
+ expect(screen.getByText('Credential 1'))!.toBeInTheDocument()
})
it('should render credential icon with correct props', () => {
@@ -116,8 +116,8 @@ describe('CredentialSelector', () => {
// Assert - CredentialIcon renders an img when avatarUrl is provided
const iconImg = container.querySelector('img')
- expect(iconImg).toBeInTheDocument()
- expect(iconImg).toHaveAttribute('src', 'https://example.com/avatar-1.png')
+ expect(iconImg)!.toBeInTheDocument()
+ expect(iconImg)!.toHaveAttribute('src', 'https://example.com/avatar-1.png')
})
it('should render dropdown arrow icon', () => {
@@ -126,7 +126,7 @@ describe('CredentialSelector', () => {
const { container } = render()
const svgIcon = container.querySelector('svg')
- expect(svgIcon).toBeInTheDocument()
+ expect(svgIcon)!.toBeInTheDocument()
})
it('should not render dropdown content initially', () => {
@@ -146,7 +146,8 @@ describe('CredentialSelector', () => {
fireEvent.click(trigger)
// Assert - All credentials should be visible (current credential appears in both trigger and list)
- expect(screen.getByTestId('portal-content')).toBeInTheDocument()
+ // Assert - All credentials should be visible (current credential appears in both trigger and list)
+ expect(screen.getByTestId('portal-content'))!.toBeInTheDocument()
// 3 in dropdown list + 1 in trigger (current) = 4 total
expect(screen.getAllByText(/Credential \d/)).toHaveLength(4)
})
@@ -160,7 +161,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText('Credential 1')).toBeInTheDocument()
+ expect(screen.getByText('Credential 1'))!.toBeInTheDocument()
})
it('should display second credential when currentCredentialId matches second', () => {
@@ -168,7 +169,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText('Credential 2')).toBeInTheDocument()
+ expect(screen.getByText('Credential 2'))!.toBeInTheDocument()
})
it('should display third credential when currentCredentialId matches third', () => {
@@ -176,7 +177,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText('Credential 3')).toBeInTheDocument()
+ expect(screen.getByText('Credential 3'))!.toBeInTheDocument()
})
it.each([
@@ -188,7 +189,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText(expectedName)).toBeInTheDocument()
+ expect(screen.getByText(expectedName))!.toBeInTheDocument()
})
})
@@ -201,7 +202,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText('Test Credential')).toBeInTheDocument()
+ expect(screen.getByText('Test Credential'))!.toBeInTheDocument()
})
it('should render multiple credentials in dropdown', () => {
@@ -226,7 +227,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText('Test & Credential ')).toBeInTheDocument()
+ expect(screen.getByText('Test & Credential '))!.toBeInTheDocument()
})
})
@@ -293,6 +294,37 @@ describe('CredentialSelector', () => {
const props = createDefaultProps()
render()
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
+ // Assert - Initially closed
// Assert - Initially closed
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
@@ -301,7 +333,8 @@ describe('CredentialSelector', () => {
fireEvent.click(trigger)
// Assert - Now open
- expect(screen.getByTestId('portal-content')).toBeInTheDocument()
+ // Assert - Now open
+ expect(screen.getByTestId('portal-content'))!.toBeInTheDocument()
})
it('should call onCredentialChange when clicking a credential item', () => {
@@ -327,7 +360,7 @@ describe('CredentialSelector', () => {
const trigger = screen.getByTestId('portal-trigger')
fireEvent.click(trigger)
- expect(screen.getByTestId('portal-content')).toBeInTheDocument()
+ expect(screen.getByTestId('portal-content'))!.toBeInTheDocument()
const credential2 = screen.getByText('Credential 2')
fireEvent.click(credential2)
@@ -347,7 +380,8 @@ describe('CredentialSelector', () => {
fireEvent.click(trigger)
// Assert - Should not crash
- expect(trigger).toBeInTheDocument()
+ // Assert - Should not crash
+ expect(trigger)!.toBeInTheDocument()
})
it('should allow selecting credentials multiple times', () => {
@@ -504,7 +538,8 @@ describe('CredentialSelector', () => {
render()
// Assert - Should display credential 2
- expect(screen.getByText('Credential 2')).toBeInTheDocument()
+ // Assert - Should display credential 2
+ expect(screen.getByText('Credential 2'))!.toBeInTheDocument()
})
it('should update currentCredential when currentCredentialId changes', () => {
@@ -512,13 +547,15 @@ describe('CredentialSelector', () => {
const { rerender } = render()
// Assert initial
- expect(screen.getByText('Credential 1')).toBeInTheDocument()
+ // Assert initial
+ expect(screen.getByText('Credential 1'))!.toBeInTheDocument()
// Act - Change currentCredentialId
rerender()
// Assert - Should now display credential 3
- expect(screen.getByText('Credential 3')).toBeInTheDocument()
+ // Assert - Should now display credential 3
+ expect(screen.getByText('Credential 3'))!.toBeInTheDocument()
})
it('should update currentCredential when credentials array changes', () => {
@@ -526,7 +563,8 @@ describe('CredentialSelector', () => {
const { rerender } = render()
// Assert initial
- expect(screen.getByText('Credential 1')).toBeInTheDocument()
+ // Assert initial
+ expect(screen.getByText('Credential 1'))!.toBeInTheDocument()
// Act - Change credentials
const newCredentials = [
@@ -535,7 +573,8 @@ describe('CredentialSelector', () => {
rerender()
// Assert - Should display updated name
- expect(screen.getByText('Updated Credential 1')).toBeInTheDocument()
+ // Assert - Should display updated name
+ expect(screen.getByText('Updated Credential 1'))!.toBeInTheDocument()
})
it('should return undefined currentCredential when id not found', () => {
@@ -581,11 +620,12 @@ describe('CredentialSelector', () => {
const { rerender } = render()
// Assert initial
- expect(screen.getByText('Credential 1')).toBeInTheDocument()
+ // Assert initial
+ expect(screen.getByText('Credential 1'))!.toBeInTheDocument()
rerender()
- expect(screen.getByText('Credential 2')).toBeInTheDocument()
+ expect(screen.getByText('Credential 2'))!.toBeInTheDocument()
})
it('should re-render when credentials array reference changes', () => {
@@ -598,7 +638,7 @@ describe('CredentialSelector', () => {
]
rerender()
- expect(screen.getByText('New Name 1')).toBeInTheDocument()
+ expect(screen.getByText('New Name 1'))!.toBeInTheDocument()
})
it('should re-render when onCredentialChange reference changes', () => {
@@ -631,7 +671,8 @@ describe('CredentialSelector', () => {
render()
// Assert - Should render without crashing
- expect(screen.getByTestId('portal-root')).toBeInTheDocument()
+ // Assert - Should render without crashing
+ expect(screen.getByTestId('portal-root'))!.toBeInTheDocument()
})
it('should handle undefined avatar_url in credential', () => {
@@ -648,12 +689,14 @@ describe('CredentialSelector', () => {
const { container } = render()
// Assert - Should render without crashing and show first letter fallback
- expect(screen.getByText('No Avatar Credential')).toBeInTheDocument()
+ // Assert - Should render without crashing and show first letter fallback
+ expect(screen.getByText('No Avatar Credential'))!.toBeInTheDocument()
// When avatar_url is undefined, CredentialIcon shows first letter instead of img
const iconImg = container.querySelector('img')
expect(iconImg).not.toBeInTheDocument()
// First letter 'N' should be displayed
- expect(screen.getByText('N')).toBeInTheDocument()
+ // First letter 'N' should be displayed
+ expect(screen.getByText('N'))!.toBeInTheDocument()
})
it('should handle empty string name in credential', () => {
@@ -669,7 +712,8 @@ describe('CredentialSelector', () => {
render()
// Assert - Should render without crashing
- expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
+ // Assert - Should render without crashing
+ expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument()
})
it('should handle very long credential name', () => {
@@ -685,7 +729,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText(longName)).toBeInTheDocument()
+ expect(screen.getByText(longName))!.toBeInTheDocument()
})
it('should handle special characters in credential name', () => {
@@ -701,7 +745,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText(specialName)).toBeInTheDocument()
+ expect(screen.getByText(specialName))!.toBeInTheDocument()
})
it('should handle numeric id as string', () => {
@@ -716,7 +760,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText('Numeric ID Credential')).toBeInTheDocument()
+ expect(screen.getByText('Numeric ID Credential'))!.toBeInTheDocument()
})
it('should handle large number of credentials', () => {
@@ -728,7 +772,7 @@ describe('CredentialSelector', () => {
render()
- expect(screen.getByText('Credential 50')).toBeInTheDocument()
+ expect(screen.getByText('Credential 50'))!.toBeInTheDocument()
})
it('should handle credential selection with duplicate names', () => {
@@ -752,7 +796,7 @@ describe('CredentialSelector', () => {
const sameNameElements = screen.getAllByText('Same Name')
expect(sameNameElements.length).toBe(3)
- fireEvent.click(sameNameElements[2])
+ fireEvent.click(sameNameElements[2]!)
// Assert - Should call with the correct id even with duplicate names
expect(mockOnChange).toHaveBeenCalledWith('cred-2')
@@ -787,7 +831,8 @@ describe('CredentialSelector', () => {
render()
// Assert - Should render without crashing
- expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
+ // Assert - Should render without crashing
+ expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument()
})
})
@@ -799,7 +844,7 @@ describe('CredentialSelector', () => {
render()
const trigger = screen.getByTestId('portal-trigger')
- expect(trigger).toHaveClass('overflow-hidden')
+ expect(trigger)!.toHaveClass('overflow-hidden')
})
it('should apply grow class to trigger', () => {
@@ -808,7 +853,7 @@ describe('CredentialSelector', () => {
render()
const trigger = screen.getByTestId('portal-trigger')
- expect(trigger).toHaveClass('grow')
+ expect(trigger)!.toHaveClass('grow')
})
it('should apply z-10 class to dropdown content', () => {
@@ -819,7 +864,7 @@ describe('CredentialSelector', () => {
fireEvent.click(trigger)
const content = screen.getByTestId('portal-content')
- expect(content).toHaveClass('z-10')
+ expect(content)!.toHaveClass('z-10')
})
})
@@ -831,7 +876,8 @@ describe('CredentialSelector', () => {
render()
// Assert - Trigger should display the correct credential
- expect(screen.getByText('Credential 2')).toBeInTheDocument()
+ // Assert - Trigger should display the correct credential
+ expect(screen.getByText('Credential 2'))!.toBeInTheDocument()
})
it('should pass isOpen state to Trigger component', () => {
@@ -840,14 +886,15 @@ describe('CredentialSelector', () => {
// Assert - Initially closed
const portalRoot = screen.getByTestId('portal-root')
- expect(portalRoot).toHaveAttribute('data-open', 'false')
+ expect(portalRoot)!.toHaveAttribute('data-open', 'false')
// Act - Open
const trigger = screen.getByTestId('portal-trigger')
fireEvent.click(trigger)
// Assert - Now open
- expect(portalRoot).toHaveAttribute('data-open', 'true')
+ // Assert - Now open
+ expect(portalRoot)!.toHaveAttribute('data-open', 'true')
})
it('should pass credentials to List component', () => {
@@ -899,7 +946,7 @@ describe('CredentialSelector', () => {
const props = createDefaultProps()
render()
- expect(screen.getByTestId('portal-root')).toBeInTheDocument()
+ expect(screen.getByTestId('portal-root'))!.toBeInTheDocument()
})
it('should configure PortalToFollowElem with offset mainAxis 4', () => {
@@ -907,7 +954,7 @@ describe('CredentialSelector', () => {
const props = createDefaultProps()
render()
- expect(screen.getByTestId('portal-root')).toBeInTheDocument()
+ expect(screen.getByTestId('portal-root'))!.toBeInTheDocument()
})
})
})
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx
index 2f14b0f3b8..116b762277 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx
@@ -29,7 +29,7 @@ const CredentialSelector = ({
useEffect(() => {
if (!currentCredential && credentials.length)
- onCredentialChange(credentials[0].id)
+ onCredentialChange(credentials[0]!.id)
}, [currentCredential, credentials])
const handleCredentialChange = useCallback((credentialId: string) => {
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx
index 4d54a04d1f..b162411f6c 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/item.tsx
@@ -31,7 +31,7 @@ const Item = ({
name={name}
size={20}
/>
-
+
{name}
{
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx
index 034556d96f..a285946272 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/base/header.tsx
@@ -1,10 +1,10 @@
import type { CredentialSelectorProps } from './credential-selector'
+import { Button } from '@langgenius/dify-ui/button'
import { RiBookOpenLine, RiEqualizer2Line } from '@remixicon/react'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import Divider from '@/app/components/base/divider'
import Tooltip from '@/app/components/base/tooltip'
-import { Button } from '@/app/components/base/ui/button'
import CredentialSelector from './credential-selector'
type HeaderProps = {
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx
index cc531aad8f..dc20688e9e 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/local-file/hooks/__tests__/use-local-file-upload.spec.tsx
@@ -18,7 +18,7 @@ const { mockNotify, mockToast } = vi.hoisted(() => {
return { mockNotify, mockToast }
})
-vi.mock('@/app/components/base/ui/toast', () => ({
+vi.mock('@langgenius/dify-ui/toast', () => ({
toast: mockToast,
}))
@@ -404,7 +404,7 @@ describe('useLocalFileUpload', () => {
// Should only process first 5 files (batch_count_limit)
const firstCall = mockSetLocalFileList.mock.calls[0]
- expect(firstCall[0].length).toBeLessThanOrEqual(5)
+ expect(firstCall![0].length).toBeLessThanOrEqual(5)
})
})
@@ -591,7 +591,8 @@ describe('useLocalFileUpload', () => {
})
// dragover should not throw
- expect(dropzone).toBeInTheDocument()
+ // dragover should not throw
+ expect(dropzone)!.toBeInTheDocument()
})
it('should set dragging false on dragleave from drag overlay', async () => {
@@ -715,7 +716,7 @@ describe('useLocalFileUpload', () => {
await waitFor(() => {
expect(mockSetLocalFileList).toHaveBeenCalled()
// Should only have 1 file (limited by supportBatchUpload: false)
- const callArgs = mockSetLocalFileList.mock.calls[0][0]
+ const callArgs = mockSetLocalFileList.mock.calls[0]![0]
expect(callArgs.length).toBe(1)
})
})
@@ -873,7 +874,7 @@ describe('useLocalFileUpload', () => {
})
await waitFor(() => {
- const callArgs = mockSetLocalFileList.mock.calls[0][0]
+ const callArgs = mockSetLocalFileList.mock.calls[0]![0]
expect(callArgs[0].progress).toBe(PROGRESS_NOT_STARTED)
})
})
@@ -899,7 +900,7 @@ describe('useLocalFileUpload', () => {
await waitFor(() => {
const calls = mockSetLocalFileList.mock.calls
- const lastCall = calls[calls.length - 1][0]
+ const lastCall = calls[calls.length - 1]![0]
expect(lastCall.some((f: FileItem) => f.progress === PROGRESS_ERROR)).toBe(true)
})
})
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx
index 6be0e28d31..c193638a6a 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/__tests__/index.spec.tsx
@@ -37,8 +37,8 @@ const { mockToastError } = vi.hoisted(() => ({
mockToastError: vi.fn(),
}))
-vi.mock('@/app/components/base/ui/toast', async (importOriginal) => {
- const actual = await importOriginal()
+vi.mock('@langgenius/dify-ui/toast', async (importOriginal) => {
+ const actual = await importOriginal()
return {
...actual,
toast: {
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx
index 5051d343cb..22bc8a65e0 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx
@@ -1,11 +1,11 @@
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import type { DataSourceNotionPageMap, DataSourceNotionWorkspace } from '@/models/common'
import type { DataSourceNodeCompletedResponse, DataSourceNodeErrorResponse } from '@/types/pipeline'
+import { toast } from '@langgenius/dify-ui/toast'
import { useCallback, useEffect, useMemo } from 'react'
import { useShallow } from 'zustand/react/shallow'
import Loading from '@/app/components/base/loading'
import SearchInput from '@/app/components/base/notion-page-selector/search-input'
-import { toast } from '@/app/components/base/ui/toast'
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDocLink } from '@/context/i18n'
@@ -115,7 +115,7 @@ const OnlineDocuments = ({
const handleSelectPages = useCallback((newSelectedPagesId: Set) => {
const { setSelectedPagesId, setOnlineDocuments } = dataSourceStore.getState()
- const selectedPages = Array.from(newSelectedPagesId).map(pageId => PagesMapAndSelectedPagesId[pageId])
+ const selectedPages = Array.from(newSelectedPagesId).map(pageId => PagesMapAndSelectedPagesId[pageId]!)
setSelectedPagesId(new Set(Array.from(newSelectedPagesId)))
setOnlineDocuments(selectedPages)
}, [dataSourceStore, PagesMapAndSelectedPagesId])
@@ -160,7 +160,7 @@ const OnlineDocuments = ({
checkedIds={selectedPagesId}
disabledValue={new Set()}
searchValue={searchValue}
- list={documentsData[0].pages || []}
+ list={documentsData[0]!.pages || []}
pagesMap={PagesMapAndSelectedPagesId}
onSelect={handleSelectPages}
canPreview={!isInPipeline}
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/index.spec.tsx
index a6d5738e2d..04676156e6 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/index.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/index.spec.tsx
@@ -83,7 +83,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
+ expect(screen.getByTestId('virtual-list'))!.toBeInTheDocument()
})
it('should render empty state when list is empty', () => {
@@ -94,7 +94,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument()
+ expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument()
expect(screen.queryByTestId('virtual-list')).not.toBeInTheDocument()
})
@@ -110,8 +110,8 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('Page 1')).toBeInTheDocument()
- expect(screen.getByText('Page 2')).toBeInTheDocument()
+ expect(screen.getByText('Page 1'))!.toBeInTheDocument()
+ expect(screen.getByText('Page 2'))!.toBeInTheDocument()
})
it('should render checkboxes when isMultipleChoice is true', () => {
@@ -119,7 +119,7 @@ describe('PageSelector', () => {
render()
- expect(getCheckbox()).toBeInTheDocument()
+ expect(getCheckbox())!.toBeInTheDocument()
})
it('should render radio buttons when isMultipleChoice is false', () => {
@@ -127,7 +127,7 @@ describe('PageSelector', () => {
render()
- expect(getRadio()).toBeInTheDocument()
+ expect(getRadio())!.toBeInTheDocument()
})
it('should render preview button when canPreview is true', () => {
@@ -135,7 +135,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument()
+ expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument()
})
it('should not render preview button when canPreview is false', () => {
@@ -153,7 +153,7 @@ describe('PageSelector', () => {
// Assert - NotionIcon renders svg when page_icon is null
const notionIcon = document.querySelector('.h-5.w-5')
- expect(notionIcon).toBeInTheDocument()
+ expect(notionIcon)!.toBeInTheDocument()
})
it('should render page name', () => {
@@ -164,7 +164,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('My Custom Page')).toBeInTheDocument()
+ expect(screen.getByText('My Custom Page'))!.toBeInTheDocument()
})
})
@@ -181,7 +181,7 @@ describe('PageSelector', () => {
render()
const checkbox = getCheckbox()
- expect(checkbox).toBeInTheDocument()
+ expect(checkbox)!.toBeInTheDocument()
expect(isCheckboxChecked(checkbox)).toBe(true)
})
@@ -196,7 +196,7 @@ describe('PageSelector', () => {
render()
const checkbox = getCheckbox()
- expect(checkbox).toBeInTheDocument()
+ expect(checkbox)!.toBeInTheDocument()
expect(isCheckboxChecked(checkbox)).toBe(false)
})
@@ -206,7 +206,7 @@ describe('PageSelector', () => {
render()
const checkbox = getCheckbox()
- expect(checkbox).toBeInTheDocument()
+ expect(checkbox)!.toBeInTheDocument()
expect(isCheckboxChecked(checkbox)).toBe(false)
})
@@ -225,9 +225,9 @@ describe('PageSelector', () => {
render()
const checkboxes = getAllCheckboxes()
- expect(isCheckboxChecked(checkboxes[0])).toBe(true)
- expect(isCheckboxChecked(checkboxes[1])).toBe(false)
- expect(isCheckboxChecked(checkboxes[2])).toBe(true)
+ expect(isCheckboxChecked(checkboxes[0]!)).toBe(true)
+ expect(isCheckboxChecked(checkboxes[1]!)).toBe(false)
+ expect(isCheckboxChecked(checkboxes[2]!)).toBe(true)
})
})
@@ -243,7 +243,7 @@ describe('PageSelector', () => {
render()
const checkbox = getCheckbox()
- expect(checkbox).toBeInTheDocument()
+ expect(checkbox)!.toBeInTheDocument()
expect(isCheckboxDisabled(checkbox)).toBe(true)
})
@@ -258,7 +258,7 @@ describe('PageSelector', () => {
render()
const checkbox = getCheckbox()
- expect(checkbox).toBeInTheDocument()
+ expect(checkbox)!.toBeInTheDocument()
expect(isCheckboxDisabled(checkbox)).toBe(false)
})
@@ -276,8 +276,8 @@ describe('PageSelector', () => {
render()
const checkboxes = getAllCheckboxes()
- expect(isCheckboxDisabled(checkboxes[0])).toBe(true)
- expect(isCheckboxDisabled(checkboxes[1])).toBe(false)
+ expect(isCheckboxDisabled(checkboxes[0]!)).toBe(true)
+ expect(isCheckboxDisabled(checkboxes[1]!)).toBe(false)
})
})
@@ -301,6 +301,37 @@ describe('PageSelector', () => {
expect(screen.getAllByText('Apple Page').length).toBeGreaterThan(0)
expect(screen.getAllByText('Apple Pie').length).toBeGreaterThan(0)
// Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
+ // Banana Page is filtered out because it doesn't contain "Apple"
expect(screen.queryByText('Banana Page')).not.toBeInTheDocument()
})
@@ -314,7 +345,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument()
+ expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument()
})
it('should show all pages when searchValue is empty', () => {
@@ -330,8 +361,8 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('Page 1')).toBeInTheDocument()
- expect(screen.getByText('Page 2')).toBeInTheDocument()
+ expect(screen.getByText('Page 1'))!.toBeInTheDocument()
+ expect(screen.getByText('Page 2'))!.toBeInTheDocument()
})
it('should show breadcrumbs when searchValue is present', () => {
@@ -345,7 +376,8 @@ describe('PageSelector', () => {
render()
// Assert - page name should be visible
- expect(screen.getByText('Grandchild 1')).toBeInTheDocument()
+ // Assert - page name should be visible
+ expect(screen.getByText('Grandchild 1'))!.toBeInTheDocument()
})
it('should perform case-sensitive search', () => {
@@ -374,7 +406,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument()
+ expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument()
})
it('should hide preview button when canPreview is false', () => {
@@ -391,7 +423,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument()
+ expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument()
})
})
@@ -401,7 +433,7 @@ describe('PageSelector', () => {
render()
- expect(getCheckbox()).toBeInTheDocument()
+ expect(getCheckbox())!.toBeInTheDocument()
expect(getRadio()).not.toBeInTheDocument()
})
@@ -410,7 +442,7 @@ describe('PageSelector', () => {
render()
- expect(getRadio()).toBeInTheDocument()
+ expect(getRadio())!.toBeInTheDocument()
expect(getCheckbox()).not.toBeInTheDocument()
})
@@ -420,7 +452,7 @@ describe('PageSelector', () => {
render()
- expect(getCheckbox()).toBeInTheDocument()
+ expect(getCheckbox())!.toBeInTheDocument()
})
})
@@ -449,7 +481,7 @@ describe('PageSelector', () => {
render()
fireEvent.click(getCheckbox())
- const calledSet = mockOnSelect.mock.calls[0][0] as Set
+ const calledSet = mockOnSelect.mock.calls[0]![0] as Set
expect(calledSet.has('page-1')).toBe(true)
})
})
@@ -498,13 +530,15 @@ describe('PageSelector', () => {
const { rerender } = render()
// Assert - Initial render
- expect(screen.getByText('Page 1')).toBeInTheDocument()
+ // Assert - Initial render
+ expect(screen.getByText('Page 1'))!.toBeInTheDocument()
// Rerender with new credential
rerender()
// Assert - Should still show pages (reset and rebuild)
- expect(screen.getByText('Page 1')).toBeInTheDocument()
+ // Assert - Should still show pages (reset and rebuild)
+ expect(screen.getByText('Page 1'))!.toBeInTheDocument()
})
})
})
@@ -521,7 +555,39 @@ describe('PageSelector', () => {
render()
// Assert - Only root level page should be visible initially
- expect(screen.getByText(rootPage.page_name)).toBeInTheDocument()
+ // Assert - Only root level page should be visible initially
+ expect(screen.getByText(rootPage.page_name))!.toBeInTheDocument()
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
+ // Child pages should not be visible until expanded
// Child pages should not be visible until expanded
expect(screen.queryByText(childPage1.page_name)).not.toBeInTheDocument()
})
@@ -540,9 +606,9 @@ describe('PageSelector', () => {
if (arrowButton)
fireEvent.click(arrowButton)
- expect(screen.getByText(rootPage.page_name)).toBeInTheDocument()
- expect(screen.getByText(childPage1.page_name)).toBeInTheDocument()
- expect(screen.getByText(childPage2.page_name)).toBeInTheDocument()
+ expect(screen.getByText(rootPage.page_name))!.toBeInTheDocument()
+ expect(screen.getByText(childPage1.page_name))!.toBeInTheDocument()
+ expect(screen.getByText(childPage2.page_name))!.toBeInTheDocument()
})
it('should maintain currentPreviewPageId state', () => {
@@ -560,7 +626,7 @@ describe('PageSelector', () => {
render()
const previewButtons = screen.getAllByText('common.dataSource.notion.selector.preview')
- fireEvent.click(previewButtons[0])
+ fireEvent.click(previewButtons[0]!)
expect(mockOnPreview).toHaveBeenCalledWith('page-1')
})
@@ -596,13 +662,14 @@ describe('PageSelector', () => {
})
const { rerender } = render()
- expect(screen.getByText('Page 1')).toBeInTheDocument()
+ expect(screen.getByText('Page 1'))!.toBeInTheDocument()
// Change credential
rerender()
// Assert - Component should still render correctly
- expect(screen.getByText('Page 1')).toBeInTheDocument()
+ // Assert - Component should still render correctly
+ expect(screen.getByText('Page 1'))!.toBeInTheDocument()
})
it('should filter root pages correctly on initialization', () => {
@@ -615,7 +682,8 @@ describe('PageSelector', () => {
render()
// Assert - Only root level pages visible
- expect(screen.getByText(rootPage.page_name)).toBeInTheDocument()
+ // Assert - Only root level pages visible
+ expect(screen.getByText(rootPage.page_name))!.toBeInTheDocument()
expect(screen.queryByText(childPage1.page_name)).not.toBeInTheDocument()
})
@@ -633,7 +701,8 @@ describe('PageSelector', () => {
render()
// Assert - Orphan page should be visible at root level
- expect(screen.getByText('Orphan Page')).toBeInTheDocument()
+ // Assert - Orphan page should be visible at root level
+ expect(screen.getByText('Orphan Page'))!.toBeInTheDocument()
})
})
@@ -654,8 +723,9 @@ describe('PageSelector', () => {
fireEvent.click(expandArrow)
// Assert - Children should be visible
- expect(screen.getByText(childPage1.page_name)).toBeInTheDocument()
- expect(screen.getByText(childPage2.page_name)).toBeInTheDocument()
+ // Assert - Children should be visible
+ expect(screen.getByText(childPage1.page_name))!.toBeInTheDocument()
+ expect(screen.getByText(childPage2.page_name))!.toBeInTheDocument()
})
it('should have stable handleToggle that collapses descendants', () => {
@@ -675,6 +745,37 @@ describe('PageSelector', () => {
fireEvent.click(expandArrow)
}
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
+ // Assert - Children should be hidden again
// Assert - Children should be hidden again
expect(screen.queryByText(childPage1.page_name)).not.toBeInTheDocument()
expect(screen.queryByText(childPage2.page_name)).not.toBeInTheDocument()
@@ -698,7 +799,7 @@ describe('PageSelector', () => {
// Assert - onSelect should be called with the page and its descendants
expect(mockOnSelect).toHaveBeenCalled()
- const selectedSet = mockOnSelect.mock.calls[0][0] as Set
+ const selectedSet = mockOnSelect.mock.calls[0]![0] as Set
expect(selectedSet.has('root-page')).toBe(true)
})
@@ -752,7 +853,7 @@ describe('PageSelector', () => {
// Assert - Tree structure should be built (verified by expand functionality)
const expandArrow = document.querySelector('[class*="hover:bg-components-button-ghost-bg-hover"]')
- expect(expandArrow).toBeInTheDocument() // Root page has children
+ expect(expandArrow)!.toBeInTheDocument() // Root page has children
})
it('should recompute listMapWithChildrenAndDescendants when list changes', () => {
@@ -763,7 +864,7 @@ describe('PageSelector', () => {
})
const { rerender } = render()
- expect(screen.getByText('Page 1')).toBeInTheDocument()
+ expect(screen.getByText('Page 1'))!.toBeInTheDocument()
// Update with new list
const newList = [
@@ -772,7 +873,7 @@ describe('PageSelector', () => {
]
rerender()
- expect(screen.getByText('Page 1')).toBeInTheDocument()
+ expect(screen.getByText('Page 1'))!.toBeInTheDocument()
// Page 2 won't show because dataList state hasn't updated (only resets on credentialId change)
})
@@ -793,7 +894,8 @@ describe('PageSelector', () => {
rerender()
// Assert - Should not throw
- expect(screen.getByText('Page 1')).toBeInTheDocument()
+ // Assert - Should not throw
+ expect(screen.getByText('Page 1'))!.toBeInTheDocument()
})
it('should handle empty list in memoization', () => {
@@ -804,7 +906,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument()
+ expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument()
})
})
@@ -819,6 +921,37 @@ describe('PageSelector', () => {
render()
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
+ // Initially children are hidden
// Initially children are hidden
expect(screen.queryByText(childPage1.page_name)).not.toBeInTheDocument()
@@ -827,7 +960,8 @@ describe('PageSelector', () => {
fireEvent.click(expandArrow)
// Children become visible
- expect(screen.getByText(childPage1.page_name)).toBeInTheDocument()
+ // Children become visible
+ expect(screen.getByText(childPage1.page_name))!.toBeInTheDocument()
})
it('should check/uncheck page when clicking checkbox', () => {
@@ -873,11 +1007,11 @@ describe('PageSelector', () => {
render()
const radios = getAllRadios()
- fireEvent.click(radios[1]) // Click on page-2
+ fireEvent.click(radios[1]!) // Click on page-2
// Assert - Should clear page-1 and select page-2
expect(mockOnSelect).toHaveBeenCalled()
- const selectedSet = mockOnSelect.mock.calls[0][0] as Set
+ const selectedSet = mockOnSelect.mock.calls[0]![0] as Set
expect(selectedSet.has('page-2')).toBe(true)
expect(selectedSet.has('page-1')).toBe(false)
})
@@ -912,7 +1046,7 @@ describe('PageSelector', () => {
// Assert - Only the clicked page should be selected (no descendants)
expect(mockOnSelect).toHaveBeenCalled()
- const selectedSet = mockOnSelect.mock.calls[0][0] as Set
+ const selectedSet = mockOnSelect.mock.calls[0]![0] as Set
expect(selectedSet.size).toBe(1)
expect(selectedSet.has('root-page')).toBe(true)
})
@@ -927,7 +1061,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('common.dataSource.notion.selector.noSearchResult')).toBeInTheDocument()
+ expect(screen.getByText('common.dataSource.notion.selector.noSearchResult'))!.toBeInTheDocument()
})
it('should handle null page_icon', () => {
@@ -941,7 +1075,7 @@ describe('PageSelector', () => {
// Assert - NotionIcon renders svg (RiFileTextLine) when page_icon is null
const notionIcon = document.querySelector('.h-5.w-5')
- expect(notionIcon).toBeInTheDocument()
+ expect(notionIcon)!.toBeInTheDocument()
})
it('should handle page_icon with all properties', () => {
@@ -956,7 +1090,8 @@ describe('PageSelector', () => {
render()
// Assert - NotionIcon renders the emoji
- expect(screen.getByText('📄')).toBeInTheDocument()
+ // Assert - NotionIcon renders the emoji
+ expect(screen.getByText('📄'))!.toBeInTheDocument()
})
it('should handle empty searchValue correctly', () => {
@@ -964,7 +1099,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
+ expect(screen.getByTestId('virtual-list'))!.toBeInTheDocument()
})
it('should handle special characters in page name', () => {
@@ -976,7 +1111,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('Test ')).toBeInTheDocument()
+ expect(screen.getByText('Test '))!.toBeInTheDocument()
})
it('should handle unicode characters in page name', () => {
@@ -988,7 +1123,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText('测试页面 🔍 привет')).toBeInTheDocument()
+ expect(screen.getByText('测试页面 🔍 привет'))!.toBeInTheDocument()
})
it('should handle very long page names', () => {
@@ -1001,7 +1136,7 @@ describe('PageSelector', () => {
render()
- expect(screen.getByText(longName)).toBeInTheDocument()
+ expect(screen.getByText(longName))!.toBeInTheDocument()
})
it('should handle deeply nested hierarchy', () => {
@@ -1027,7 +1162,8 @@ describe('PageSelector', () => {
render()
// Assert - Only root level visible
- expect(screen.getByText('Level 0')).toBeInTheDocument()
+ // Assert - Only root level visible
+ expect(screen.getByText('Level 0'))!.toBeInTheDocument()
expect(screen.queryByText('Level 1')).not.toBeInTheDocument()
})
@@ -1048,7 +1184,8 @@ describe('PageSelector', () => {
render()
// Assert - Should render the orphan page at root level
- expect(screen.getByText('Orphan Page')).toBeInTheDocument()
+ // Assert - Should render the orphan page at root level
+ expect(screen.getByText('Orphan Page'))!.toBeInTheDocument()
})
it('should handle empty checkedIds Set', () => {
@@ -1057,7 +1194,7 @@ describe('PageSelector', () => {
render()
const checkbox = getCheckbox()
- expect(checkbox).toBeInTheDocument()
+ expect(checkbox)!.toBeInTheDocument()
expect(isCheckboxChecked(checkbox)).toBe(false)
})
@@ -1067,7 +1204,7 @@ describe('PageSelector', () => {
render()
const checkbox = getCheckbox()
- expect(checkbox).toBeInTheDocument()
+ expect(checkbox)!.toBeInTheDocument()
expect(isCheckboxDisabled(checkbox)).toBe(false)
})
@@ -1112,16 +1249,16 @@ describe('PageSelector', () => {
render()
- expect(screen.getByTestId('virtual-list')).toBeInTheDocument()
+ expect(screen.getByTestId('virtual-list'))!.toBeInTheDocument()
if (propVariation.canPreview)
- expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument()
+ expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument()
else
expect(screen.queryByText('common.dataSource.notion.selector.preview')).not.toBeInTheDocument()
if (propVariation.isMultipleChoice)
- expect(getCheckbox()).toBeInTheDocument()
+ expect(getCheckbox())!.toBeInTheDocument()
else
- expect(getRadio()).toBeInTheDocument()
+ expect(getRadio())!.toBeInTheDocument()
})
it('should handle all default prop values', () => {
@@ -1140,8 +1277,9 @@ describe('PageSelector', () => {
render()
// Assert - Defaults should be applied
- expect(getCheckbox()).toBeInTheDocument()
- expect(screen.getByText('common.dataSource.notion.selector.preview')).toBeInTheDocument()
+ // Assert - Defaults should be applied
+ expect(getCheckbox())!.toBeInTheDocument()
+ expect(screen.getByText('common.dataSource.notion.selector.preview'))!.toBeInTheDocument()
})
})
@@ -1166,8 +1304,8 @@ describe('PageSelector', () => {
recursivePushInParentDescendants(pagesMap, listTreeMap, childEntry, childEntry)
expect(listTreeMap.parent).toBeDefined()
- expect(listTreeMap.parent.children.has('child')).toBe(true)
- expect(listTreeMap.parent.descendants.has('child')).toBe(true)
+ expect(listTreeMap.parent!.children.has('child')).toBe(true)
+ expect(listTreeMap.parent!.descendants.has('child')).toBe(true)
expect(childEntry.depth).toBe(1)
expect(childEntry.ancestors).toContain('Parent')
})
@@ -1274,8 +1412,8 @@ describe('PageSelector', () => {
expect(l2Entry.depth).toBe(2)
expect(l2Entry.ancestors).toEqual(['Level 0', 'Level 1'])
- expect(listTreeMap.l1.children.has('l2')).toBe(true)
- expect(listTreeMap.l0.descendants.has('l2')).toBe(true)
+ expect(listTreeMap.l1!.children.has('l2')).toBe(true)
+ expect(listTreeMap.l0!.descendants.has('l2')).toBe(true)
})
it('should update existing parent entry', () => {
@@ -1329,7 +1467,7 @@ describe('PageSelector', () => {
// Assert - Item should have preview styling class
const itemContainer = screen.getByText('Test Page').closest('[class*="group"]')
- expect(itemContainer).toHaveClass('bg-state-base-hover')
+ expect(itemContainer)!.toHaveClass('bg-state-base-hover')
})
it('should show arrow for pages with children', () => {
@@ -1343,7 +1481,7 @@ describe('PageSelector', () => {
// Assert - Root page should have expand arrow
const arrowContainer = document.querySelector('[class*="hover:bg-components-button-ghost-bg-hover"]')
- expect(arrowContainer).toBeInTheDocument()
+ expect(arrowContainer)!.toBeInTheDocument()
})
it('should not show arrow for leaf pages', () => {
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/utils.spec.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/utils.spec.ts
index 2a081ef418..a7175a47de 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/utils.spec.ts
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/page-selector/__tests__/utils.spec.ts
@@ -28,11 +28,11 @@ describe('recursivePushInParentDescendants', () => {
child1: makePageEntry({ page_id: 'child1', parent_id: 'parent1', page_name: 'Child' }),
}
- recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child1, listTreeMap.child1)
+ recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child1!, listTreeMap.child1!)
expect(listTreeMap.parent1).toBeDefined()
- expect(listTreeMap.parent1.children.has('child1')).toBe(true)
- expect(listTreeMap.parent1.descendants.has('child1')).toBe(true)
+ expect(listTreeMap.parent1!.children.has('child1')).toBe(true)
+ expect(listTreeMap.parent1!.descendants.has('child1')).toBe(true)
})
it('should recursively populate ancestors for deeply nested items', () => {
@@ -47,11 +47,11 @@ describe('recursivePushInParentDescendants', () => {
child: makePageEntry({ page_id: 'child', parent_id: 'parent', page_name: 'Child' }),
}
- recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child, listTreeMap.child)
+ recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child!, listTreeMap.child!)
- expect(listTreeMap.child.depth).toBe(2)
- expect(listTreeMap.child.ancestors).toContain('Grandparent')
- expect(listTreeMap.child.ancestors).toContain('Parent')
+ expect(listTreeMap.child!.depth).toBe(2)
+ expect(listTreeMap.child!.ancestors).toContain('Grandparent')
+ expect(listTreeMap.child!.ancestors).toContain('Parent')
})
it('should do nothing for root parent', () => {
@@ -63,7 +63,7 @@ describe('recursivePushInParentDescendants', () => {
root_child: makePageEntry({ page_id: 'root_child', parent_id: 'root', page_name: 'Root Child' }),
}
- recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.root_child, listTreeMap.root_child)
+ recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.root_child!, listTreeMap.root_child!)
// No new entries should be added since parent is root
expect(Object.keys(listTreeMap)).toEqual(['root_child'])
@@ -76,7 +76,7 @@ describe('recursivePushInParentDescendants', () => {
// Should not throw
recursivePushInParentDescendants(pagesMap, listTreeMap, current, current)
- expect(listTreeMap.orphan.depth).toBe(0)
+ expect(listTreeMap.orphan!.depth).toBe(0)
})
it('should add to existing parent entry when parent already in tree', () => {
@@ -91,10 +91,10 @@ describe('recursivePushInParentDescendants', () => {
child2: makePageEntry({ page_id: 'child2', parent_id: 'parent', page_name: 'Child2' }),
}
- recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child2, listTreeMap.child2)
+ recursivePushInParentDescendants(pagesMap, listTreeMap, listTreeMap.child2!, listTreeMap.child2!)
- expect(listTreeMap.parent.children.has('child2')).toBe(true)
- expect(listTreeMap.parent.descendants.has('child2')).toBe(true)
- expect(listTreeMap.parent.children.has('child1')).toBe(true)
+ expect(listTreeMap.parent!.children.has('child2')).toBe(true)
+ expect(listTreeMap.parent!.descendants.has('child2')).toBe(true)
+ expect(listTreeMap.parent!.children.has('child1')).toBe(true)
})
})
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx
index 4f555f3e1f..2d4eb81212 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/title.tsx
@@ -11,7 +11,7 @@ const Title = ({
const { t } = useTranslation()
return (
-
+
{t('onlineDocument.pageSelectorTitle', { ns: 'datasetPipeline', name })}
)
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx
index 7c1941afd9..c8fdf49fd1 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/index.spec.tsx
@@ -49,8 +49,8 @@ const { mockToastError } = vi.hoisted(() => ({
mockToastError: vi.fn(),
}))
-vi.mock('@/app/components/base/ui/toast', async (importOriginal) => {
- const actual = await importOriginal()
+vi.mock('@langgenius/dify-ui/toast', async (importOriginal) => {
+ const actual = await importOriginal()
return {
...actual,
toast: {
@@ -259,8 +259,8 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('header')).toBeInTheDocument()
- expect(screen.getByTestId('file-list')).toBeInTheDocument()
+ expect(screen.getByTestId('header'))!.toBeInTheDocument()
+ expect(screen.getByTestId('file-list'))!.toBeInTheDocument()
})
it('should render Header with correct props', () => {
@@ -271,9 +271,9 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('header-doc-title')).toHaveTextContent('Docs')
- expect(screen.getByTestId('header-plugin-name')).toHaveTextContent('My Online Drive')
- expect(screen.getByTestId('header-credential-id')).toHaveTextContent('cred-123')
+ expect(screen.getByTestId('header-doc-title'))!.toHaveTextContent('Docs')
+ expect(screen.getByTestId('header-plugin-name'))!.toHaveTextContent('My Online Drive')
+ expect(screen.getByTestId('header-credential-id'))!.toHaveTextContent('cred-123')
})
it('should render FileList with correct props', () => {
@@ -290,11 +290,11 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list')).toBeInTheDocument()
- expect(screen.getByTestId('file-list-keywords')).toHaveTextContent('search-term')
- expect(screen.getByTestId('file-list-breadcrumbs')).toHaveTextContent('folder1/folder2')
- expect(screen.getByTestId('file-list-bucket')).toHaveTextContent('my-bucket')
- expect(screen.getByTestId('file-list-selected-count')).toHaveTextContent('2')
+ expect(screen.getByTestId('file-list'))!.toBeInTheDocument()
+ expect(screen.getByTestId('file-list-keywords'))!.toHaveTextContent('search-term')
+ expect(screen.getByTestId('file-list-breadcrumbs'))!.toHaveTextContent('folder1/folder2')
+ expect(screen.getByTestId('file-list-bucket'))!.toHaveTextContent('my-bucket')
+ expect(screen.getByTestId('file-list-selected-count'))!.toHaveTextContent('2')
})
it('should pass docLink with correct path to Header', () => {
@@ -371,7 +371,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('header-plugin-name')).toHaveTextContent('Custom Online Drive')
+ expect(screen.getByTestId('header-plugin-name'))!.toHaveTextContent('Custom Online Drive')
})
})
@@ -411,7 +411,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-is-in-pipeline')).toHaveTextContent('true')
+ expect(screen.getByTestId('file-list-is-in-pipeline'))!.toHaveTextContent('true')
})
})
@@ -421,7 +421,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-support-batch')).toHaveTextContent('true')
+ expect(screen.getByTestId('file-list-support-batch'))!.toHaveTextContent('true')
})
it('should pass supportBatchUpload false to FileList when supportBatchUpload is false', () => {
@@ -429,7 +429,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-support-batch')).toHaveTextContent('false')
+ expect(screen.getByTestId('file-list-support-batch'))!.toHaveTextContent('false')
})
it.each([
@@ -441,7 +441,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-support-batch')).toHaveTextContent(expected)
+ expect(screen.getByTestId('file-list-support-batch'))!.toHaveTextContent(expected)
})
})
@@ -504,7 +504,7 @@ describe('OnlineDrive', () => {
render()
await waitFor(() => {
- expect(screen.getByTestId('file-list-loading')).toHaveTextContent('true')
+ expect(screen.getByTestId('file-list-loading'))!.toHaveTextContent('true')
})
})
@@ -566,7 +566,8 @@ describe('OnlineDrive', () => {
render()
// Assert - filteredOnlineDriveFileList should have 2 items matching 'test'
- expect(screen.getByTestId('file-list-count')).toHaveTextContent('2')
+ // Assert - filteredOnlineDriveFileList should have 2 items matching 'test'
+ expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('2')
})
it('should return all files when keywords is empty', () => {
@@ -580,7 +581,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-count')).toHaveTextContent('3')
+ expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('3')
})
it('should filter files case-insensitively', () => {
@@ -594,7 +595,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-count')).toHaveTextContent('2')
+ expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('2')
})
})
@@ -932,7 +933,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('header-credentials-count')).toHaveTextContent('0')
+ expect(screen.getByTestId('header-credentials-count'))!.toHaveTextContent('0')
})
it('should handle undefined credentials data', () => {
@@ -943,7 +944,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('header-credentials-count')).toHaveTextContent('0')
+ expect(screen.getByTestId('header-credentials-count'))!.toHaveTextContent('0')
})
it('should handle undefined pipelineId', async () => {
@@ -969,7 +970,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-count')).toHaveTextContent('0')
+ expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('0')
})
it('should handle empty breadcrumbs', () => {
@@ -978,7 +979,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-breadcrumbs')).toHaveTextContent('')
+ expect(screen.getByTestId('file-list-breadcrumbs'))!.toHaveTextContent('')
})
it('should handle empty bucket', () => {
@@ -987,7 +988,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-bucket')).toHaveTextContent('')
+ expect(screen.getByTestId('file-list-bucket'))!.toHaveTextContent('')
})
it('should handle special characters in keywords', () => {
@@ -1001,7 +1002,8 @@ describe('OnlineDrive', () => {
render()
// Assert - Should find file with special characters
- expect(screen.getByTestId('file-list-count')).toHaveTextContent('1')
+ // Assert - Should find file with special characters
+ expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('1')
})
it('should handle very long file names', () => {
@@ -1013,7 +1015,7 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('file-list-count')).toHaveTextContent('1')
+ expect(screen.getByTestId('file-list-count'))!.toHaveTextContent('1')
})
it('should handle bucket list initiation response', async () => {
@@ -1051,10 +1053,10 @@ describe('OnlineDrive', () => {
render()
- expect(screen.getByTestId('header')).toBeInTheDocument()
- expect(screen.getByTestId('file-list')).toBeInTheDocument()
- expect(screen.getByTestId('file-list-is-in-pipeline')).toHaveTextContent(String(propVariation.isInPipeline))
- expect(screen.getByTestId('file-list-support-batch')).toHaveTextContent(String(propVariation.supportBatchUpload))
+ expect(screen.getByTestId('header'))!.toBeInTheDocument()
+ expect(screen.getByTestId('file-list'))!.toBeInTheDocument()
+ expect(screen.getByTestId('file-list-is-in-pipeline'))!.toHaveTextContent(String(propVariation.isInPipeline))
+ expect(screen.getByTestId('file-list-support-batch'))!.toHaveTextContent(String(propVariation.supportBatchUpload))
})
it.each([
@@ -1117,7 +1119,7 @@ describe('Header', () => {
render()
- expect(screen.getByText('Documentation')).toBeInTheDocument()
+ expect(screen.getByText('Documentation'))!.toBeInTheDocument()
})
it('should render doc link with correct href', () => {
@@ -1129,9 +1131,9 @@ describe('Header', () => {
render()
const link = screen.getByRole('link')
- expect(link).toHaveAttribute('href', 'https://custom-docs.com/path')
- expect(link).toHaveAttribute('target', '_blank')
- expect(link).toHaveAttribute('rel', 'noopener noreferrer')
+ expect(link)!.toHaveAttribute('href', 'https://custom-docs.com/path')
+ expect(link)!.toHaveAttribute('target', '_blank')
+ expect(link)!.toHaveAttribute('rel', 'noopener noreferrer')
})
it('should render doc title text', () => {
@@ -1139,7 +1141,7 @@ describe('Header', () => {
render()
- expect(screen.getByText('My Documentation Title')).toBeInTheDocument()
+ expect(screen.getByText('My Documentation Title'))!.toBeInTheDocument()
})
it('should render configuration button', () => {
@@ -1147,7 +1149,7 @@ describe('Header', () => {
render()
- expect(screen.getByRole('button')).toBeInTheDocument()
+ expect(screen.getByRole('button'))!.toBeInTheDocument()
})
})
@@ -1164,7 +1166,7 @@ describe('Header', () => {
render()
if (docTitle)
- expect(screen.getByText(docTitle)).toBeInTheDocument()
+ expect(screen.getByText(docTitle))!.toBeInTheDocument()
})
})
@@ -1178,7 +1180,7 @@ describe('Header', () => {
render()
- expect(screen.getByRole('link')).toHaveAttribute('href', docLink)
+ expect(screen.getByRole('link'))!.toHaveAttribute('href', docLink)
})
})
@@ -1209,7 +1211,7 @@ describe('Header', () => {
render()
const titleSpan = screen.getByTitle('Accessible Title')
- expect(titleSpan).toBeInTheDocument()
+ expect(titleSpan)!.toBeInTheDocument()
})
})
})
@@ -1437,10 +1439,10 @@ describe('utils', () => {
const result = convertOnlineDriveData(data, [], 'my-bucket')
expect(result.fileList).toHaveLength(4)
- expect(result.fileList[0].type).toBe(OnlineDriveFileType.folder)
- expect(result.fileList[1].type).toBe(OnlineDriveFileType.file)
- expect(result.fileList[2].type).toBe(OnlineDriveFileType.folder)
- expect(result.fileList[3].type).toBe(OnlineDriveFileType.file)
+ expect(result.fileList[0]!.type).toBe(OnlineDriveFileType.folder)
+ expect(result.fileList[1]!.type).toBe(OnlineDriveFileType.file)
+ expect(result.fileList[2]!.type).toBe(OnlineDriveFileType.folder)
+ expect(result.fileList[3]!.type).toBe(OnlineDriveFileType.file)
})
})
@@ -1539,7 +1541,7 @@ describe('utils', () => {
const result = convertOnlineDriveData(data, [], 'my-bucket')
- expect(result.fileList[0].size).toBe(0)
+ expect(result.fileList[0]!.size).toBe(0)
})
it('should handle files with very large size', () => {
@@ -1555,7 +1557,7 @@ describe('utils', () => {
const result = convertOnlineDriveData(data, [], 'my-bucket')
- expect(result.fileList[0].size).toBe(largeSize)
+ expect(result.fileList[0]!.size).toBe(largeSize)
})
it('should handle files with special characters in name', () => {
@@ -1574,9 +1576,9 @@ describe('utils', () => {
const result = convertOnlineDriveData(data, [], 'my-bucket')
- expect(result.fileList[0].name).toBe('file[1] (copy).txt')
- expect(result.fileList[1].name).toBe('doc-with-dash_and_underscore.pdf')
- expect(result.fileList[2].name).toBe('file with spaces.txt')
+ expect(result.fileList[0]!.name).toBe('file[1] (copy).txt')
+ expect(result.fileList[1]!.name).toBe('doc-with-dash_and_underscore.pdf')
+ expect(result.fileList[2]!.name).toBe('file with spaces.txt')
})
it('should handle complex next_page_parameters', () => {
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/utils.spec.ts b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/utils.spec.ts
index 7c5761be8a..9ac2ef9f89 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/utils.spec.ts
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/__tests__/utils.spec.ts
@@ -93,10 +93,10 @@ describe('online-drive utils', () => {
const result = convertOnlineDriveData(data, [], 'bucket-1')
expect(result.fileList).toHaveLength(2)
- expect(result.fileList[0].type).toBe(OnlineDriveFileType.file)
- expect(result.fileList[0].size).toBe(100)
- expect(result.fileList[1].type).toBe(OnlineDriveFileType.folder)
- expect(result.fileList[1].size).toBeUndefined()
+ expect(result.fileList[0]!.type).toBe(OnlineDriveFileType.file)
+ expect(result.fileList[0]!.size).toBe(100)
+ expect(result.fileList[1]!.type).toBe(OnlineDriveFileType.folder)
+ expect(result.fileList[1]!.size).toBeUndefined()
expect(result.isTruncated).toBe(true)
expect(result.nextPageParameters).toEqual({ token: 'next' })
expect(result.hasBucket).toBe(true)
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/connect/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/connect/index.tsx
index 5b1b0a6b1a..6a7190161d 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/connect/index.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/connect/index.tsx
@@ -1,7 +1,7 @@
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
+import { Button } from '@langgenius/dify-ui/button'
import { useTranslation } from 'react-i18next'
import { Icon3Dots } from '@/app/components/base/icons/src/vender/line/others'
-import { Button } from '@/app/components/base/ui/button'
import BlockIcon from '@/app/components/workflow/block-icon'
import { useToolIcon } from '@/app/components/workflow/hooks'
import { BlockEnum } from '@/app/components/workflow/types'
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/__tests__/index.spec.tsx
index 07308361ad..dcb1922fe9 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/__tests__/index.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/__tests__/index.spec.tsx
@@ -60,7 +60,8 @@ describe('Header', () => {
render()
// Assert - search input should be visible
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ // Assert - search input should be visible
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should render with correct container styles', () => {
@@ -70,12 +71,12 @@ describe('Header', () => {
// Assert - container should have correct class names
const wrapper = container.firstChild as HTMLElement
- expect(wrapper).toHaveClass('flex')
- expect(wrapper).toHaveClass('items-center')
- expect(wrapper).toHaveClass('gap-x-2')
- expect(wrapper).toHaveClass('bg-components-panel-bg')
- expect(wrapper).toHaveClass('p-1')
- expect(wrapper).toHaveClass('pl-3')
+ expect(wrapper)!.toHaveClass('flex')
+ expect(wrapper)!.toHaveClass('items-center')
+ expect(wrapper)!.toHaveClass('gap-x-2')
+ expect(wrapper)!.toHaveClass('bg-components-panel-bg')
+ expect(wrapper)!.toHaveClass('p-1')
+ expect(wrapper)!.toHaveClass('pl-3')
})
it('should render Input component with correct props', () => {
@@ -84,8 +85,8 @@ describe('Header', () => {
render()
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
- expect(input).toBeInTheDocument()
- expect(input).toHaveValue('test-value')
+ expect(input)!.toBeInTheDocument()
+ expect(input)!.toHaveValue('test-value')
})
it('should render Input with search icon', () => {
@@ -95,7 +96,7 @@ describe('Header', () => {
// Assert - Input should have search icon class
const searchIcon = container.querySelector('.i-ri-search-line.h-4.w-4')
- expect(searchIcon).toBeInTheDocument()
+ expect(searchIcon)!.toBeInTheDocument()
})
it('should render Input with correct wrapper width', () => {
@@ -105,7 +106,7 @@ describe('Header', () => {
// Assert - Input wrapper should have w-[200px] class
const inputWrapper = container.querySelector('.w-\\[200px\\]')
- expect(inputWrapper).toBeInTheDocument()
+ expect(inputWrapper)!.toBeInTheDocument()
})
})
@@ -117,7 +118,7 @@ describe('Header', () => {
render()
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
- expect(input).toHaveValue('')
+ expect(input)!.toHaveValue('')
})
it('should display input value correctly', () => {
@@ -126,7 +127,7 @@ describe('Header', () => {
render()
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
- expect(input).toHaveValue('search-query')
+ expect(input)!.toHaveValue('search-query')
})
it('should handle special characters in inputValue', () => {
@@ -136,7 +137,7 @@ describe('Header', () => {
render()
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
- expect(input).toHaveValue(specialChars)
+ expect(input)!.toHaveValue(specialChars)
})
it('should handle unicode characters in inputValue', () => {
@@ -146,7 +147,7 @@ describe('Header', () => {
render()
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
- expect(input).toHaveValue(unicodeValue)
+ expect(input)!.toHaveValue(unicodeValue)
})
})
@@ -157,7 +158,8 @@ describe('Header', () => {
render()
// Assert - Component should render without errors
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ // Assert - Component should render without errors
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should render with single breadcrumb', () => {
@@ -165,7 +167,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should render with multiple breadcrumbs', () => {
@@ -173,7 +175,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
})
@@ -184,7 +186,8 @@ describe('Header', () => {
render()
// Assert - keywords are passed through, component renders
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ // Assert - keywords are passed through, component renders
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
})
@@ -194,7 +197,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should render with bucket value', () => {
@@ -202,7 +205,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
})
@@ -212,7 +215,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should handle positive search results', () => {
@@ -221,7 +224,8 @@ describe('Header', () => {
render()
// Assert - Breadcrumbs will show search results text when keywords exist and results > 0
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ // Assert - Breadcrumbs will show search results text when keywords exist and results > 0
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should handle large search results count', () => {
@@ -229,7 +233,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
})
@@ -239,7 +243,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should render correctly when isInPipeline is true', () => {
@@ -247,7 +251,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
})
})
@@ -265,7 +269,7 @@ describe('Header', () => {
expect(mockHandleInputChange).toHaveBeenCalledTimes(1)
// Verify that onChange event was triggered (React's synthetic event structure)
- expect(mockHandleInputChange.mock.calls[0][0]).toHaveProperty('type', 'change')
+ expect(mockHandleInputChange.mock.calls[0]![0]).toHaveProperty('type', 'change')
})
it('should call handleInputChange on each keystroke', () => {
@@ -290,7 +294,7 @@ describe('Header', () => {
fireEvent.change(input, { target: { value: '' } })
expect(mockHandleInputChange).toHaveBeenCalledTimes(1)
- expect(mockHandleInputChange.mock.calls[0][0]).toHaveProperty('type', 'change')
+ expect(mockHandleInputChange.mock.calls[0]![0]).toHaveProperty('type', 'change')
})
it('should handle whitespace-only input', () => {
@@ -302,7 +306,7 @@ describe('Header', () => {
fireEvent.change(input, { target: { value: ' ' } })
expect(mockHandleInputChange).toHaveBeenCalledTimes(1)
- expect(mockHandleInputChange.mock.calls[0][0]).toHaveProperty('type', 'change')
+ expect(mockHandleInputChange.mock.calls[0]![0]).toHaveProperty('type', 'change')
})
})
@@ -317,7 +321,7 @@ describe('Header', () => {
// Act - Find and click the clear icon container
const clearButton = screen.getByTestId('input-clear')
- expect(clearButton).toBeInTheDocument()
+ expect(clearButton)!.toBeInTheDocument()
fireEvent.click(clearButton!)
expect(mockHandleResetKeywords).toHaveBeenCalledTimes(1)
@@ -338,7 +342,7 @@ describe('Header', () => {
// Act & Assert - Clear icon should be visible
const clearIcon = screen.getByTestId('input-clear')
- expect(clearIcon).toBeInTheDocument()
+ expect(clearIcon)!.toBeInTheDocument()
})
})
})
@@ -365,21 +369,23 @@ describe('Header', () => {
rerender()
// Assert - Component renders without errors
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ // Assert - Component renders without errors
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should re-render when inputValue changes', () => {
const props = createDefaultProps({ inputValue: 'initial' })
const { rerender } = render()
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
- expect(input).toHaveValue('initial')
+ expect(input)!.toHaveValue('initial')
// Act - Rerender with different inputValue
const newProps = createDefaultProps({ inputValue: 'changed' })
rerender()
// Assert - Input value should be updated
- expect(input).toHaveValue('changed')
+ // Assert - Input value should be updated
+ expect(input)!.toHaveValue('changed')
})
it('should re-render when breadcrumbs change', () => {
@@ -391,7 +397,8 @@ describe('Header', () => {
rerender()
// Assert - Component renders without errors
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ // Assert - Component renders without errors
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should re-render when keywords change', () => {
@@ -403,7 +410,8 @@ describe('Header', () => {
rerender()
// Assert - Component renders without errors
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ // Assert - Component renders without errors
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
})
@@ -415,7 +423,7 @@ describe('Header', () => {
render()
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
- expect(input).toHaveValue(longValue)
+ expect(input)!.toHaveValue(longValue)
})
it('should handle very long breadcrumb paths', () => {
@@ -424,7 +432,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should handle breadcrumbs with special characters', () => {
@@ -433,7 +441,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should handle breadcrumbs with unicode names', () => {
@@ -442,7 +450,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should handle bucket with special characters', () => {
@@ -450,7 +458,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should pass the event object to handleInputChange callback', () => {
@@ -463,7 +471,7 @@ describe('Header', () => {
// Assert - Verify the event object is passed correctly
expect(mockHandleInputChange).toHaveBeenCalledTimes(1)
- const eventArg = mockHandleInputChange.mock.calls[0][0]
+ const eventArg = mockHandleInputChange.mock.calls[0]![0]
expect(eventArg).toHaveProperty('type', 'change')
expect(eventArg).toHaveProperty('target')
})
@@ -480,7 +488,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it.each([
@@ -493,7 +501,7 @@ describe('Header', () => {
render()
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it.each([
@@ -507,7 +515,7 @@ describe('Header', () => {
render()
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
- expect(input).toHaveValue(inputValue)
+ expect(input)!.toHaveValue(inputValue)
})
})
@@ -525,7 +533,8 @@ describe('Header', () => {
render()
// Assert - Component should render successfully, meaning props are passed correctly
- expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')).toBeInTheDocument()
+ // Assert - Component should render successfully, meaning props are passed correctly
+ expect(screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder'))!.toBeInTheDocument()
})
it('should pass correct props to Input component', () => {
@@ -540,7 +549,7 @@ describe('Header', () => {
render()
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
- expect(input).toHaveValue('test-input')
+ expect(input)!.toHaveValue('test-input')
// Test onChange handler
fireEvent.change(input, { target: { value: 'new-value' } })
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/bucket.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/bucket.spec.tsx
index c407be51ac..83e17e6e04 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/bucket.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/bucket.spec.tsx
@@ -22,18 +22,18 @@ describe('Bucket', () => {
it('should render bucket name', () => {
render()
- expect(screen.getByText('my-bucket')).toBeInTheDocument()
+ expect(screen.getByText('my-bucket'))!.toBeInTheDocument()
})
it('should render bucket icon', () => {
render()
- expect(screen.getByTestId('buckets-gray')).toBeInTheDocument()
+ expect(screen.getByTestId('buckets-gray'))!.toBeInTheDocument()
})
it('should call handleBackToBucketList on icon button click', () => {
render()
const buttons = screen.getAllByRole('button')
- fireEvent.click(buttons[0])
+ fireEvent.click(buttons[0]!)
expect(defaultProps.handleBackToBucketList).toHaveBeenCalledOnce()
})
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/index.spec.tsx
index a6aaf3a50b..906a9e01e0 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/index.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/__tests__/index.spec.tsx
@@ -1,4 +1,4 @@
-import { fireEvent, render, screen, waitFor } from '@testing-library/react'
+import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import Breadcrumbs from '../index'
@@ -44,6 +44,16 @@ const resetMockStoreState = () => {
mockStoreState.setBucket = vi.fn()
}
+const getDropdownTrigger = () => {
+ return document.querySelector('[aria-haspopup="menu"]') as HTMLElement | null
+}
+
+const openCollapsedBreadcrumbDropdown = () => {
+ const dropdownTrigger = getDropdownTrigger()
+ expect(dropdownTrigger).toBeInTheDocument()
+ fireEvent.click(dropdownTrigger as HTMLElement)
+}
+
describe('Breadcrumbs', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -58,7 +68,7 @@ describe('Breadcrumbs', () => {
// Assert - Container should be in the document
const container = document.querySelector('.flex.grow')
- expect(container).toBeInTheDocument()
+ expect(container)!.toBeInTheDocument()
})
it('should render with correct container styles', () => {
@@ -67,10 +77,10 @@ describe('Breadcrumbs', () => {
const { container } = render()
const wrapper = container.firstChild as HTMLElement
- expect(wrapper).toHaveClass('flex')
- expect(wrapper).toHaveClass('grow')
- expect(wrapper).toHaveClass('items-center')
- expect(wrapper).toHaveClass('overflow-hidden')
+ expect(wrapper)!.toHaveClass('flex')
+ expect(wrapper)!.toHaveClass('grow')
+ expect(wrapper)!.toHaveClass('items-center')
+ expect(wrapper)!.toHaveClass('overflow-hidden')
})
describe('Search Results Display', () => {
@@ -84,7 +94,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Search result text should be displayed
- expect(screen.getByText(/datasetPipeline\.onlineDrive\.breadcrumbs\.searchResult/)).toBeInTheDocument()
+ // Assert - Search result text should be displayed
+ expect(screen.getByText(/datasetPipeline\.onlineDrive\.breadcrumbs\.searchResult/))!.toBeInTheDocument()
})
it('should not show search results when keywords is empty', () => {
@@ -121,7 +132,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Should use bucket name in search result
- expect(screen.getByText(/searchResult.*my-bucket/i)).toBeInTheDocument()
+ // Assert - Should use bucket name in search result
+ expect(screen.getByText(/searchResult.*my-bucket/i))!.toBeInTheDocument()
})
it('should use last breadcrumb as folderName when breadcrumbs exist', () => {
@@ -135,7 +147,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Should use last breadcrumb in search result
- expect(screen.getByText(/searchResult.*folder2/i)).toBeInTheDocument()
+ // Assert - Should use last breadcrumb in search result
+ expect(screen.getByText(/searchResult.*folder2/i))!.toBeInTheDocument()
})
})
@@ -150,7 +163,7 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets')).toBeInTheDocument()
+ expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets'))!.toBeInTheDocument()
})
it('should not show all buckets title when breadcrumbs exist', () => {
@@ -174,6 +187,37 @@ describe('Breadcrumbs', () => {
render()
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
+ // Assert - Should show bucket name instead
// Assert - Should show bucket name instead
expect(screen.queryByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets')).not.toBeInTheDocument()
})
@@ -190,7 +234,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Bucket name should be displayed
- expect(screen.getByText('test-bucket')).toBeInTheDocument()
+ // Assert - Bucket name should be displayed
+ expect(screen.getByText('test-bucket'))!.toBeInTheDocument()
})
it('should not render Bucket when hasBucket is false', () => {
@@ -202,6 +247,37 @@ describe('Breadcrumbs', () => {
render()
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
+ // Assert - Bucket should not be displayed, Drive should be shown instead
// Assert - Bucket should not be displayed, Drive should be shown instead
expect(screen.queryByText('test-bucket')).not.toBeInTheDocument()
})
@@ -217,7 +293,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - "All Files" should be displayed
- expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allFiles')).toBeInTheDocument()
+ // Assert - "All Files" should be displayed
+ expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allFiles'))!.toBeInTheDocument()
})
it('should not render Drive component when hasBucket is true', () => {
@@ -243,8 +320,8 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText('folder1')).toBeInTheDocument()
- expect(screen.getByText('folder2')).toBeInTheDocument()
+ expect(screen.getByText('folder1'))!.toBeInTheDocument()
+ expect(screen.getByText('folder2'))!.toBeInTheDocument()
})
it('should render last breadcrumb as active', () => {
@@ -257,8 +334,8 @@ describe('Breadcrumbs', () => {
// Assert - Last breadcrumb should have active styles
const lastBreadcrumb = screen.getByText('folder2')
- expect(lastBreadcrumb).toHaveClass('system-sm-medium')
- expect(lastBreadcrumb).toHaveClass('text-text-secondary')
+ expect(lastBreadcrumb)!.toHaveClass('system-sm-medium')
+ expect(lastBreadcrumb)!.toHaveClass('text-text-secondary')
})
it('should render non-last breadcrumbs with tertiary styles', () => {
@@ -271,8 +348,8 @@ describe('Breadcrumbs', () => {
// Assert - First breadcrumb should have tertiary styles
const firstBreadcrumb = screen.getByText('folder1')
- expect(firstBreadcrumb).toHaveClass('system-sm-regular')
- expect(firstBreadcrumb).toHaveClass('text-text-tertiary')
+ expect(firstBreadcrumb)!.toHaveClass('system-sm-regular')
+ expect(firstBreadcrumb)!.toHaveClass('text-text-tertiary')
})
})
@@ -287,7 +364,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Dropdown trigger (more button) should be present
- expect(screen.getByRole('button', { name: '' })).toBeInTheDocument()
+ // Assert - Dropdown trigger (more button) should be present
+ expect(screen.getByRole('button', { name: '' }))!.toBeInTheDocument()
})
it('should not show dropdown when breadcrumbs do not exceed displayBreadcrumbNum', () => {
@@ -301,8 +379,10 @@ describe('Breadcrumbs', () => {
// Assert - Should not have dropdown, just regular breadcrumbs
// All breadcrumbs should be directly visible
- expect(screen.getByText('folder1')).toBeInTheDocument()
- expect(screen.getByText('folder2')).toBeInTheDocument()
+ // Assert - Should not have dropdown, just regular breadcrumbs
+ // All breadcrumbs should be directly visible
+ expect(screen.getByText('folder1'))!.toBeInTheDocument()
+ expect(screen.getByText('folder2'))!.toBeInTheDocument()
// Count buttons - should be 3 (allFiles + folder1 + folder2)
const buttons = container.querySelectorAll('button')
expect(buttons.length).toBe(3)
@@ -318,9 +398,41 @@ describe('Breadcrumbs', () => {
render()
// Assert - First breadcrumb and last breadcrumb should be visible
- expect(screen.getByText('folder1')).toBeInTheDocument()
- expect(screen.getByText('folder2')).toBeInTheDocument()
- expect(screen.getByText('folder5')).toBeInTheDocument()
+ // Assert - First breadcrumb and last breadcrumb should be visible
+ expect(screen.getByText('folder1'))!.toBeInTheDocument()
+ expect(screen.getByText('folder2'))!.toBeInTheDocument()
+ expect(screen.getByText('folder5'))!.toBeInTheDocument()
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
+ // Middle breadcrumbs should be in dropdown
// Middle breadcrumbs should be in dropdown
expect(screen.queryByText('folder3')).not.toBeInTheDocument()
expect(screen.queryByText('folder4')).not.toBeInTheDocument()
@@ -335,15 +447,11 @@ describe('Breadcrumbs', () => {
render()
// Act - Click on dropdown trigger (the ... button)
- const dropdownTrigger = screen.getAllByRole('button').find(btn => btn.querySelector('svg'))
- if (dropdownTrigger)
- fireEvent.click(dropdownTrigger)
+ openCollapsedBreadcrumbDropdown()
// Assert - Collapsed breadcrumbs should be visible
- await waitFor(() => {
- expect(screen.getByText('folder3')).toBeInTheDocument()
- expect(screen.getByText('folder4')).toBeInTheDocument()
- })
+ expect(await screen.findByText('folder3')).toBeInTheDocument()
+ expect(await screen.findByText('folder4')).toBeInTheDocument()
})
})
})
@@ -357,7 +465,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Only Drive should be visible
- expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allFiles')).toBeInTheDocument()
+ // Assert - Only Drive should be visible
+ expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allFiles'))!.toBeInTheDocument()
})
it('should handle single breadcrumb', () => {
@@ -366,7 +475,7 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText('single-folder')).toBeInTheDocument()
+ expect(screen.getByText('single-folder'))!.toBeInTheDocument()
})
it('should handle breadcrumbs with special characters', () => {
@@ -377,8 +486,8 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText('folder [1]')).toBeInTheDocument()
- expect(screen.getByText('folder (copy)')).toBeInTheDocument()
+ expect(screen.getByText('folder [1]'))!.toBeInTheDocument()
+ expect(screen.getByText('folder (copy)'))!.toBeInTheDocument()
})
it('should handle breadcrumbs with unicode characters', () => {
@@ -389,8 +498,8 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText('文件夹')).toBeInTheDocument()
- expect(screen.getByText('フォルダ')).toBeInTheDocument()
+ expect(screen.getByText('文件夹'))!.toBeInTheDocument()
+ expect(screen.getByText('フォルダ'))!.toBeInTheDocument()
})
})
@@ -403,7 +512,7 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText(/searchResult/)).toBeInTheDocument()
+ expect(screen.getByText(/searchResult/))!.toBeInTheDocument()
})
it('should handle whitespace keywords', () => {
@@ -415,7 +524,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Whitespace is truthy, so should show search results
- expect(screen.getByText(/searchResult/)).toBeInTheDocument()
+ // Assert - Whitespace is truthy, so should show search results
+ expect(screen.getByText(/searchResult/))!.toBeInTheDocument()
})
})
@@ -428,7 +538,7 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText('production-bucket')).toBeInTheDocument()
+ expect(screen.getByText('production-bucket'))!.toBeInTheDocument()
})
it('should handle bucket with special characters', () => {
@@ -439,7 +549,7 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText('bucket-v2.0_backup')).toBeInTheDocument()
+ expect(screen.getByText('bucket-v2.0_backup'))!.toBeInTheDocument()
})
})
@@ -452,6 +562,37 @@ describe('Breadcrumbs', () => {
render()
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
+ // Assert - Should not show search results
// Assert - Should not show search results
expect(screen.queryByText(/searchResult/)).not.toBeInTheDocument()
})
@@ -464,7 +605,7 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText(/searchResult.*10000/)).toBeInTheDocument()
+ expect(screen.getByText(/searchResult.*10000/))!.toBeInTheDocument()
})
})
@@ -480,9 +621,7 @@ describe('Breadcrumbs', () => {
// Assert - Should collapse because 3 > 2
// Dropdown should be present
- const buttons = screen.getAllByRole('button')
- const hasDropdownTrigger = buttons.some(btn => btn.querySelector('svg'))
- expect(hasDropdownTrigger).toBe(true)
+ expect(getDropdownTrigger()).toBeInTheDocument()
})
it('should use displayBreadcrumbNum=3 when isInPipeline is false', () => {
@@ -495,9 +634,10 @@ describe('Breadcrumbs', () => {
render()
// Assert - Should NOT collapse because 3 <= 3
- expect(screen.getByText('folder1')).toBeInTheDocument()
- expect(screen.getByText('folder2')).toBeInTheDocument()
- expect(screen.getByText('folder3')).toBeInTheDocument()
+ // Assert - Should NOT collapse because 3 <= 3
+ expect(screen.getByText('folder1'))!.toBeInTheDocument()
+ expect(screen.getByText('folder2'))!.toBeInTheDocument()
+ expect(screen.getByText('folder3'))!.toBeInTheDocument()
})
it('should reduce displayBreadcrumbNum by 1 when bucket is set', () => {
@@ -511,9 +651,7 @@ describe('Breadcrumbs', () => {
render()
// Assert - Should collapse because 3 > 2
- const buttons = screen.getAllByRole('button')
- const hasDropdownTrigger = buttons.some(btn => btn.querySelector('svg'))
- expect(hasDropdownTrigger).toBe(true)
+ expect(getDropdownTrigger()).toBeInTheDocument()
})
})
})
@@ -533,9 +671,11 @@ describe('Breadcrumbs', () => {
// Assert - displayBreadcrumbNum = 3, so 4 breadcrumbs should collapse
// First 2 visible, dropdown, last 1 visible
- expect(screen.getByText('a')).toBeInTheDocument()
- expect(screen.getByText('b')).toBeInTheDocument()
- expect(screen.getByText('d')).toBeInTheDocument()
+ // Assert - displayBreadcrumbNum = 3, so 4 breadcrumbs should collapse
+ // First 2 visible, dropdown, last 1 visible
+ expect(screen.getByText('a'))!.toBeInTheDocument()
+ expect(screen.getByText('b'))!.toBeInTheDocument()
+ expect(screen.getByText('d'))!.toBeInTheDocument()
expect(screen.queryByText('c')).not.toBeInTheDocument()
})
@@ -550,8 +690,9 @@ describe('Breadcrumbs', () => {
render()
// Assert - displayBreadcrumbNum = 2, so 3 breadcrumbs should collapse
- expect(screen.getByText('a')).toBeInTheDocument()
- expect(screen.getByText('c')).toBeInTheDocument()
+ // Assert - displayBreadcrumbNum = 2, so 3 breadcrumbs should collapse
+ expect(screen.getByText('a'))!.toBeInTheDocument()
+ expect(screen.getByText('c'))!.toBeInTheDocument()
expect(screen.queryByText('b')).not.toBeInTheDocument()
})
@@ -566,8 +707,9 @@ describe('Breadcrumbs', () => {
render()
// Assert - displayBreadcrumbNum = 3 - 1 = 2, so 3 breadcrumbs should collapse
- expect(screen.getByText('a')).toBeInTheDocument()
- expect(screen.getByText('c')).toBeInTheDocument()
+ // Assert - displayBreadcrumbNum = 3 - 1 = 2, so 3 breadcrumbs should collapse
+ expect(screen.getByText('a'))!.toBeInTheDocument()
+ expect(screen.getByText('c'))!.toBeInTheDocument()
expect(screen.queryByText('b')).not.toBeInTheDocument()
})
})
@@ -582,9 +724,7 @@ describe('Breadcrumbs', () => {
render()
// Act - Click dropdown to see collapsed items
- const dropdownTrigger = screen.getAllByRole('button').find(btn => btn.querySelector('svg'))
- if (dropdownTrigger)
- fireEvent.click(dropdownTrigger)
+ openCollapsedBreadcrumbDropdown()
// prefixBreadcrumbs = ['f1', 'f2']
// collapsedBreadcrumbs = ['f3', 'f4']
@@ -592,10 +732,8 @@ describe('Breadcrumbs', () => {
expect(screen.getByText('f1')).toBeInTheDocument()
expect(screen.getByText('f2')).toBeInTheDocument()
expect(screen.getByText('f5')).toBeInTheDocument()
- await waitFor(() => {
- expect(screen.getByText('f3')).toBeInTheDocument()
- expect(screen.getByText('f4')).toBeInTheDocument()
- })
+ expect(await screen.findByText('f3')).toBeInTheDocument()
+ expect(await screen.findByText('f4')).toBeInTheDocument()
})
it('should not collapse when breadcrumbs.length <= displayBreadcrumbNum', () => {
@@ -608,8 +746,9 @@ describe('Breadcrumbs', () => {
render()
// Assert - All breadcrumbs should be visible
- expect(screen.getByText('f1')).toBeInTheDocument()
- expect(screen.getByText('f2')).toBeInTheDocument()
+ // Assert - All breadcrumbs should be visible
+ expect(screen.getByText('f1'))!.toBeInTheDocument()
+ expect(screen.getByText('f2'))!.toBeInTheDocument()
})
})
})
@@ -627,7 +766,7 @@ describe('Breadcrumbs', () => {
// Act - Click bucket icon button (first button in Bucket component)
const buttons = screen.getAllByRole('button')
- fireEvent.click(buttons[0]) // Bucket icon button
+ fireEvent.click(buttons[0]!) // Bucket icon button
expect(mockStoreState.setOnlineDriveFileList).toHaveBeenCalledWith([])
expect(mockStoreState.setSelectedFileIds).toHaveBeenCalledWith([])
@@ -739,15 +878,8 @@ describe('Breadcrumbs', () => {
render()
// Act - Open dropdown and click on collapsed breadcrumb (f3, index=2)
- const dropdownTrigger = screen.getAllByRole('button').find(btn => btn.querySelector('svg'))
- if (dropdownTrigger)
- fireEvent.click(dropdownTrigger)
-
- await waitFor(() => {
- expect(screen.getByText('f3')).toBeInTheDocument()
- })
-
- fireEvent.click(screen.getByText('f3'))
+ openCollapsedBreadcrumbDropdown()
+ fireEvent.click(await screen.findByText('f3'))
// Assert - Should slice to index 2 + 1 = 3
expect(mockStoreState.setBreadcrumbs).toHaveBeenCalledWith(['f1', 'f2', 'f3'])
@@ -771,19 +903,19 @@ describe('Breadcrumbs', () => {
// Assert - Component should render without errors
const container = document.querySelector('.flex.grow')
- expect(container).toBeInTheDocument()
+ expect(container)!.toBeInTheDocument()
})
it('should re-render when breadcrumbs change', () => {
mockStoreState.hasBucket = false
const props = createDefaultProps({ breadcrumbs: ['folder1'] })
const { rerender } = render()
- expect(screen.getByText('folder1')).toBeInTheDocument()
+ expect(screen.getByText('folder1'))!.toBeInTheDocument()
// Act - Rerender with different breadcrumbs
rerender()
- expect(screen.getByText('folder2')).toBeInTheDocument()
+ expect(screen.getByText('folder2'))!.toBeInTheDocument()
})
})
@@ -798,7 +930,7 @@ describe('Breadcrumbs', () => {
render()
- expect(screen.getByText(longName)).toBeInTheDocument()
+ expect(screen.getByText(longName))!.toBeInTheDocument()
})
it('should handle many breadcrumbs', async () => {
@@ -810,17 +942,13 @@ describe('Breadcrumbs', () => {
render()
// Act - Open dropdown
- const dropdownTrigger = screen.getAllByRole('button').find(btn => btn.querySelector('svg'))
- if (dropdownTrigger)
- fireEvent.click(dropdownTrigger)
+ openCollapsedBreadcrumbDropdown()
// Assert - First, last, and collapsed should be accessible
expect(screen.getByText('folder-0')).toBeInTheDocument()
expect(screen.getByText('folder-1')).toBeInTheDocument()
expect(screen.getByText('folder-19')).toBeInTheDocument()
- await waitFor(() => {
- expect(screen.getByText('folder-2')).toBeInTheDocument()
- })
+ expect(await screen.findByText('folder-2')).toBeInTheDocument()
})
it('should handle empty bucket string', () => {
@@ -833,7 +961,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Should show all buckets title
- expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets')).toBeInTheDocument()
+ // Assert - Should show all buckets title
+ expect(screen.getByText('datasetPipeline.onlineDrive.breadcrumbs.allBuckets'))!.toBeInTheDocument()
})
it('should handle breadcrumb with only whitespace', () => {
@@ -845,7 +974,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Both should be rendered
- expect(screen.getByText('normal-folder')).toBeInTheDocument()
+ // Assert - Both should be rendered
+ expect(screen.getByText('normal-folder'))!.toBeInTheDocument()
})
})
@@ -863,7 +993,7 @@ describe('Breadcrumbs', () => {
// Assert - Component should render without errors
const container = document.querySelector('.flex.grow')
- expect(container).toBeInTheDocument()
+ expect(container)!.toBeInTheDocument()
})
it.each([
@@ -879,9 +1009,7 @@ describe('Breadcrumbs', () => {
render()
// Assert - Should collapse because breadcrumbs.length > expectedNum
- const buttons = screen.getAllByRole('button')
- const hasDropdownTrigger = buttons.some(btn => btn.querySelector('svg'))
- expect(hasDropdownTrigger).toBe(true)
+ expect(getDropdownTrigger()).toBeInTheDocument()
})
})
@@ -916,7 +1044,8 @@ describe('Breadcrumbs', () => {
render()
// Assert - Search result should be shown, navigation elements should be hidden
- expect(screen.getByText(/searchResult/)).toBeInTheDocument()
+ // Assert - Search result should be shown, navigation elements should be hidden
+ expect(screen.getByText(/searchResult/))!.toBeInTheDocument()
expect(screen.queryByText('my-bucket')).not.toBeInTheDocument()
})
})
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/__tests__/index.spec.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/__tests__/index.spec.tsx
index 0157d3cf79..d57e8340e9 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/__tests__/index.spec.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/__tests__/index.spec.tsx
@@ -23,7 +23,8 @@ describe('Dropdown', () => {
render()
// Assert - Trigger button should be visible
- expect(screen.getByRole('button')).toBeInTheDocument()
+ // Assert - Trigger button should be visible
+ expect(screen.getByRole('button'))!.toBeInTheDocument()
})
it('should render trigger button with more icon', () => {
@@ -31,10 +32,10 @@ describe('Dropdown', () => {
const { container } = render()
- // Assert - Button should have RiMoreFill icon (rendered as svg)
+ // Assert - Button should have the more icon
const button = screen.getByRole('button')
expect(button).toBeInTheDocument()
- expect(container.querySelector('svg')).toBeInTheDocument()
+ expect(container.querySelector('.i-ri-more-fill')).toBeInTheDocument()
})
it('should render separator after dropdown', () => {
@@ -43,7 +44,8 @@ describe('Dropdown', () => {
render()
// Assert - Separator "/" should be visible
- expect(screen.getByText('/')).toBeInTheDocument()
+ // Assert - Separator "/" should be visible
+ expect(screen.getByText('/'))!.toBeInTheDocument()
})
it('should render trigger button with correct default styles', () => {
@@ -52,11 +54,11 @@ describe('Dropdown', () => {
render()
const button = screen.getByRole('button')
- expect(button).toHaveClass('flex')
- expect(button).toHaveClass('size-6')
- expect(button).toHaveClass('items-center')
- expect(button).toHaveClass('justify-center')
- expect(button).toHaveClass('rounded-md')
+ expect(button)!.toHaveClass('flex')
+ expect(button)!.toHaveClass('size-6')
+ expect(button)!.toHaveClass('items-center')
+ expect(button)!.toHaveClass('justify-center')
+ expect(button)!.toHaveClass('rounded-md')
})
it('should not render menu content when closed', () => {
@@ -64,6 +66,37 @@ describe('Dropdown', () => {
render()
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
+ // Assert - Menu content should not be visible when dropdown is closed
// Assert - Menu content should not be visible when dropdown is closed
expect(screen.queryByText('visible-folder')).not.toBeInTheDocument()
})
@@ -77,8 +110,8 @@ describe('Dropdown', () => {
// Assert - Menu items should be visible
await waitFor(() => {
- expect(screen.getByText('test-folder1')).toBeInTheDocument()
- expect(screen.getByText('test-folder2')).toBeInTheDocument()
+ expect(screen.getByText('test-folder1'))!.toBeInTheDocument()
+ expect(screen.getByText('test-folder2'))!.toBeInTheDocument()
})
})
})
@@ -98,7 +131,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder1')).toBeInTheDocument()
+ expect(screen.getByText('folder1'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder1'))
@@ -120,7 +153,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder2')).toBeInTheDocument()
+ expect(screen.getByText('folder2'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder2'))
@@ -140,9 +173,9 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder-a')).toBeInTheDocument()
- expect(screen.getByText('folder-b')).toBeInTheDocument()
- expect(screen.getByText('folder-c')).toBeInTheDocument()
+ expect(screen.getByText('folder-a'))!.toBeInTheDocument()
+ expect(screen.getByText('folder-b'))!.toBeInTheDocument()
+ expect(screen.getByText('folder-c'))!.toBeInTheDocument()
})
})
@@ -155,7 +188,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('single-folder')).toBeInTheDocument()
+ expect(screen.getByText('single-folder'))!.toBeInTheDocument()
})
})
@@ -170,7 +203,8 @@ describe('Dropdown', () => {
// Assert - Menu should be rendered but with no items
await waitFor(() => {
// The menu container should exist but be empty
- expect(screen.getByRole('button')).toBeInTheDocument()
+ // The menu container should exist but be empty
+ expect(screen.getByRole('button'))!.toBeInTheDocument()
})
})
@@ -183,9 +217,9 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder [1]')).toBeInTheDocument()
- expect(screen.getByText('folder (copy)')).toBeInTheDocument()
- expect(screen.getByText('folder-v2.0')).toBeInTheDocument()
+ expect(screen.getByText('folder [1]'))!.toBeInTheDocument()
+ expect(screen.getByText('folder (copy)'))!.toBeInTheDocument()
+ expect(screen.getByText('folder-v2.0'))!.toBeInTheDocument()
})
})
@@ -198,9 +232,9 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('文件夹')).toBeInTheDocument()
- expect(screen.getByText('フォルダ')).toBeInTheDocument()
- expect(screen.getByText('Папка')).toBeInTheDocument()
+ expect(screen.getByText('文件夹'))!.toBeInTheDocument()
+ expect(screen.getByText('フォルダ'))!.toBeInTheDocument()
+ expect(screen.getByText('Папка'))!.toBeInTheDocument()
})
})
})
@@ -218,7 +252,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder1')).toBeInTheDocument()
+ expect(screen.getByText('folder1'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder1'))
@@ -236,6 +270,37 @@ describe('Dropdown', () => {
render()
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
+ // Assert - Menu content should not be visible
// Assert - Menu content should not be visible
expect(screen.queryByText('test-folder')).not.toBeInTheDocument()
})
@@ -247,7 +312,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('test-folder')).toBeInTheDocument()
+ expect(screen.getByText('test-folder'))!.toBeInTheDocument()
})
})
@@ -258,7 +323,7 @@ describe('Dropdown', () => {
// Act - Open and then close
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('test-folder')).toBeInTheDocument()
+ expect(screen.getByText('test-folder'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByRole('button'))
@@ -280,7 +345,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('test-folder')).toBeInTheDocument()
+ expect(screen.getByText('test-folder'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('test-folder'))
@@ -297,14 +362,15 @@ describe('Dropdown', () => {
const button = screen.getByRole('button')
// Assert - Initial state (closed): should have hover:bg-state-base-hover
- expect(button).toHaveClass('hover:bg-state-base-hover')
+ // Assert - Initial state (closed): should have hover:bg-state-base-hover
+ expect(button)!.toHaveClass('hover:bg-state-base-hover')
// Act - Open dropdown
fireEvent.click(button)
// Assert - Open state: should have bg-state-base-hover
await waitFor(() => {
- expect(button).toHaveClass('bg-state-base-hover')
+ expect(button)!.toHaveClass('bg-state-base-hover')
})
})
})
@@ -317,6 +383,37 @@ describe('Dropdown', () => {
const props = createDefaultProps({ breadcrumbs: ['folder'] })
render()
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
+ // Act & Assert - Initially closed
// Act & Assert - Initially closed
expect(screen.queryByText('folder')).not.toBeInTheDocument()
@@ -325,7 +422,7 @@ describe('Dropdown', () => {
// Assert - Now open
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
})
@@ -338,7 +435,7 @@ describe('Dropdown', () => {
// 1st click - open
fireEvent.click(button)
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
// 2nd click - close
@@ -350,7 +447,7 @@ describe('Dropdown', () => {
// 3rd click - open again
fireEvent.click(button)
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
})
})
@@ -368,7 +465,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder1')).toBeInTheDocument()
+ expect(screen.getByText('folder1'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder1'))
@@ -394,7 +491,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder1')).toBeInTheDocument()
+ expect(screen.getByText('folder1'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder1'))
@@ -423,7 +520,7 @@ describe('Dropdown', () => {
// Act - Open and click
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder'))
@@ -431,7 +528,7 @@ describe('Dropdown', () => {
rerender()
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder'))
@@ -450,7 +547,7 @@ describe('Dropdown', () => {
// Act - Open and click with first callback
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder'))
@@ -466,7 +563,7 @@ describe('Dropdown', () => {
// Open and click with second callback
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder'))
@@ -482,7 +579,8 @@ describe('Dropdown', () => {
rerender()
// Assert - Component should render without errors
- expect(screen.getByRole('button')).toBeInTheDocument()
+ // Assert - Component should render without errors
+ expect(screen.getByRole('button'))!.toBeInTheDocument()
})
})
@@ -499,7 +597,7 @@ describe('Dropdown', () => {
// Assert - Should handle gracefully (open after odd number of clicks)
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
})
@@ -513,7 +611,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText(longName)).toBeInTheDocument()
+ expect(screen.getByText(longName))!.toBeInTheDocument()
})
})
@@ -528,8 +626,8 @@ describe('Dropdown', () => {
// Assert - First and last items should be visible
await waitFor(() => {
- expect(screen.getByText('folder-0')).toBeInTheDocument()
- expect(screen.getByText('folder-19')).toBeInTheDocument()
+ expect(screen.getByText('folder-0'))!.toBeInTheDocument()
+ expect(screen.getByText('folder-19'))!.toBeInTheDocument()
})
})
@@ -544,7 +642,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder'))
@@ -562,7 +660,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('folder'))
@@ -578,7 +676,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('normal-folder')).toBeInTheDocument()
+ expect(screen.getByText('normal-folder'))!.toBeInTheDocument()
})
})
@@ -591,7 +689,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('folder')).toBeInTheDocument()
+ expect(screen.getByText('folder'))!.toBeInTheDocument()
})
})
})
@@ -613,9 +711,9 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText(breadcrumbs[0])).toBeInTheDocument()
+ expect(screen.getByText(breadcrumbs[0]!))!.toBeInTheDocument()
})
- fireEvent.click(screen.getByText(breadcrumbs[0]))
+ fireEvent.click(screen.getByText(breadcrumbs[0]!))
expect(mockOnBreadcrumbClick).toHaveBeenCalledWith(expectedIndex)
})
@@ -634,7 +732,7 @@ describe('Dropdown', () => {
// Assert - Should render without errors
await waitFor(() => {
if (breadcrumbs.length > 0)
- expect(screen.getByText(breadcrumbs[0])).toBeInTheDocument()
+ expect(screen.getByText(breadcrumbs[0]!))!.toBeInTheDocument()
})
})
})
@@ -650,9 +748,9 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('Documents')).toBeInTheDocument()
- expect(screen.getByText('Projects')).toBeInTheDocument()
- expect(screen.getByText('Archive')).toBeInTheDocument()
+ expect(screen.getByText('Documents'))!.toBeInTheDocument()
+ expect(screen.getByText('Projects'))!.toBeInTheDocument()
+ expect(screen.getByText('Archive'))!.toBeInTheDocument()
})
})
@@ -668,7 +766,7 @@ describe('Dropdown', () => {
// Act - Open and click on second item
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('second')).toBeInTheDocument()
+ expect(screen.getByText('second'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('second'))
@@ -687,7 +785,7 @@ describe('Dropdown', () => {
// Act - Open and click on middle item
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText('item2')).toBeInTheDocument()
+ expect(screen.getByText('item2'))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText('item2'))
@@ -714,7 +812,7 @@ describe('Dropdown', () => {
fireEvent.click(screen.getByRole('button'))
await waitFor(() => {
- expect(screen.getByText(`folder-${String.fromCharCode(97 + i)}`)).toBeInTheDocument()
+ expect(screen.getByText(`folder-${String.fromCharCode(97 + i)}`))!.toBeInTheDocument()
})
fireEvent.click(screen.getByText(`folder-${String.fromCharCode(97 + i)}`))
@@ -731,7 +829,7 @@ describe('Dropdown', () => {
render()
const button = screen.getByRole('button')
- expect(button).toBeInTheDocument()
+ expect(button)!.toBeInTheDocument()
expect(button.tagName).toBe('BUTTON')
})
@@ -741,7 +839,7 @@ describe('Dropdown', () => {
render()
const button = screen.getByRole('button')
- expect(button).toHaveAttribute('type', 'button')
+ expect(button)!.toHaveAttribute('type', 'button')
})
})
})
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/index.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/index.tsx
index 7178b45b34..43b5fcc71a 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/index.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/index.tsx
@@ -1,12 +1,11 @@
import { cn } from '@langgenius/dify-ui/cn'
-import { RiMoreFill } from '@remixicon/react'
+import {
+ DropdownMenu,
+ DropdownMenuContent,
+ DropdownMenuTrigger,
+} from '@langgenius/dify-ui/dropdown-menu'
import * as React from 'react'
import { useCallback, useState } from 'react'
-import {
- PortalToFollowElem,
- PortalToFollowElemContent,
- PortalToFollowElemTrigger,
-} from '@/app/components/base/portal-to-follow-elem'
import Menu from './menu'
type DropdownProps = {
@@ -22,26 +21,17 @@ const Dropdown = ({
}: DropdownProps) => {
const [open, setOpen] = useState(false)
- const handleTrigger = useCallback(() => {
- setOpen(prev => !prev)
- }, [])
-
const handleBreadCrumbClick = useCallback((index: number) => {
onBreadcrumbClick(index)
setOpen(false)
}, [onBreadcrumbClick])
return (
-
-
+ }>
-
-
+
+
-
+
/
-
+
)
}
diff --git a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/item.tsx b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/item.tsx
index 864cade85c..6f04ede88a 100644
--- a/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/item.tsx
+++ b/web/app/components/datasets/documents/create-from-pipeline/data-source/online-drive/file-list/header/breadcrumbs/dropdown/item.tsx
@@ -18,7 +18,7 @@ const Item = ({
return (