diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md
index 4da070bdbf..105c979c58 100644
--- a/.agents/skills/frontend-testing/SKILL.md
+++ b/.agents/skills/frontend-testing/SKILL.md
@@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path:
- ✅ **Import real project components** directly (including base components and siblings)
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
-- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
+- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`)
- ❌ **DO NOT mock** sibling/child components in the same directory
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
@@ -325,12 +325,12 @@ For more detailed information, refer to:
### Reference Examples in Codebase
- `web/utils/classnames.spec.ts` - Utility function tests
-- `web/app/components/base/button/index.spec.tsx` - Component tests
+- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests
- `web/__mocks__/provider-context.ts` - Mock factory example
### Project Configuration
-- `web/vitest.config.ts` - Vitest configuration
+- `web/vite.config.ts` - Vite/Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/scripts/analyze-component.js` - Component analysis tool
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md
index 10b8fb66f9..99258498dd 100644
--- a/.agents/skills/frontend-testing/references/checklist.md
+++ b/.agents/skills/frontend-testing/references/checklist.md
@@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Integration vs Mocking
-- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
+- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.)
- [ ] Import real project components instead of mocking
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
- [ ] Prefer integration testing when using single spec file
@@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Mocks
-- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
+- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`)
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md
index f58377c4a5..8c2f1c0c58 100644
--- a/.agents/skills/frontend-testing/references/mocking.md
+++ b/.agents/skills/frontend-testing/references/mocking.md
@@ -2,29 +2,27 @@
## ⚠️ Important: What NOT to Mock
-### DO NOT Mock Base Components
+### DO NOT Mock Base Components or dify-ui Primitives
-**Never mock components from `@/app/components/base/`** such as:
+**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as:
-- `Loading`, `Spinner`
-- `Button`, `Input`, `Select`
-- `Tooltip`, `Modal`, `Dropdown`
-- `Icon`, `Badge`, `Tag`
+- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag`
+- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast`
**Why?**
-- Base components will have their own dedicated tests
+- These components have their own dedicated tests
- Mocking them creates false positives (tests pass but real integration fails)
- Using real components tests actual integration behavior
```typescript
-// ❌ WRONG: Don't mock base components
+// ❌ WRONG: Don't mock base components or dify-ui primitives
vi.mock('@/app/components/base/loading', () => () =>
Loading
)
-vi.mock('@/app/components/base/button', () => ({ children }: any) => {children} )
+vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => {children} }))
-// ✅ CORRECT: Import and use real base components
+// ✅ CORRECT: Import and use the real components
import Loading from '@/app/components/base/loading'
-import Button from '@/app/components/base/button'
+import { Button } from '@langgenius/dify-ui/button'
// They will render normally in tests
```
@@ -319,7 +317,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ✅ DO
-1. **Use real base components** - Import from `@/app/components/base/` directly
+1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly
1. **Use real project components** - Prefer importing over mocking
1. **Use real Zustand stores** - Set test state via `store.setState()`
1. **Reset mocks in `beforeEach`**, not `afterEach`
@@ -330,7 +328,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ❌ DON'T
-1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
+1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.)
1. **Don't mock Zustand store modules** - Use real stores with `setState()`
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
@@ -342,7 +340,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
```
Need to use a component in test?
│
-├─ Is it from @/app/components/base/*?
+├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*?
│ └─ YES → Import real component, DO NOT mock
│
├─ Is it a project component?
diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml
deleted file mode 100644
index b0f0a36bc9..0000000000
--- a/.github/workflows/anti-slop.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Anti-Slop PR Check
-
-on:
- pull_request_target:
- types: [opened, edited, synchronize]
-
-permissions:
- pull-requests: write
- contents: read
-
-jobs:
- anti-slop:
- runs-on: ubuntu-latest
- steps:
- - uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
- with:
- github-token: ${{ secrets.GITHUB_TOKEN }}
- close-pr: false
- failure-add-pr-labels: "needs-revision"
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index fd910531db..717413937f 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -35,7 +35,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -84,7 +84,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -105,7 +105,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@@ -156,7 +156,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml
index 3946834e09..35683b112f 100644
--- a/.github/workflows/autofix.yml
+++ b/.github/workflows/autofix.yml
@@ -25,7 +25,7 @@ jobs:
- name: Check Docker Compose inputs
if: github.event_name != 'merge_group'
id: docker-compose-changes
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
docker/generate_docker_compose
@@ -35,7 +35,7 @@ jobs:
- name: Check web inputs
if: github.event_name != 'merge_group'
id: web-changes
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
@@ -48,7 +48,7 @@ jobs:
- name: Check api inputs
if: github.event_name != 'merge_group'
id: api-changes
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
api/**
@@ -58,7 +58,7 @@ jobs:
python-version: "3.11"
- if: github.event_name != 'merge_group'
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
- name: Generate Docker Compose
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
@@ -123,4 +123,4 @@ jobs:
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
- if: github.event_name != 'merge_group'
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
+ uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4
diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml
index 5991abe3ba..17b867dd6d 100644
--- a/.github/workflows/db-migration-test.yml
+++ b/.github/workflows/db-migration-test.yml
@@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
@@ -40,7 +40,7 @@ jobs:
cp middleware.env.example middleware.env
- name: Set up Middlewares
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
@@ -94,7 +94,7 @@ jobs:
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml
index ac3732579c..eb15cd6f75 100644
--- a/.github/workflows/pyrefly-diff.yml
+++ b/.github/workflows/pyrefly-diff.yml
@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml
index 974da99aad..3c6c96a664 100644
--- a/.github/workflows/pyrefly-type-coverage-comment.yml
+++ b/.github/workflows/pyrefly-type-coverage-comment.yml
@@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml
index c795c32e31..0599c94eef 100644
--- a/.github/workflows/pyrefly-type-coverage.yml
+++ b/.github/workflows/pyrefly-type-coverage.yml
@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index 29f5b090f8..d8c7ebbad3 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -25,7 +25,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
api/**
@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: false
python-version: "3.12"
@@ -73,7 +73,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
web/**
@@ -95,7 +95,7 @@ jobs:
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
- uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
+ uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: .eslintcache
key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
@@ -124,7 +124,7 @@ jobs:
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
- uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
+ uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: .eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
@@ -142,7 +142,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
+ uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
with:
files: |
**.sh
diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml
index 467f31fccf..bf33207a14 100644
--- a/.github/workflows/tool-test-sdks.yaml
+++ b/.github/workflows/tool-test-sdks.yaml
@@ -30,7 +30,7 @@ jobs:
persist-credentials: false
- name: Use Node.js
- uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
+ uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
with:
node-version: 22
cache: ''
diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml
index 541200293d..eecbbb1a56 100644
--- a/.github/workflows/translate-i18n-claude.yml
+++ b/.github/workflows/translate-i18n-claude.yml
@@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
- uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
+ uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/vdb-tests-full.yml b/.github/workflows/vdb-tests-full.yml
index f0def8fe7a..b79e8927d7 100644
--- a/.github/workflows/vdb-tests-full.yml
+++ b/.github/workflows/vdb-tests-full.yml
@@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -65,7 +65,7 @@ jobs:
# tiflash
- name: Set up Full Vector Store Matrix
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml
diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml
index f3966f15b9..bd13d662c3 100644
--- a/.github/workflows/vdb-tests.yml
+++ b/.github/workflows/vdb-tests.yml
@@ -33,7 +33,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -62,7 +62,7 @@ jobs:
# tiflash
- name: Set up Vector Stores for Smoke Coverage
- uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
+ uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
with:
compose-file: |
docker/docker-compose.yaml
diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml
index 10dc31bde8..6bd4d4f406 100644
--- a/.github/workflows/web-e2e.yml
+++ b/.github/workflows/web-e2e.yml
@@ -28,7 +28,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Setup UV and Python
- uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
with:
enable-cache: true
python-version: "3.12"
diff --git a/.gitignore b/.gitignore
index 3493a7c756..836bddbb49 100644
--- a/.gitignore
+++ b/.gitignore
@@ -237,6 +237,10 @@ scripts/stress-test/reports/
.playwright-mcp/
.serena/
+# vitest browser mode attachments (failure screenshots, traces, etc.)
+.vitest-attachments/
+**/__screenshots__/
+
# settings
*.local.json
*.local.md
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 051d08aa36..9102983d86 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -129,6 +129,7 @@ class AppNamePayload(BaseModel):
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
+ icon_type: IconType | None = Field(default=None, description="Icon type")
icon_background: str | None = Field(default=None, description="Icon background color")
@@ -729,7 +730,12 @@ class AppIconApi(Resource):
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
- app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
+ app_model = app_service.update_app_icon(
+ app_model,
+ args.icon or "",
+ args.icon_background or "",
+ args.icon_type,
+ )
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")
diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py
index cead33d14f..9c8b095b9f 100644
--- a/api/controllers/console/app/conversation_variables.py
+++ b/api/controllers/console/app/conversation_variables.py
@@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
return value
try:
diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py
index f6319573e0..e32ba5f66c 100644
--- a/api/controllers/console/app/workflow_draft_variable.py
+++ b/api/controllers/console/app/workflow_draft_variable.py
@@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
- return value_type.exposed_type().value
+ return str(value_type.exposed_type())
class FullContentDict(TypedDict):
@@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
result: FullContentDict = {
"size_bytes": variable_file.size,
- "value_type": variable_file.value_type.exposed_type().value,
+ "value_type": str(variable_file.value_type.exposed_type()),
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
@@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
- "value_type": v.value_type.exposed_type().value,
+ "value_type": str(v.value_type.exposed_type()),
"value": v.value,
# Do not track edited for env vars.
"edited": False,
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index ea0fdef0a7..d001dfba64 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -50,6 +50,7 @@ from fields.dataset_fields import (
from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.login import current_account_with_tenant, login_required
+from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
@@ -889,7 +890,8 @@ class DatasetApiBaseUrlApi(Resource):
@login_required
@account_initialization_required
def get(self):
- return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
+ base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/")
+ return {"api_base_url": normalize_api_base_url(base)}
@console_ns.route("/datasets/retrieval-setting")
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index 44404005b2..c01286cc59 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -595,13 +595,25 @@ class ChangeEmailSendEmailApi(Resource):
account = None
user_email = None
email_for_sending = args.email.lower()
- if args.phase is not None and args.phase == "new_email":
+ # Default to the initial phase; any legacy/unexpected client input is
+ # coerced back to `old_email` so we never trust the caller to declare
+ # later phases without a verified predecessor token.
+ send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
+ if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW:
+ send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
if args.token is None:
raise InvalidTokenError()
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
+
+ # The token used to request a new-email code must come from the
+ # old-email verification step. This prevents the bypass described
+ # in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
+ token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
+ if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
+ raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
@@ -620,7 +632,7 @@ class ChangeEmailSendEmailApi(Resource):
email=email_for_sending,
old_email=user_email,
language=language,
- phase=args.phase,
+ phase=send_phase,
)
return {"result": "success", "data": token}
@@ -655,12 +667,31 @@ class ChangeEmailCheckApi(Resource):
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
+ # Only advance tokens that were minted by the matching send-code step;
+ # refuse tokens that have already progressed or lack a phase marker so
+ # the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
+ # is strictly enforced.
+ phase_transitions = {
+ AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
+ AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
+ }
+ token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
+ if not isinstance(token_phase, str):
+ raise InvalidTokenError()
+ refreshed_phase = phase_transitions.get(token_phase)
+ if refreshed_phase is None:
+ raise InvalidTokenError()
+
# Verified, revoke the first token
AccountService.revoke_change_email_token(args.token)
- # Refresh token data by generating a new token
+ # Refresh token data by generating a new token that carries the
+ # upgraded phase so later steps can check it.
_, new_token = AccountService.generate_change_email_token(
- user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
+ user_email,
+ code=args.code,
+ old_email=token_data.get("old_email"),
+ additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
)
AccountService.reset_change_email_error_rate_limit(user_email)
@@ -690,13 +721,29 @@ class ChangeEmailResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
- AccountService.revoke_change_email_token(args.token)
+ # Only tokens that completed both verification phases may be used to
+ # change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
+ # the initial send-code step could be replayed directly here.
+ token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
+ if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
+ raise InvalidTokenError()
+
+ # Bind the new email to the token that was mailed and verified, so a
+ # verified token cannot be reused with a different `new_email` value.
+ token_email = reset_data.get("email")
+ normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
+ if normalized_token_email != normalized_new_email:
+ raise InvalidTokenError()
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
+ # Revoke only after all checks pass so failed attempts don't burn a
+ # legitimately verified token.
+ AccountService.revoke_change_email_token(args.token)
+
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index 471594f349..34c9534de8 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -1131,6 +1131,14 @@ class ToolMCPAuthApi(Resource):
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
+ parsed = urlparse(server_url)
+ sanitized_url = f"{parsed.scheme}://{parsed.hostname}{parsed.path}"
+ logger.warning(
+ "MCP authorization failed for provider %s (url=%s)",
+ provider_id,
+ sanitized_url,
+ exc_info=True,
+ )
raise ValueError(f"Failed to connect to MCP server: {e}") from e
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index c4353ca7b8..ca4b18cb5e 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
try:
- return str(SegmentType(value).exposed_type().value)
+ return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index 790602ef5d..c22102c2ba 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
index dbd7527fc6..5df3df2b3e 100644
--- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
+++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
@@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class ModelConfigConverter:
diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py
index 09ddce327e..cae0eee0df 100644
--- a/api/core/app/apps/agent_chat/app_runner.py
+++ b/api/core/app/apps/agent_chat/app_runner.py
@@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index dfe6133cb6..e2e07ebaff 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
TextPromptMessageContent,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py
index 68e5e5f0c8..3a6f9d575a 100644
--- a/api/core/app/workflow/file_runtime.py
+++ b/api/core/app/workflow/file_runtime.py
@@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
-from core.helper.ssrf_proxy import ssrf_proxy
+from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
from graphon.file import FileTransferMethod
-from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
+from graphon.file.protocols import WorkflowFileRuntimeProtocol
from graphon.file.runtime import set_workflow_file_runtime
+from graphon.http.protocols import HttpResponseProtocol
if TYPE_CHECKING:
from graphon.file import File
@@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
return dify_config.MULTIMODAL_SEND_FORMAT
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
- return ssrf_proxy.get(url, follow_redirects=follow_redirects)
+ return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)
diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py
index 87f005a250..d521304615 100644
--- a/api/core/app/workflow/layers/persistence.py
+++ b/api/core/app/workflow/layers/persistence.py
@@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
- execution.exceptions_count = runtime_state.exceptions_count
+ execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
def _update_node_execution(
self,
diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py
index dc831e5cac..f0dcb13b62 100644
--- a/api/core/datasource/datasource_manager.py
+++ b/api/core/datasource/datasource_manager.py
@@ -352,11 +352,11 @@ class DatasourceManager:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.CUSTOM,
+ file_type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(record_id=str(upload_file.id)),
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index 6bbf163c9d..38b87e2cd1 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
FormType,
ProviderEntity,
)
-from graphon.model_runtime.model_providers.__base.ai_model import AIModel
+from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
@@ -363,7 +363,7 @@ class ProviderConfiguration(BaseModel):
)
for key, value in validated_credentials.items():
- if key in provider_credential_secret_variables:
+ if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
@@ -912,7 +912,7 @@ class ProviderConfiguration(BaseModel):
)
for key, value in validated_credentials.items():
- if key in provider_credential_secret_variables:
+ if key in provider_credential_secret_variables and isinstance(value, str):
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index b96a9ce380..38864a1830 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
- inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
+ inputs_json_str = dumps_with_segments(inputs).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded
diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py
index dc37a36943..f169f247cf 100644
--- a/api/core/helper/moderation.py
+++ b/api/core/helper/moderation.py
@@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py
index e38592bb7b..91e92712b7 100644
--- a/api/core/helper/ssrf_proxy.py
+++ b/api/core/helper/ssrf_proxy.py
@@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
+from graphon.http.response import HttpResponse
logger = logging.getLogger(__name__)
@@ -267,4 +268,47 @@ class SSRFProxy:
return patch(url=url, max_retries=max_retries, **kwargs)
+def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
+ """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
+ return HttpResponse(
+ status_code=response.status_code,
+ headers=dict(response.headers),
+ content=response.content,
+ url=str(response.url) if response.url else None,
+ reason_phrase=response.reason_phrase,
+ fallback_text=response.text,
+ )
+
+
+class GraphonSSRFProxy:
+ """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
+
+ @property
+ def max_retries_exceeded_error(self) -> type[Exception]:
+ return max_retries_exceeded_error
+
+ @property
+ def request_error(self) -> type[Exception]:
+ return request_error
+
+ def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
+
+ def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
+
+ def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
+
+ def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
+
+ def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
+
+ def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
+
+
ssrf_proxy = SSRFProxy()
+graphon_ssrf_proxy = GraphonSSRFProxy()
diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py
index 5c3cd0d8f8..acba3e666b 100644
--- a/api/core/mcp/client/streamable_client.py
+++ b/api/core/mcp/client/streamable_client.py
@@ -303,9 +303,16 @@ class StreamableHTTPTransport:
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
+ error_msg = (
+ f"MCP server URL returned 404 Not Found: {self.url} "
+ "— verify the server URL is correct and the server is running"
+ if is_initialization
+ else "Session terminated by server"
+ )
self._send_session_terminated_error(
ctx.server_to_client_queue,
message.root.id,
+ message=error_msg,
)
return
@@ -381,12 +388,13 @@ class StreamableHTTPTransport:
self,
server_to_client_queue: ServerToClientQueue,
request_id: RequestId,
+ message: str = "Session terminated by server",
):
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
- error=ErrorData(code=32600, message="Session terminated by server"),
+ error=ErrorData(code=32600, message=message),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
server_to_client_queue.put(session_message)
diff --git a/api/core/model_manager.py b/api/core/model_manager.py
index d8d8dfedd8..86d0e3baaa 100644
--- a/api/core/model_manager.py
+++ b/api/core/model_manager.py
@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
-from typing import IO, Any, Literal, Optional, Union, cast, overload
+from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
from core.entities import PluginCredentialType
@@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
-from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
-from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
+from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
+P = ParamSpec("P")
+R = TypeVar("R")
class ModelInstance:
@@ -168,7 +170,7 @@ class ModelInstance:
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -193,7 +195,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
+ self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -213,7 +215,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -235,7 +237,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
@@ -252,7 +254,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
+ self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -277,7 +279,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -305,7 +307,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke_multimodal_rerank,
+ self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -324,7 +326,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
@@ -340,7 +342,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
@@ -357,14 +359,14 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
- def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
+ def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke
diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py
index e3fba4ef3a..4e66d58b5e 100644
--- a/api/core/plugin/impl/model_runtime.py
+++ b/api/core/plugin/impl/model_runtime.py
@@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
- provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
+ provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
- provider_schema.icon_small_dark.zh_Hans
+ provider_schema.icon_small_dark.zh_hans
if lang.lower() == "zh_hans"
- else provider_schema.icon_small_dark.en_US
+ else provider_schema.icon_small_dark.en_us
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py
index 8f1d51f08a..7c6280fe93 100644
--- a/api/core/prompt/agent_history_prompt_transform.py
+++ b/api/core/prompt/agent_history_prompt_transform.py
@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class AgentHistoryPromptTransform(PromptTransform):
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index c3bbe8fc09..8969825be4 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -70,12 +70,32 @@ class ProviderManager:
Request-bound managers may carry caller identity in that runtime, and the
resulting ``ProviderConfiguration`` objects must reuse it for downstream
model-type and schema lookups.
+
+ Configuration assembly is cached per manager instance so call chains that
+ share one request-scoped manager can reuse the same provider graph instead
+ of rebuilding it for every lookup. Call ``clear_configurations_cache()``
+ when a long-lived manager needs to observe writes performed within the same
+ instance scope.
"""
+ decoding_rsa_key: Any | None
+ decoding_cipher_rsa: Any | None
+ _model_runtime: ModelRuntime
+ _configurations_cache: dict[str, ProviderConfigurations]
+
def __init__(self, model_runtime: ModelRuntime):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._model_runtime = model_runtime
+ self._configurations_cache = {}
+
+ def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
+ """Drop assembled provider configurations cached on this manager instance."""
+ if tenant_id is None:
+ self._configurations_cache.clear()
+ return
+
+ self._configurations_cache.pop(tenant_id, None)
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@@ -114,6 +134,10 @@ class ProviderManager:
:param tenant_id:
:return:
"""
+ cached_configurations = self._configurations_cache.get(tenant_id)
+ if cached_configurations is not None:
+ return cached_configurations
+
# Get all provider records of the workspace
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
@@ -273,6 +297,8 @@ class ProviderManager:
provider_configurations[str(provider_id_entity)] = provider_configuration
+ self._configurations_cache[tenant_id] = provider_configurations
+
# Return the encapsulated object
return provider_configurations
diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py
index ed264878d3..242da520c1 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba.py
@@ -139,8 +139,10 @@ class Jieba(BaseKeyword):
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
}
dataset_keyword_table = self.dataset.dataset_keyword_table
- keyword_data_source_type = dataset_keyword_table.data_source_type
+ keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
if keyword_data_source_type == "database":
+ if dataset_keyword_table is None:
+ return
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit()
else:
diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
index 84f35c25f8..1ca6303af6 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
@@ -1,4 +1,5 @@
import re
+from collections.abc import Callable
from operator import itemgetter
from typing import cast
@@ -80,12 +81,14 @@ class JiebaKeywordTableHandler:
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
- top_k = kwargs.pop("topK", top_k)
+ top_k = cast(int | None, kwargs.pop("topK", top_k))
+ if top_k is None:
+ top_k = 20
cut = getattr(jieba, "cut", None)
if self._lcut:
tokens = self._lcut(sentence)
elif callable(cut):
- tokens = list(cut(sentence))
+ tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
else:
tokens = re.findall(r"\w+", sentence)
@@ -108,7 +111,7 @@ class JiebaKeywordTableHandler:
sentence=text,
topK=max_keywords_per_chunk,
)
- # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
+ # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
keywords = cast(list[str], keywords)
return set(self._expand_tokens_with_subtokens(set(keywords)))
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index 7e71d67ec0..2997710daf 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -158,7 +158,7 @@ class RetrievalService:
)
if futures:
- for future in concurrent.futures.as_completed(futures, timeout=3600):
+ for _ in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()
diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py
index 4926f44f16..a9995778f7 100644
--- a/api/core/rag/embedding/cached_embedding.py
+++ b/api/core/rag/embedding/cached_embedding.py
@@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from models.dataset import Embedding
diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py
index fbd2a6db93..b679edab36 100644
--- a/api/core/rag/extractor/extract_processor.py
+++ b/api/core/rag/extractor/extract_processor.py
@@ -94,6 +94,7 @@ class ExtractProcessor:
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
+ upload_file = extract_setting.upload_file
with tempfile.TemporaryDirectory() as temp_dir:
upload_file = extract_setting.upload_file
if not file_path:
@@ -104,6 +105,7 @@ class ExtractProcessor:
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
+ assert upload_file is not None, "upload_file is required"
etl_type = dify_config.ETL_TYPE
extractor: BaseExtractor | None = None
if etl_type == "Unstructured":
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index 052fca930d..0330a43b28 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -3,6 +3,7 @@
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
+import inspect
import logging
import mimetypes
import os
@@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
+ _closed: bool
+
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
+ self._closed = False
self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
@@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
+ def close(self) -> None:
+ """Best-effort cleanup for downloaded temporary files."""
+ if getattr(self, "_closed", False):
+ return
+
+ self._closed = True
+ temp_file = getattr(self, "temp_file", None)
+ if temp_file is None:
+ return
+
+ try:
+ close_result = temp_file.close()
+ if inspect.isawaitable(close_result):
+ close_awaitable = getattr(close_result, "close", None)
+ if callable(close_awaitable):
+ close_awaitable()
+ except Exception:
+ logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
+
def __del__(self):
- if hasattr(self, "temp_file"):
- self.temp_file.close()
+ self.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""
diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py
index f8242efe31..7ffa9afafd 100644
--- a/api/core/rag/index_processor/processor/paragraph_index_processor.py
+++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py
@@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 1453fe020b..5631b3a921 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.helper import parse_uuid_str_or_none
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
@@ -517,11 +517,11 @@ class DatasetRetrieval:
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(
diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py
index e617a9660e..426d1b67dc 100644
--- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py
+++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py
@@ -28,7 +28,7 @@ class FunctionCallMultiDatasetRouter:
SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query),
]
- result: LLMResult = model_instance.invoke_llm(
+ result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py
index 2581c354dd..52c9a02f97 100644
--- a/api/core/rag/splitter/fixed_text_splitter.py
+++ b/api/core/rag/splitter/fixed_text_splitter.py
@@ -4,12 +4,12 @@ from __future__ import annotations
import codecs
import re
-from collections.abc import Collection
+from collections.abc import Set as AbstractSet
from typing import Any, Literal
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
-from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
+from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
@@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
cls: type[T],
embedding_model_instance: ModelInstance | None,
- allowed_special: Literal["all"] | set[str] = set(),
- disallowed_special: Literal["all"] | Collection[str] = "all",
+ allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
+ disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any,
) -> T:
def _token_encoder(texts: list[str]) -> list[int]:
@@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return [len(text) for text in texts]
+ _ = _token_encoder # kept for future token-length wiring
return cls(length_function=_character_encoder, **kwargs)
diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py
index 7f2117e2dd..a8d9013fbc 100644
--- a/api/core/rag/splitter/text_splitter.py
+++ b/api/core/rag/splitter/text_splitter.py
@@ -4,7 +4,8 @@ import copy
import logging
import re
from abc import ABC, abstractmethod
-from collections.abc import Callable, Collection, Iterable, Sequence, Set
+from collections.abc import Callable, Iterable, Sequence
+from collections.abc import Set as AbstractSet
from dataclasses import dataclass
from typing import Any, Literal
@@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter):
self,
encoding_name: str = "gpt2",
model_name: str | None = None,
- allowed_special: Literal["all"] | Set[str] = set(),
- disallowed_special: Literal["all"] | Collection[str] = "all",
+ allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
+ disallowed_special: Literal["all"] | AbstractSet[str] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
@@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter):
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
- self._allowed_special = allowed_special
- self._disallowed_special = disallowed_special
+ self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
+ self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
def split_text(self, text: str) -> list[str]:
def _encode(_text: str) -> list[int]:
diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py
index 02625e242f..740d727e26 100644
--- a/api/core/repositories/human_input_repository.py
+++ b/api/core/repositories/human_input_repository.py
@@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
EmailDeliveryMethod,
diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py
index b3424cd9a5..c87e8a3ae0 100644
--- a/api/core/tools/tool_file_manager.py
+++ b/api/core/tools/tool_file_manager.py
@@ -28,7 +28,7 @@ class ToolFileManager:
def _build_graph_file_reference(tool_file: ToolFile) -> File:
extension = guess_extension(tool_file.mimetype) or ".bin"
return File(
- type=get_file_type_by_mime_type(tool_file.mimetype),
+ file_type=get_file_type_by_mime_type(tool_file.mimetype),
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index f4588904d3..87cf6d7085 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -1082,7 +1082,12 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
- variable = variable_pool.get(tool_input.value)
+ variable_selector = tool_input.value
+ if not isinstance(variable_selector, list) or not all(
+ isinstance(selector_part, str) for selector_part in variable_selector
+ ):
+ raise ToolParameterError("Variable tool input must be a variable selector")
+ variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py
index 79d0c114d4..5679466cbc 100644
--- a/api/core/tools/utils/message_transformer.py
+++ b/api/core/tools/utils/message_transformer.py
@@ -41,6 +41,10 @@ def safe_json_value(v):
return v.hex()
elif isinstance(v, memoryview):
return v.tobytes().hex()
+ elif isinstance(v, np.integer):
+ return int(v)
+ elif isinstance(v, np.floating):
+ return float(v)
elif isinstance(v, np.ndarray):
return v.tolist()
elif isinstance(v, dict):
diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py
index 9e1d41cb39..a3623d4ecd 100644
--- a/api/core/tools/utils/model_invocation_utils.py
+++ b/api/core/tools/utils/model_invocation_utils.py
@@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models.tools import ToolModelInvoke
diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py
index ed3ed3e0de..94a2c0427b 100644
--- a/api/core/tools/utils/web_reader_tool.py
+++ b/api/core/tools/utils/web_reader_tool.py
@@ -105,7 +105,7 @@ class Article:
def extract_using_readabilipy(html: str):
- json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
+ json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False)
article = Article(
title=json_article.get("title") or "",
author=json_article.get("byline") or "",
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 52ab605963..cd8c6352b5 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -357,7 +357,10 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
- transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
+ transfer_method_value = file_dict.get("transfer_method")
+ if not isinstance(transfer_method_value, str):
+ raise ValueError("Workflow file mapping is missing a valid transfer_method")
+ transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id
diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py
similarity index 74%
rename from api/core/workflow/human_input_compat.py
rename to api/core/workflow/human_input_adapter.py
index 75a0a0c202..4b765e6aea 100644
--- a/api/core/workflow/human_input_compat.py
+++ b/api/core/workflow/human_input_adapter.py
@@ -1,8 +1,8 @@
-"""Workflow-layer adapters for legacy human-input payload keys.
+"""Workflow-to-Graphon adapters for persisted node payloads.
-Stored workflow graphs and editor payloads may still use Dify-specific human
-input recipient keys. Normalize them here before handing configs to
-`graphon` so graph-owned models only see graph-neutral field names.
+Stored workflow graphs and editor payloads still contain a small set of
+Dify-owned field spellings and value shapes. Adapt them here before handing the
+payload to Graphon so Graphon-owned models only see current contracts.
"""
from __future__ import annotations
@@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
return None
-def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
@@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
- normalized = normalize_human_input_node_data_for_graph(node_data)
+ normalized = adapt_human_input_node_data_for_graph(node_data)
raw_delivery_methods = normalized.get("delivery_methods")
if not isinstance(raw_delivery_methods, list):
return []
@@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
return False
-def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
- if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
- return normalized
- return normalize_human_input_node_data_for_graph(normalized)
+ node_type = normalized.get("type")
+ if node_type == BuiltinNodeTypes.HUMAN_INPUT:
+ return adapt_human_input_node_data_for_graph(normalized)
+ if node_type == BuiltinNodeTypes.TOOL:
+ return _adapt_tool_node_data_for_graph(normalized)
+ return normalized
-def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_config)
if normalized is None:
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
@@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
if data_mapping is None:
return normalized
- normalized["data"] = normalize_node_data_for_graph(data_mapping)
+ normalized["data"] = adapt_node_data_for_graph(data_mapping)
return normalized
+def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
+ normalized = dict(node_data)
+
+ raw_tool_configurations = normalized.get("tool_configurations")
+ if not isinstance(raw_tool_configurations, Mapping):
+ return normalized
+
+ existing_tool_parameters = normalized.get("tool_parameters")
+ normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
+ normalized_tool_configurations: dict[str, Any] = {}
+ found_legacy_tool_inputs = False
+
+ for name, value in raw_tool_configurations.items():
+ if not isinstance(value, Mapping):
+ normalized_tool_configurations[name] = value
+ continue
+
+ input_type = value.get("type")
+ input_value = value.get("value")
+ if input_type not in {"mixed", "variable", "constant"}:
+ normalized_tool_configurations[name] = value
+ continue
+
+ found_legacy_tool_inputs = True
+ normalized_tool_parameters.setdefault(name, dict(value))
+
+ flattened_value = _flatten_legacy_tool_configuration_value(
+ input_type=input_type,
+ input_value=input_value,
+ )
+ if flattened_value is not None:
+ normalized_tool_configurations[name] = flattened_value
+
+ if not found_legacy_tool_inputs:
+ return normalized
+
+ normalized["tool_parameters"] = normalized_tool_parameters
+ normalized["tool_configurations"] = normalized_tool_configurations
+ return normalized
+
+
+def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
+ if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
+ return input_value
+
+ if (
+ input_type == "variable"
+ and isinstance(input_value, list)
+ and all(isinstance(item, str) for item in input_value)
+ ):
+ return "{{#" + ".".join(input_value) + "#}}"
+
+ return None
+
+
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)
@@ -291,9 +349,9 @@ __all__ = [
"MemberRecipient",
"WebAppDeliveryMethod",
"_WebAppDeliveryConfig",
+ "adapt_human_input_node_data_for_graph",
+ "adapt_node_config_for_graph",
+ "adapt_node_data_for_graph",
"is_human_input_webapp_enabled",
- "normalize_human_input_node_data_for_graph",
- "normalize_node_config_for_graph",
- "normalize_node_data_for_graph",
"parse_human_input_delivery_methods",
]
diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py
index 351da3444f..de4eae1b22 100644
--- a/api/core/workflow/node_factory.py
+++ b/api/core/workflow/node_factory.py
@@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
-from core.helper.ssrf_proxy import ssrf_proxy
+from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.trigger.constants import TRIGGER_NODE_TYPES
-from core.workflow.human_input_compat import normalize_node_config_for_graph
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
@@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.file.file_manager import file_manager
from graphon.graph.graph import NodeFactory
from graphon.model_runtime.memory import PromptMessageMemory
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.base.node import Node
from graphon.nodes.code.code_node import WorkflowCodeExecutor
from graphon.nodes.code.entities import CodeLanguage
@@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
+ """Resolve the production node class for the requested type/version."""
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
@@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
- self._http_request_http_client = ssrf_proxy
+ self._http_request_http_client = graphon_ssrf_proxy
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
self._dify_context,
conversation_id_getter=self._conversation_id,
@@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
- typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
+ typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
+ # Graph configs are initially validated against permissive shared node data.
+ # Re-validate using the resolved node class so workflow-local node schemas
+ # stay explicit and constructors receive the concrete typed payload.
+ resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
@@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
- return node_class.validate_node_data(node_data)
+ validate_node_data = getattr(node_class, "validate_node_data", None)
+ if callable(validate_node_data):
+ return cast("BaseNodeData", validate_node_data(node_data))
+ return node_data
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py
index 2e632e56f0..b8725853c4 100644
--- a/api/core/workflow/node_runtime.py
+++ b/api/core/workflow/node_runtime.py
@@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any, Literal, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
@@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
-from .human_input_compat import (
+from .human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
DeliveryMethodType,
@@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
+ @overload
+ def invoke_llm(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ model_parameters: Mapping[str, Any],
+ tools: Sequence[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: Literal[False],
+ ) -> LLMResult: ...
+
+ @overload
+ def invoke_llm(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ model_parameters: Mapping[str, Any],
+ tools: Sequence[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: Literal[True],
+ ) -> Generator[LLMResultChunk, None, None]: ...
+
def invoke_llm(
self,
*,
@@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
stream=stream,
)
+ @overload
+ def invoke_llm_with_structured_output(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Mapping[str, Any],
+ stop: Sequence[str] | None,
+ stream: Literal[False],
+ ) -> LLMResultWithStructuredOutput: ...
+
+ @overload
+ def invoke_llm_with_structured_output(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Mapping[str, Any],
+ stop: Sequence[str] | None,
+ stream: Literal[True],
+ ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
+
def invoke_llm_with_structured_output(
self,
*,
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 7b000101b0..68a24e86b1 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
@@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: AgentNodeData,
+ *,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
- *,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py
index e4f6b3b470..f3006c4242 100644
--- a/api/core/workflow/nodes/datasource/datasource_node.py
+++ b/api/core/workflow/nodes/datasource/datasource_node.py
@@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.workflow.file_reference import resolve_file_record_id
from core.workflow.system_variables import SystemVariableKey, get_system_segment
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
NodeExecutionType,
@@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: DatasourceNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- ):
+ ) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
index d5cab05dbe..9c1b7ab2c4 100644
--- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
+++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
@@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
@@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: KnowledgeIndexNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
- super().__init__(id, config, graph_init_params, graph_runtime_state)
+ super().__init__(
+ node_id=node_id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index 47ad14b499..25f73e446d 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file_reference import parse_file_reference
from graphon.entities import GraphInitParams
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
@@ -50,6 +49,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
+ if value is None or isinstance(value, (str, float)):
+ return value
+ if isinstance(value, int) and not isinstance(value, bool):
+ return value
+ return str(value)
+
+
+def _normalize_metadata_filter_sequence_item(value: object) -> str:
+ return value if isinstance(value, str) else str(value)
+
+
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
@@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: KnowledgeRetrievalNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- ):
+ ) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
+ resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
- resolved_value = segment_group.value[0].to_object()
+ resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
- resolved_values = []
- for v in value: # type: ignore
+ resolved_values: list[str] = []
+ for v in value:
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
- resolved_values.append(segment_group.value[0].to_object())
+ resolved_values.append(
+ _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
+ )
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py
index ce1fa441c2..1d2ad4d445 100644
--- a/api/factories/file_factory/builders.py
+++ b/api/factories/file_factory/builders.py
@@ -148,11 +148,11 @@ def _build_from_local_file(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=row.name,
extension="." + row.extension,
mime_type=row.mime_type,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=row.source_url,
reference=build_file_reference(record_id=str(row.id)),
@@ -196,11 +196,11 @@ def _build_from_remote_url(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
reference=build_file_reference(record_id=str(upload_file.id)),
@@ -222,9 +222,9 @@ def _build_from_remote_url(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=filename,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
@@ -263,9 +263,9 @@ def _build_from_tool_file(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=tool_file.name,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
@@ -306,9 +306,9 @@ def _build_from_datasource_file(
)
return File(
- id=mapping.get("datasource_file_id"),
+ file_id=mapping.get("datasource_file_id"),
filename=datasource_file.name,
- type=file_type,
+ file_type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=datasource_file.source_url,
reference=build_file_reference(record_id=str(datasource_file.id)),
diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py
index b5acbbbcb4..d518114777 100644
--- a/api/fields/_value_type_serializer.py
+++ b/api/fields/_value_type_serializer.py
@@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
- return v.value_type.exposed_type().value
+ return str(v.value_type.exposed_type())
else:
value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
- return value_type.exposed_type().value
+ return str(value_type.exposed_type())
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index cf4a71d545..e4219ba1ee 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
try:
- return str(SegmentType(value).exposed_type().value)
+ return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index f9b5e98936..6e947858ba 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
- "value_type": value.value_type.exposed_type().value,
+ "value_type": str(value.value_type.exposed_type()),
"description": value.description,
}
if isinstance(value, dict):
diff --git a/api/libs/token.py b/api/libs/token.py
index a34db70764..5b043465ac 100644
--- a/api/libs/token.py
+++ b/api/libs/token.py
@@ -47,23 +47,17 @@ def _cookie_domain() -> str | None:
def _real_cookie_name(cookie_name: str) -> str:
if is_secure() and _cookie_domain() is None:
return "__Host-" + cookie_name
- else:
- return cookie_name
+ return cookie_name
def _try_extract_from_header(request: Request) -> str | None:
auth_header = request.headers.get("Authorization")
- if auth_header:
- if " " not in auth_header:
- return None
- else:
- auth_scheme, auth_token = auth_header.split(None, 1)
- auth_scheme = auth_scheme.lower()
- if auth_scheme != "bearer":
- return None
- else:
- return auth_token
- return None
+ if not auth_header or " " not in auth_header:
+ return None
+ auth_scheme, auth_token = auth_header.split(None, 1)
+ if auth_scheme.lower() != "bearer":
+ return None
+ return auth_token
def extract_refresh_token(request: Request) -> str | None:
@@ -90,14 +84,9 @@ def extract_webapp_access_token(request: Request) -> str | None:
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
- def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
- return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
-
- def _try_extract_passport_token_from_header(request: Request) -> str | None:
- return request.headers.get(HEADER_NAME_PASSPORT)
-
- ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request)
- return ret
+ return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) or request.headers.get(
+ HEADER_NAME_PASSPORT
+ )
def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"):
@@ -209,22 +198,18 @@ def check_csrf_token(request: Request, user_id: str):
if not csrf_token:
_unauthorized()
- verified = {}
try:
verified = PassportService().verify(csrf_token)
- except:
+ except Exception:
_unauthorized()
+ raise # unreachable, but helps the type checker see verified is always bound
if verified.get("sub") != user_id:
_unauthorized()
exp: int | None = verified.get("exp")
- if not exp:
+ if not exp or exp < int(datetime.now(UTC).timestamp()):
_unauthorized()
- else:
- time_now = int(datetime.now().timestamp())
- if exp < time_now:
- _unauthorized()
def generate_csrf_token(user_id: str) -> str:
diff --git a/api/libs/url_utils.py b/api/libs/url_utils.py
new file mode 100644
index 0000000000..adcac3add0
--- /dev/null
+++ b/api/libs/url_utils.py
@@ -0,0 +1,3 @@
+def normalize_api_base_url(base_url: str) -> str:
+ """Normalize a base URL to always end with /v1, avoiding double /v1 suffixes."""
+ return base_url.rstrip("/").removesuffix("/v1").rstrip("/") + "/v1"
diff --git a/api/models/comment.py b/api/models/comment.py
index 308339e6f6..1154e16788 100644
--- a/api/models/comment.py
+++ b/api/models/comment.py
@@ -3,6 +3,7 @@
from datetime import datetime
from typing import Optional
+import sqlalchemy as sa
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -36,24 +37,24 @@ class WorkflowComment(Base):
__tablename__ = "workflow_comments"
__table_args__ = (
- db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
+ sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
+ id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
- position_x: Mapped[float] = mapped_column(db.Float)
- position_y: Mapped[float] = mapped_column(db.Float)
- content: Mapped[str] = mapped_column(db.Text, nullable=False)
+ position_x: Mapped[float] = mapped_column(sa.Float)
+ position_y: Mapped[float] = mapped_column(sa.Float)
+ content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
- created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
- db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
- resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
- resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
+ resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
+ resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
# Relationships
@@ -143,20 +144,20 @@ class WorkflowCommentReply(Base):
__tablename__ = "workflow_comment_replies"
__table_args__ = (
- db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
+ sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
- id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
+ id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
- StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
+ StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
- content: Mapped[str] = mapped_column(db.Text, nullable=False)
+ content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
- created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
- db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@@ -187,18 +188,18 @@ class WorkflowCommentMention(Base):
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
- db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
+ sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
- id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
+ id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
- StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
+ StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[str | None] = mapped_column(
- StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
+ StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
diff --git a/api/models/human_input.py b/api/models/human_input.py
index b4c7a634b6..7447d3efcb 100644
--- a/api/models/human_input.py
+++ b/api/models/human_input.py
@@ -6,7 +6,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.orm import Mapped, mapped_column, relationship
-from core.workflow.human_input_compat import DeliveryMethodType
+from core.workflow.human_input_adapter import DeliveryMethodType
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.helper import generate_string
diff --git a/api/models/model.py b/api/models/model.py
index 7fe0731098..a1117fc43a 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -25,6 +25,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from graphon.file import helpers as file_helpers
from libs.helper import generate_string # type: ignore[import-not-found]
+from libs.url_utils import normalize_api_base_url
from libs.uuid_utils import uuidv7
from models.utils.file_input_compat import build_file_from_input_mapping
@@ -446,7 +447,8 @@ class App(Base):
@property
def api_base_url(self) -> str:
- return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"
+ base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/")
+ return normalize_api_base_url(base)
@property
def tenant(self) -> Tenant | None:
diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py
index a2dc8f6157..77dcbd13d4 100644
--- a/api/models/utils/file_input_compat.py
+++ b/api/models/utils/file_input_compat.py
@@ -5,7 +5,8 @@ from functools import lru_cache
from typing import Any
from core.workflow.file_reference import parse_file_reference
-from graphon.file import File, FileTransferMethod
+from graphon.file import File, FileTransferMethod, FileType
+from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
@lru_cache(maxsize=1)
@@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
return tenant_resolver()
+def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
+ """Build a graph `File` directly from serialized metadata."""
+
+ def _coerce_file_type(value: Any) -> FileType:
+ if isinstance(value, FileType):
+ return value
+ if isinstance(value, str):
+ return FileType.value_of(value)
+ raise ValueError("file type is required in file mapping")
+
+ mapping = dict(file_mapping)
+ transfer_method_value = mapping.get("transfer_method")
+ if isinstance(transfer_method_value, FileTransferMethod):
+ transfer_method = transfer_method_value
+ elif isinstance(transfer_method_value, str):
+ transfer_method = FileTransferMethod.value_of(transfer_method_value)
+ else:
+ raise ValueError("transfer_method is required in file mapping")
+
+ file_id = mapping.get("file_id")
+ if not isinstance(file_id, str) or not file_id:
+ legacy_id = mapping.get("id")
+ file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
+
+ related_id = resolve_file_record_id(mapping)
+ if related_id is None:
+ raw_related_id = mapping.get("related_id")
+ related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
+
+ remote_url = mapping.get("remote_url")
+ if not isinstance(remote_url, str) or not remote_url:
+ url = mapping.get("url")
+ remote_url = url if isinstance(url, str) and url else None
+
+ reference = mapping.get("reference")
+ if not isinstance(reference, str) or not reference:
+ reference = None
+
+ filename = mapping.get("filename")
+ if not isinstance(filename, str):
+ filename = None
+
+ extension = mapping.get("extension")
+ if not isinstance(extension, str):
+ extension = None
+
+ mime_type = mapping.get("mime_type")
+ if not isinstance(mime_type, str):
+ mime_type = None
+
+ size = mapping.get("size", -1)
+ if not isinstance(size, int):
+ size = -1
+
+ storage_key = mapping.get("storage_key")
+ if not isinstance(storage_key, str):
+ storage_key = None
+
+ tenant_id = mapping.get("tenant_id")
+ if not isinstance(tenant_id, str):
+ tenant_id = None
+
+ dify_model_identity = mapping.get("dify_model_identity")
+ if not isinstance(dify_model_identity, str):
+ dify_model_identity = FILE_MODEL_IDENTITY
+
+ tool_file_id = mapping.get("tool_file_id")
+ if not isinstance(tool_file_id, str):
+ tool_file_id = None
+
+ upload_file_id = mapping.get("upload_file_id")
+ if not isinstance(upload_file_id, str):
+ upload_file_id = None
+
+ datasource_file_id = mapping.get("datasource_file_id")
+ if not isinstance(datasource_file_id, str):
+ datasource_file_id = None
+
+ return File(
+ file_id=file_id,
+ tenant_id=tenant_id,
+ file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
+ transfer_method=transfer_method,
+ remote_url=remote_url,
+ reference=reference,
+ related_id=related_id,
+ filename=filename,
+ extension=extension,
+ mime_type=mime_type,
+ size=size,
+ storage_key=storage_key,
+ dify_model_identity=dify_model_identity,
+ url=remote_url,
+ tool_file_id=tool_file_id,
+ upload_file_id=upload_file_id,
+ datasource_file_id=datasource_file_id,
+ )
+
+
+def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
+ """Recursively rebuild serialized graph file payloads into `File` objects.
+
+ `graphon` 0.2.2 no longer accepts legacy serialized file mappings via
+ `model_validate_json()`. Dify keeps this recovery path at the model boundary
+ so historical JSON blobs remain readable without reintroducing global graph
+ patches or test-local coercion.
+ """
+ if isinstance(value, list):
+ return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
+
+ if isinstance(value, dict):
+ if maybe_file_object(value):
+ return build_file_from_mapping_without_lookup(file_mapping=value)
+ return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
+
+ return value
+
+
def build_file_from_stored_mapping(
*,
file_mapping: Mapping[str, Any],
@@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
- remote_url = mapping.get("remote_url")
- if not isinstance(remote_url, str) or not remote_url:
- url = mapping.get("url")
- if isinstance(url, str) and url:
- mapping["remote_url"] = url
- return File.model_validate(mapping)
+ return build_file_from_mapping_without_lookup(file_mapping=mapping)
return file_factory.build_from_mapping(
mapping=mapping,
diff --git a/api/models/workflow.py b/api/models/workflow.py
index dfda03c2ee..d127244b0f 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
-from core.workflow.human_input_compat import normalize_node_config_for_graph
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.variable_prefixes import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
-from .utils.file_input_compat import build_file_from_stored_mapping
+from .utils.file_input_compat import (
+ build_file_from_mapping_without_lookup,
+ build_file_from_stored_mapping,
+)
logger = logging.getLogger(__name__)
@@ -290,7 +293,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
- return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
+ return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
@staticmethod
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
@@ -1688,7 +1691,7 @@ class WorkflowDraftVariable(Base):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
- return File.model_validate(normalized_file)
+ return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
@@ -1698,7 +1701,7 @@ class WorkflowDraftVariable(Base):
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
normalized_file.pop("tenant_id", None)
- file_list.append(File.model_validate(normalized_file))
+ file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
return cast(Any, file_list)
else:
return cast(Any, value)
diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
index cfcf6b307e..a8c480e4a5 100644
--- a/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
+++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
@@ -1,7 +1,6 @@
-"""
-Tencent APM tracing implementation with separated concerns
-"""
+"""Tencent APM tracing with idempotent client cleanup."""
+import inspect
import logging
from sqlalchemy import select
@@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
+
+ The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen
+ explicitly in tests or implicitly during garbage collection, so shutdown
+ must be safe to call multiple times.
"""
+ trace_client: TencentTraceClient
+ _closed: bool
+
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
+ self._closed = False
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
@@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance):
except Exception:
logger.debug("[Tencent APM] Failed to record message trace duration")
- def __del__(self):
- """Ensure proper cleanup on garbage collection."""
+ def close(self) -> None:
+ """Synchronously and idempotently shutdown the underlying trace client."""
+ if getattr(self, "_closed", False):
+ return
+
+ self._closed = True
+ trace_client = getattr(self, "trace_client", None)
+ if trace_client is None:
+ return
+
try:
- if hasattr(self, "trace_client"):
- self.trace_client.shutdown()
+ shutdown_result = trace_client.shutdown()
+ if inspect.isawaitable(shutdown_result):
+ close_awaitable = getattr(shutdown_result, "close", None)
+ if callable(close_awaitable):
+ close_awaitable()
except Exception:
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
+
+ def __del__(self):
+ """Ensure best-effort cleanup on garbage collection without retrying shutdown."""
+ self.close()
diff --git a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
index a91a0aa558..54524b09ca 100644
--- a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
@@ -1,5 +1,7 @@
+import gc
import logging
-from unittest.mock import MagicMock, patch
+import warnings
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dify_trace_tencent.config import TencentConfig
@@ -632,13 +634,38 @@ class TestTencentDataTrace:
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_trace_duration(trace_info)
- def test_del(self, tencent_data_trace):
+ def test_close(self, tencent_data_trace):
client = tencent_data_trace.trace_client
- tencent_data_trace.__del__()
+ tencent_data_trace.close()
client.shutdown.assert_called_once()
- def test_del_exception(self, tencent_data_trace):
+ def test_close_is_idempotent(self, tencent_data_trace):
+ client = tencent_data_trace.trace_client
+
+ tencent_data_trace.close()
+ tencent_data_trace.close()
+
+ client.shutdown.assert_called_once()
+
+ def test_close_exception(self, tencent_data_trace):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
- tencent_data_trace.__del__()
+ tencent_data_trace.close()
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
+
+ def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
+ shutdown = AsyncMock()
+ tencent_data_trace.trace_client.shutdown = shutdown
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ tencent_data_trace.close()
+ gc.collect()
+
+ shutdown.assert_called_once()
+ assert not [
+ warning
+ for warning in caught
+ if issubclass(warning.category, RuntimeWarning)
+ and "AsyncMockMixin._execute_mock_call" in str(warning.message)
+ ]
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 12b8b3d782..b13c744f0a 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -6,7 +6,7 @@ requires-python = "~=3.12.0"
dependencies = [
# Legacy: mature and widely deployed
"bleach>=6.3.0",
- "boto3>=1.42.88",
+ "boto3>=1.42.91",
"celery>=5.6.3",
"croniter>=6.2.2",
"flask-cors>=6.0.2",
@@ -30,7 +30,7 @@ dependencies = [
"flask-migrate>=4.1.0,<5.0.0",
"flask-orjson>=2.0.0,<3.0.0",
"flask-restx>=1.3.2,<2.0.0",
- "google-cloud-aiplatform>=1.147.0,<2.0.0",
+ "google-cloud-aiplatform>=1.148.1,<2.0.0",
"httpx[socks]>=0.28.1,<1.0.0",
"opentelemetry-distro>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
@@ -44,9 +44,9 @@ dependencies = [
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]~=0.7.0",
- "graphon~=0.1.2",
+ "graphon~=0.2.2",
"httpx-sse~=0.4.0",
- "json-repair~=0.59.2",
+ "json-repair~=0.59.4",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@@ -173,8 +173,8 @@ dev = [
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
- "pyrefly>=0.60.0",
- "xinference-client>=2.4.0",
+ "pyrefly>=0.61.1",
+ "xinference-client>=2.5.0",
]
############################################################
@@ -183,13 +183,13 @@ dev = [
############################################################
storage = [
"azure-storage-blob>=12.28.0",
- "bce-python-sdk>=0.9.69",
+ "bce-python-sdk>=0.9.70",
"cos-python-sdk-v5>=1.9.41",
"esdk-obs-python>=3.22.2",
"google-cloud-storage>=3.10.1",
"opendal>=0.46.0",
"oss2>=2.19.1",
- "supabase>=2.18.1",
+ "supabase>=2.28.3",
"tos>=2.9.0",
]
@@ -266,7 +266,7 @@ vdb-vastbase = ["dify-vdb-vastbase"]
vdb-vikingdb = ["dify-vdb-vikingdb"]
vdb-weaviate = ["dify-vdb-weaviate"]
# Optional client used by some tests / integrations (not a vector backend plugin)
-vdb-xinference = ["xinference-client>=2.4.0"]
+vdb-xinference = ["xinference-client>=2.5.0"]
trace-all = [
"dify-trace-aliyun",
diff --git a/api/services/account_service.py b/api/services/account_service.py
index ccc4a7c1fa..b6554a3de7 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -112,6 +112,14 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
+ # Phase-bound token metadata for the change-email flow. Tokens carry the
+ # current phase so that downstream endpoints can enforce proper progression
+ CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
+ CHANGE_EMAIL_PHASE_OLD = "old_email"
+ CHANGE_EMAIL_PHASE_OLD_VERIFIED = "old_email_verified"
+ CHANGE_EMAIL_PHASE_NEW = "new_email"
+ CHANGE_EMAIL_PHASE_NEW_VERIFIED = "new_email_verified"
+
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
@@ -576,13 +584,20 @@ class AccountService:
raise ValueError("Email must be provided.")
if not phase:
raise ValueError("phase must be provided.")
+ if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
+ raise ValueError("phase must be one of old_email or new_email.")
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
- code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
+ code, token = cls.generate_change_email_token(
+ account_email,
+ account,
+ old_email=old_email,
+ additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
+ )
send_change_mail_task.delay(
language=language,
diff --git a/api/services/app_service.py b/api/services/app_service.py
index afd98e2975..a046b909b3 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
@@ -303,17 +303,22 @@ class AppService:
return app
- def update_app_icon(self, app: App, icon: str, icon_background: str) -> App:
+ def update_app_icon(
+ self, app: App, icon: str, icon_background: str, icon_type: IconType | str | None = None
+ ) -> App:
"""
Update app icon
:param app: App instance
:param icon: new icon
:param icon_background: new icon_background
+ :param icon_type: new icon type
:return: App instance
"""
assert current_user is not None
app.icon = icon
app.icon_background = icon_background
+ if icon_type is not None:
+ app.icon_type = icon_type if isinstance(icon_type, IconType) else IconType(icon_type)
app.updated_by = current_user.id
app.updated_at = naive_utc_now()
db.session.commit()
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index e6f5f80a6d..894cb05687 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -30,7 +30,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py
index ca84b2a3d8..2e5987dd28 100644
--- a/api/services/hit_testing_service.py
+++ b/api/services/hit_testing_service.py
@@ -1,10 +1,10 @@
import json
import logging
import time
-from typing import Any, TypedDict
+from typing import Any, TypedDict, cast
from core.app.app_config.entities import ModelConfig
-from core.rag.datasource.retrieval_service import RetrievalService
+from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService
from core.rag.index_processor.constant.query_type import QueryType
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@@ -36,6 +36,10 @@ default_retrieval_model = {
}
+class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False):
+ metadata_filtering_conditions: dict[str, Any]
+
+
class HitTestingService:
@classmethod
def retrieve(
@@ -51,17 +55,18 @@ class HitTestingService:
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
- if not retrieval_model:
- retrieval_model = dataset.retrieval_model or default_retrieval_model
- assert isinstance(retrieval_model, dict)
+ resolved_retrieval_model = cast(
+ HitTestingRetrievalModelDict,
+ retrieval_model or dataset.retrieval_model or default_retrieval_model,
+ )
document_ids_filter = None
- metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {})
- if metadata_filtering_conditions and query:
+ metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {})
+ if metadata_filtering_conditions_raw and query:
dataset_retrieval = DatasetRetrieval()
from core.rag.entities import MetadataFilteringCondition
- metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
+ metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw)
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id],
@@ -78,19 +83,21 @@ class HitTestingService:
if metadata_condition and not document_ids_filter:
return cls.compact_retrieve_response(query, [])
all_documents = RetrievalService.retrieve(
- retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)),
+ retrieval_method=RetrievalMethod(
+ resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)
+ ),
dataset_id=dataset.id,
query=query,
attachment_ids=attachment_ids,
- top_k=retrieval_model.get("top_k", 4),
- score_threshold=retrieval_model.get("score_threshold", 0.0)
- if retrieval_model["score_threshold_enabled"]
+ top_k=resolved_retrieval_model.get("top_k", 4),
+ score_threshold=resolved_retrieval_model.get("score_threshold", 0.0)
+ if resolved_retrieval_model["score_threshold_enabled"]
else 0.0,
- reranking_model=retrieval_model.get("reranking_model", None)
- if retrieval_model["reranking_enable"]
+ reranking_model=resolved_retrieval_model.get("reranking_model", None)
+ if resolved_retrieval_model["reranking_enable"]
else None,
- reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
- weights=retrieval_model.get("weights", None),
+ reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model",
+ weights=resolved_retrieval_model.get("weights", None),
document_ids_filter=document_ids_filter,
)
diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py
index 68ef67dec1..8b4983e5f7 100644
--- a/api/services/human_input_delivery_test_service.py
+++ b/api/services/human_input_delivery_test_service.py
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py
index 968600d1bc..9db6682e10 100644
--- a/api/services/rag_pipeline/rag_pipeline.py
+++ b/api/services/rag_pipeline/rag_pipeline.py
@@ -476,7 +476,7 @@ class RagPipelineService:
:param filters: filter by node config parameters.
:return:
"""
- node_type_enum = NodeType(node_type)
+ node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py
index c96050ce13..1529c2b98f 100644
--- a/api/services/variable_truncator.py
+++ b/api/services/variable_truncator.py
@@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator):
return TruncationResult(StringSegment(value=fallback_result.value), True)
# Apply final fallback - convert to JSON string and truncate
- json_str = dumps_with_segments(result.value, ensure_ascii=False)
+ json_str = dumps_with_segments(result.value)
if len(json_str) > self._max_size_bytes:
json_str = json_str[: self._max_size_bytes] + "..."
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index 5ec00ee336..96f936ff9b 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -1067,7 +1067,7 @@ class DraftVariableSaver:
filename = f"{self._generate_filename(name)}.txt"
else:
# For other types, store as JSON
- original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False)
+ original_content_serialized = dumps_with_segments(value_seg.value)
content_type = "application/json"
filename = f"{self._generate_filename(name)}.json"
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index d71223314e..d4b9095ce5 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly,
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.trigger.constants import is_trigger_node_type
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
- normalize_human_input_node_data_for_graph,
+ adapt_human_input_node_data_for_graph,
parse_human_input_delivery_methods,
)
from core.workflow.node_factory import (
@@ -791,7 +791,7 @@ class WorkflowService:
:param filters: filter by node config parameters.
:return:
"""
- node_type_enum = NodeType(node_type)
+ node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
@@ -1096,7 +1096,7 @@ class WorkflowService:
raise ValueError("Node type must be human-input.")
node_data = HumanInputNodeData.model_validate(
- normalize_human_input_node_data_for_graph(node_config["data"]),
+ adapt_human_input_node_data_for_graph(node_config["data"]),
from_attributes=True,
)
delivery_method = self._resolve_human_input_delivery_method(
@@ -1237,9 +1237,10 @@ class WorkflowService:
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
+ node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
node = HumanInputNode(
- id=node_config["id"],
- config=node_config,
+ node_id=node_config["id"],
+ config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(run_context),
@@ -1529,7 +1530,7 @@ class WorkflowService:
from graphon.nodes.human_input.entities import HumanInputNodeData
try:
- HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data))
+ HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data))
except Exception as e:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py
index f8ae3f4b6e..2a60be7762 100644
--- a/api/tasks/mail_human_input_delivery_task.py
+++ b/api/tasks/mail_human_input_delivery_task.py
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
-from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod
+from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod
from extensions.ext_database import db
from extensions.ext_mail import mail
from graphon.runtime import GraphRuntimeState, VariablePool
diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
index b5318aaa2b..2392084c36 100644
--- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
+++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
+from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamCompletedEvent
@@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
- id="n",
- config={
- "id": "n",
- "data": {
- "type": "datasource",
- "version": "1",
- "title": "Datasource",
- "provider_type": "plugin",
- "provider_name": "p",
- "plugin_id": "plug",
- "datasource_name": "ds",
- },
- },
+ node_id="n",
+ config=DatasourceNodeData(
+ type="datasource",
+ version="1",
+ title="Datasource",
+ provider_type="plugin",
+ provider_name="p",
+ plugin_id="plug",
+ datasource_name="ds",
+ ),
graph_init_params=_GP(),
graph_runtime_state=_GS(vp),
)
diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py
index e3476c292b..aaa6092993 100644
--- a/api/tests/integration_tests/workflow/nodes/test_code.py
+++ b/api/tests/integration_tests/workflow/nodes/test_code.py
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import NodeRunResult
from graphon.nodes.code.code_node import CodeNode
+from graphon.nodes.code.entities import CodeNodeData
from graphon.nodes.code.limits import CodeNodeLimits
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -64,8 +65,8 @@ def init_code_node(code_config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = CodeNode(
- id=str(uuid.uuid4()),
- config=code_config,
+ node_id=str(uuid.uuid4()),
+ config=CodeNodeData.model_validate(code_config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_executor=node_factory._code_executor,
diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py
index aa6cf1e021..b9f7b9575b 100644
--- a/api/tests/integration_tests/workflow/nodes/test_http.py
+++ b/api/tests/integration_tests/workflow/nodes/test_http.py
@@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.graph import Graph
-from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig
+from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -75,8 +75,8 @@ def init_http_node(config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=HttpRequestNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
@@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
- id=str(uuid.uuid4()),
- config=graph_config["nodes"][1],
+ node_id=str(uuid.uuid4()),
+ config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index fa5d63cfbf..3eead70163 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import StreamCompletedEvent
+from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
@@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode:
llm_file_saver = MagicMock(spec=LLMFileSaver)
node = LLMNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
index 52886855b8..f2eabb86c3 100644
--- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
+++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
@@ -11,6 +11,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
+from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
@@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = ParameterExtractorNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=ParameterExtractorNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
index 9e3e1a47e3..e2e0723fb8 100644
--- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py
+++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
@@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
+from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.template_rendering import TemplateRenderError
@@ -86,8 +87,8 @@ def test_execute_template_transform():
assert graph is not None
node = TemplateTransformNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=TemplateTransformNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
jinja2_template_renderer=_SimpleJinja2Renderer(),
diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py
index f9ec51ee10..a8e9422c1e 100644
--- a/api/tests/integration_tests/workflow/nodes/test_tool.py
+++ b/api/tests/integration_tests/workflow/nodes/test_tool.py
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import StreamCompletedEvent
from graphon.nodes.protocols import ToolFileManagerProtocol
+from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool.tool_node import ToolNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -60,8 +61,8 @@ def init_tool_node(config: dict):
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py
index 15dec06311..18755ef012 100644
--- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py
+++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py
@@ -234,6 +234,35 @@ class TestAppEndpoints:
}
)
+ def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch):
+ api = app_module.AppIconApi()
+ method = _unwrap(api.post)
+ payload = {
+ "icon": "https://example.com/icon.png",
+ "icon_type": "image",
+ "icon_background": "#FFFFFF",
+ }
+ app_service = MagicMock()
+ app_service.update_app_icon.return_value = SimpleNamespace()
+ response_model = MagicMock()
+ response_model.model_dump.return_value = {"id": "app-1"}
+
+ monkeypatch.setattr(app_module, "AppService", lambda: app_service)
+ monkeypatch.setattr(app_module.AppDetail, "model_validate", MagicMock(return_value=response_model))
+
+ with (
+ app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload),
+ patch.object(type(console_ns), "payload", payload),
+ ):
+ response = method(app_model=SimpleNamespace())
+
+ assert response == {"id": "app-1"}
+ assert app_service.update_app_icon.call_args.args[1:] == (
+ payload["icon"],
+ payload["icon_background"],
+ app_module.IconType.IMAGE,
+ )
+
class TestOpsTraceEndpoints:
@pytest.fixture
diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
index 14d5740072..6524d6ce61 100644
--- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
+++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
index da4f8847d6..5aed230cd4 100644
--- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
+++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
@@ -101,8 +101,8 @@ def _build_graph(
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
- id="start",
- config={"id": "start", "data": start_data.model_dump()},
+ node_id="start",
+ config=start_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
@@ -116,8 +116,8 @@ def _build_graph(
],
)
human_node = HumanInputNode(
- id="human",
- config={"id": "human", "data": human_data.model_dump()},
+ node_id="human",
+ config=human_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
@@ -130,8 +130,8 @@ def _build_graph(
desc=None,
)
end_node = EndNode(
- id="end",
- config={"id": "end", "data": end_data.model_dump()},
+ node_id="end",
+ config=end_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
index 2e207ddc67..35e41035df 100644
--- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
+++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
@@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase):
file_related_id = related_id
return File(
- id=str(uuid4()), # Generate new UUID for File.id
+ file_id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,
diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
index aaf9a85d60..54b7afc018 100644
--- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
+++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
@@ -271,7 +271,7 @@ def _create_recipient(
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
- from core.workflow.human_input_compat import DeliveryMethodType
+ from core.workflow.human_input_adapter import DeliveryMethodType
from models.human_input import ConsoleDeliveryPayload
delivery = HumanInputDelivery(
diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py
index fa57dd4a6f..b695ae9fd9 100644
--- a/api/tests/test_containers_integration_tests/services/test_app_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_app_service.py
@@ -658,15 +658,17 @@ class TestAppService:
# Update app icon
new_icon = "🌟"
new_icon_background = "#FFD93D"
+ new_icon_type = "image"
mock_current_user = create_autospec(Account, instance=True)
mock_current_user.id = account.id
mock_current_user.current_tenant_id = account.current_tenant_id
with patch("services.app_service.current_user", mock_current_user):
- updated_app = app_service.update_app_icon(app, new_icon, new_icon_background)
+ updated_app = app_service.update_app_icon(app, new_icon, new_icon_background, new_icon_type)
assert updated_app.icon == new_icon
assert updated_app.icon_background == new_icon_background
+ assert str(updated_app.icon_type).lower() == new_icon_type
assert updated_app.updated_by == account.id
# Verify other fields remain unchanged
diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
index 18c5320d0a..80f9083e81 100644
--- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
+++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
index 21a54e909e..ed75363f3b 100644
--- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
@@ -8,7 +8,7 @@ import pytest
from sqlalchemy.engine import Engine
from configs import dify_config
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
index 328bdbf055..95a867dbb5 100644
--- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
+++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
@@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py
index 6ff3b19362..e91c0a0597 100644
--- a/api/tests/unit_tests/controllers/console/app/test_workflow.py
+++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py
@@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
file_list = [
File(
tenant_id="t1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="http://u",
)
diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
index b19a1740eb..22b80b748e 100644
--- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
+++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
@@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url():
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
- id="test_file_id",
- type=FileType.IMAGE,
+ file_id="test_file_id",
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
@@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url():
# Create a File object with REMOTE_URL transfer method
test_file = File(
- id="test_file_id",
- type=FileType.IMAGE,
+ file_id="test_file_id",
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py
index 94d6c17915..9465936f28 100644
--- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py
+++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py
@@ -1772,6 +1772,21 @@ class TestDatasetApiBaseUrlApi:
assert response["api_base_url"] == "http://localhost:5000/v1"
+ def test_get_api_base_url_no_double_v1(self, app):
+ api = DatasetApiBaseUrlApi()
+ method = unwrap(api.get)
+
+ with (
+ app.test_request_context("/"),
+ patch(
+ "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL",
+ "https://example.com/v1",
+ ),
+ ):
+ response = method(api)
+
+ assert response["api_base_url"] == "https://example.com/v1"
+
class TestDatasetRetrievalSettingApi:
def test_get_success(self, app):
diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py
index c513be950b..26ff264f18 100644
--- a/api/tests/unit_tests/controllers/console/test_workspace_account.py
+++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py
@@ -68,7 +68,10 @@ class TestChangeEmailSend:
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
- mock_get_change_data.return_value = {"email": "current@example.com"}
+ mock_get_change_data.return_value = {
+ "email": "current@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
+ }
mock_send_email.return_value = "token-abc"
with app.test_request_context(
@@ -85,12 +88,55 @@ class TestChangeEmailSend:
email="new@example.com",
old_email="current@example.com",
language="en-US",
- phase="new_email",
+ phase=AccountService.CHANGE_EMAIL_PHASE_NEW,
)
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.send_change_email_email")
+ @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
+ @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_extract_ip,
+ mock_is_ip_limit,
+ mock_send_email,
+ mock_get_change_data,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ mock_account = _build_account("current@example.com", "acc1")
+ mock_current_account.return_value = (mock_account, None)
+ mock_get_change_data.return_value = {
+ "email": "current@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
+ }
+
+ with app.test_request_context(
+ "/account/change-email",
+ method="POST",
+ json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailSendEmailApi().post()
+
+ mock_send_email.assert_not_called()
+
class TestChangeEmailValidity:
@patch("controllers.console.wraps.db")
@@ -122,7 +168,12 @@ class TestChangeEmailValidity:
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
- mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
+ mock_get_data.return_value = {
+ "email": "user@example.com",
+ "code": "1234",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
+ }
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
@@ -138,11 +189,169 @@ class TestChangeEmailValidity:
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
- "user@example.com", code="1234", old_email="old@example.com", additional_data={}
+ "user@example.com",
+ code="1234",
+ old_email="old@example.com",
+ additional_data={
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
+ },
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_csrf.assert_called_once()
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_upgrade_new_phase_token_to_new_verified(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_rate_limit,
+ mock_get_data,
+ mock_add_rate,
+ mock_revoke_token,
+ mock_generate_token,
+ mock_reset_rate,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
+ mock_is_rate_limit.return_value = False
+ mock_get_data.return_value = {
+ "email": "new@example.com",
+ "code": "1234",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
+ }
+ mock_generate_token.return_value = (None, "new-verified-token")
+
+ with app.test_request_context(
+ "/account/change-email/validity",
+ method="POST",
+ json={"email": "new@example.com", "code": "1234", "token": "token-123"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ response = ChangeEmailCheckApi().post()
+
+ assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"}
+ mock_generate_token.assert_called_once_with(
+ "new@example.com",
+ code="1234",
+ old_email="old@example.com",
+ additional_data={
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
+ },
+ )
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_validity_when_token_phase_is_unknown(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_rate_limit,
+ mock_get_data,
+ mock_add_rate,
+ mock_revoke_token,
+ mock_generate_token,
+ mock_reset_rate,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """A token whose phase marker is a string but not a known transition must be rejected."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
+ mock_is_rate_limit.return_value = False
+ mock_get_data.return_value = {
+ "email": "user@example.com",
+ "code": "1234",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else",
+ }
+
+ with app.test_request_context(
+ "/account/change-email/validity",
+ method="POST",
+ json={"email": "user@example.com", "code": "1234", "token": "token-123"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailCheckApi().post()
+
+ mock_revoke_token.assert_not_called()
+ mock_generate_token.assert_not_called()
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_validity_when_token_has_no_phase(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_rate_limit,
+ mock_get_data,
+ mock_add_rate,
+ mock_revoke_token,
+ mock_generate_token,
+ mock_reset_rate,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
+ mock_is_rate_limit.return_value = False
+ mock_get_data.return_value = {
+ "email": "user@example.com",
+ "code": "1234",
+ "old_email": "old@example.com",
+ }
+
+ with app.test_request_context(
+ "/account/change-email/validity",
+ method="POST",
+ json={"email": "user@example.com", "code": "1234", "token": "token-123"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailCheckApi().post()
+
+ mock_revoke_token.assert_not_called()
+ mock_generate_token.assert_not_called()
+
class TestChangeEmailReset:
@patch("controllers.console.wraps.db")
@@ -175,7 +384,11 @@ class TestChangeEmailReset:
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
- mock_get_data.return_value = {"old_email": "OLD@example.com"}
+ mock_get_data.return_value = {
+ "email": "new@example.com",
+ "old_email": "OLD@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
+ }
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
mock_update_account.return_value = mock_account_after_update
@@ -194,6 +407,155 @@ class TestChangeEmailReset:
mock_send_notify.assert_called_once_with(email="new@example.com")
mock_csrf.assert_called_once()
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
+ @patch("controllers.console.workspace.account.AccountService.update_account_email")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.check_email_unique")
+ @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_reset_when_token_phase_is_not_new_verified(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_freeze,
+ mock_check_unique,
+ mock_get_data,
+ mock_revoke_token,
+ mock_update_account,
+ mock_send_notify,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ current_user = _build_account("old@example.com", "acc3")
+ mock_current_account.return_value = (current_user, None)
+ mock_is_freeze.return_value = False
+ mock_check_unique.return_value = True
+ # Simulate a token straight out of step #1 (phase=old_email) — exactly
+ # the replay used in the advisory PoC.
+ mock_get_data.return_value = {
+ "email": "old@example.com",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
+ }
+
+ with app.test_request_context(
+ "/account/change-email/reset",
+ method="POST",
+ json={"new_email": "attacker@example.com", "token": "token-from-step1"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailResetApi().post()
+
+ mock_revoke_token.assert_not_called()
+ mock_update_account.assert_not_called()
+ mock_send_notify.assert_not_called()
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.workspace.account.current_account_with_tenant")
+ @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
+ @patch("controllers.console.workspace.account.AccountService.update_account_email")
+ @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
+ @patch("controllers.console.workspace.account.AccountService.get_change_email_data")
+ @patch("controllers.console.workspace.account.AccountService.check_email_unique")
+ @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
+ @patch("libs.login.check_csrf_token", return_value=None)
+ @patch("controllers.console.wraps.FeatureService.get_system_features")
+ def test_should_reject_reset_when_token_email_differs_from_payload_new_email(
+ self,
+ mock_features,
+ mock_csrf,
+ mock_is_freeze,
+ mock_check_unique,
+ mock_get_data,
+ mock_revoke_token,
+ mock_update_account,
+ mock_send_notify,
+ mock_current_account,
+ mock_db,
+ app,
+ ):
+ """A verified token for address A must not be replayed to change to address B."""
+ from controllers.console.auth.error import InvalidTokenError
+
+ _mock_wraps_db(mock_db)
+ mock_features.return_value = SimpleNamespace(enable_change_email=True)
+ current_user = _build_account("old@example.com", "acc3")
+ mock_current_account.return_value = (current_user, None)
+ mock_is_freeze.return_value = False
+ mock_check_unique.return_value = True
+ mock_get_data.return_value = {
+ "email": "verified@example.com",
+ "old_email": "old@example.com",
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
+ }
+
+ with app.test_request_context(
+ "/account/change-email/reset",
+ method="POST",
+ json={"new_email": "attacker@example.com", "token": "token-verified"},
+ ):
+ _set_logged_in_user(_build_account("tester@example.com", "tester"))
+ with pytest.raises(InvalidTokenError):
+ ChangeEmailResetApi().post()
+
+ mock_revoke_token.assert_not_called()
+ mock_update_account.assert_not_called()
+ mock_send_notify.assert_not_called()
+
+
+class TestAccountServiceSendChangeEmailEmail:
+ """Service-level coverage for the phase-bound changes in `send_change_email_email`."""
+
+ def test_should_raise_value_error_for_invalid_phase(self):
+ with pytest.raises(ValueError, match="phase must be one of"):
+ AccountService.send_change_email_email(
+ email="user@example.com",
+ old_email="user@example.com",
+ phase="old_email_verified",
+ )
+
+ @patch("services.account_service.send_change_mail_task")
+ @patch("services.account_service.AccountService.change_email_rate_limiter")
+ @patch("services.account_service.AccountService.generate_change_email_token")
+ def test_should_stamp_phase_into_generated_token(
+ self,
+ mock_generate_token,
+ mock_rate_limiter,
+ mock_mail_task,
+ ):
+ mock_rate_limiter.is_rate_limited.return_value = False
+ mock_generate_token.return_value = ("123456", "the-token")
+
+ returned = AccountService.send_change_email_email(
+ email="user@example.com",
+ old_email="user@example.com",
+ language="en-US",
+ phase=AccountService.CHANGE_EMAIL_PHASE_NEW,
+ )
+
+ assert returned == "the-token"
+ mock_generate_token.assert_called_once_with(
+ "user@example.com",
+ None,
+ old_email="user@example.com",
+ additional_data={
+ AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
+ },
+ )
+ mock_mail_task.delay.assert_called_once()
+ mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com")
+
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
index 14c35a9ed5..4fb8ecf784 100644
--- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
+++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
@@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import (
ConversationVariableUpdatePayload,
)
from controllers.service_api.app.error import NotChatAppError
+from fields._value_type_serializer import serialize_value_type
+from graphon.variables import StringSegment
from graphon.variables.types import SegmentType
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@@ -284,6 +286,32 @@ class TestConversationVariableResponseModels:
assert response.created_at == int(created_at.timestamp())
assert response.updated_at == int(created_at.timestamp())
+ def test_variable_response_normalizes_string_value_type_alias(self):
+ response = ConversationVariableResponse.model_validate(
+ {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "foo",
+ "value_type": SegmentType.INTEGER.value,
+ }
+ )
+
+ assert response.value_type == "number"
+
+ def test_variable_response_normalizes_callable_exposed_type(self):
+ response = ConversationVariableResponse.model_validate(
+ {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "foo",
+ "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()),
+ }
+ )
+
+ assert response.value_type == "string"
+
+ def test_serialize_value_type_supports_segments_and_mappings(self):
+ assert serialize_value_type(StringSegment(value="hello")) == "string"
+ assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number"
+
def test_variable_pagination_response(self):
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
{
diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
index 3ab63aed25..dd6cd0e919 100644
--- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
+++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
@@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
def create_test_file(self, file_id: str = "test_file_1") -> File:
"""Create a test File object"""
return File(
- id=file_id,
- type=FileType.DOCUMENT,
+ file_id=file_id,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related_123",
filename=f"{file_id}.txt",
diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py
index a04a7b7576..6104b8d6ca 100644
--- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py
+++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py
@@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow import node_factory as node_factory_module
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.base_node_data import BaseNodeData, RetryConfig
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.entities.pause_reason import SchedulingPause
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus
from graphon.graph import Graph
@@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]):
def version(cls) -> str:
return "1"
- def init_node_data(self, data):
- self._node_data = _StubToolNodeData.model_validate(data)
+ def __init__(
+ self,
+ node_id: str,
+ config: _StubToolNodeData,
+ *,
+ graph_init_params,
+ graph_runtime_state,
+ **_kwargs: Any,
+ ) -> None:
+ super().__init__(
+ node_id=node_id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
def _get_error_strategy(self):
return self._node_data.error_strategy
@@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]):
def _patch_tool_node(mocker):
- original_create_node = DifyNodeFactory.create_node
+ original_resolve_node_class = node_factory_module.resolve_workflow_node_class
- def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
- typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
- node_data = typed_node_config["data"]
- if node_data.type == BuiltinNodeTypes.TOOL:
- return _StubToolNode(
- id=str(typed_node_config["id"]),
- config=typed_node_config,
- graph_init_params=self.graph_init_params,
- graph_runtime_state=self.graph_runtime_state,
- )
- return original_create_node(self, typed_node_config)
+ def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
+ if node_type == BuiltinNodeTypes.TOOL:
+ return _StubToolNode
+ return original_resolve_node_class(node_type=node_type, node_version=node_version)
- mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
+ mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class)
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
index cddd03f4b0..701863b927 100644
--- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
+++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
@@ -26,8 +26,8 @@ def _build_file(
extension: str | None = None,
) -> File:
return File(
- id="file-id",
- type=FileType.IMAGE,
+ file_id="file-id",
+ file_type=FileType.IMAGE,
transfer_method=transfer_method,
reference=reference,
remote_url=remote_url,
@@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
assert runtime.multimodal_send_format == "url"
- with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get:
+ with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
assert runtime.http_get("http://example", follow_redirects=False) == "response"
mock_get.assert_called_once_with("http://example", follow_redirects=False)
diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py
index c4bfb23272..30a068f4c5 100644
--- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py
+++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py
@@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes
class DummyNode:
- def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
- self.id = id
+ def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
+ self.id = node_id
self.config = config
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py
index 81315d2508..deeac49bbc 100644
--- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py
+++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py
@@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE)
built = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tool_file_1",
extension=".png",
@@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
file_in = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tf",
extension=".pdf",
diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py
index a0b2820157..aeca2e3afd 100644
--- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py
+++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py
@@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
# Assert
assert simple_provider.provider == "openai"
- assert simple_provider.label.en_US == "OpenAI"
+ assert simple_provider.label.en_us == "OpenAI"
assert simple_provider.supported_model_types == [ModelType.LLM]
diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py
index bb6e40e224..8cb0938575 100644
--- a/api/tests/unit_tests/core/file/test_models.py
+++ b/api/tests/unit_tests/core/file/test_models.py
@@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType
def test_file():
file = File(
- id="test-file",
+ file_id="test-file",
tenant_id="test-tenant-id",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
@@ -25,27 +25,21 @@ def test_file():
assert file.size == 67
-def test_file_model_validate_accepts_legacy_tenant_id():
- data = {
- "id": "test-file",
- "tenant_id": "test-tenant-id",
- "type": "image",
- "transfer_method": "tool_file",
- "related_id": "test-related-id",
- "filename": "image.png",
- "extension": ".png",
- "mime_type": "image/png",
- "size": 67,
- "storage_key": "test-storage-key",
- "url": "https://example.com/image.png",
- # Extra legacy fields
- "tool_file_id": "tool-file-123",
- "upload_file_id": "upload-file-456",
- "datasource_file_id": "datasource-file-789",
- }
+def test_file_constructor_accepts_legacy_tenant_id():
+ file = File(
+ file_id="test-file",
+ tenant_id="test-tenant-id",
+ file_type=FileType.IMAGE,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ tool_file_id="tool-file-123",
+ filename="image.png",
+ extension=".png",
+ mime_type="image/png",
+ size=67,
+ storage_key="test-storage-key",
+ url="https://example.com/image.png",
+ )
- file = File.model_validate(data)
-
- assert file.related_id == "test-related-id"
+ assert file.related_id == "tool-file-123"
assert file.storage_key == "test-storage-key"
assert "tenant_id" not in file.model_dump()
diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
index 3b5c5e6597..d9fed9ae2a 100644
--- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
+++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
@@ -1,11 +1,17 @@
from unittest.mock import MagicMock, patch
+import httpx
import pytest
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
+ SSRFProxy,
_get_user_provided_host_header,
+ _to_graphon_http_response,
+ graphon_ssrf_proxy,
make_request,
+ max_retries_exceeded_error,
+ request_error,
)
@@ -174,3 +180,56 @@ class TestFollowRedirectsParameter:
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
+
+
+def test_to_graphon_http_response_preserves_httpx_response_fields() -> None:
+ response = httpx.Response(
+ 201,
+ headers={"X-Test": "1"},
+ content=b"payload",
+ request=httpx.Request("GET", "https://example.com/resource"),
+ )
+
+ wrapped = _to_graphon_http_response(response)
+
+ assert wrapped.status_code == 201
+ assert wrapped.headers == {"x-test": "1", "content-length": "7"}
+ assert wrapped.content == b"payload"
+ assert wrapped.url == "https://example.com/resource"
+ assert wrapped.reason_phrase == "Created"
+ assert wrapped.text == "payload"
+
+
+def test_ssrf_proxy_exposes_expected_error_types() -> None:
+ proxy = SSRFProxy()
+
+ assert proxy.max_retries_exceeded_error is max_retries_exceeded_error
+ assert proxy.request_error is request_error
+ assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error
+ assert graphon_ssrf_proxy.request_error is request_error
+
+
+@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"])
+def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None:
+ response = httpx.Response(
+ 200,
+ headers={"X-Test": "1"},
+ content=b"ok",
+ request=httpx.Request("GET", "https://example.com/resource"),
+ )
+
+ with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method:
+ wrapped = getattr(graphon_ssrf_proxy, method_name)(
+ "https://example.com/resource",
+ max_retries=3,
+ headers={"X-Test": "1"},
+ )
+
+ mock_method.assert_called_once_with(
+ url="https://example.com/resource",
+ max_retries=3,
+ headers={"X-Test": "1"},
+ )
+ assert wrapped.status_code == 200
+ assert wrapped.url == "https://example.com/resource"
+ assert wrapped.content == b"ok"
diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py
index 81f8da9a62..bbbffa2e69 100644
--- a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py
+++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py
@@ -971,6 +971,23 @@ class TestHandlePostRequestNew:
assert isinstance(item, SessionMessage)
assert isinstance(item.message.root, JSONRPCError)
assert item.message.root.id == 77
+ assert item.message.root.error.message == "Session terminated by server"
+
+ def test_404_on_initialization_includes_url_in_error(self):
+ t = _new_transport(url="http://example.com/mcp/server/abc123/mcp")
+ q: queue.Queue = queue.Queue()
+ msg = _make_request_msg("initialize", 1)
+ ctx = self._make_ctx(t, q, message=msg)
+ mock_resp = MagicMock()
+ mock_resp.status_code = 404
+ ctx.client.stream = self._stream_ctx(mock_resp)
+ t._handle_post_request(ctx)
+ item = q.get_nowait()
+ assert isinstance(item, SessionMessage)
+ assert isinstance(item.message.root, JSONRPCError)
+ assert item.message.root.error.code == 32600
+ assert "404 Not Found" in item.message.root.error.message
+ assert "http://example.com/mcp/server/abc123/mcp" in item.message.root.error.message
def test_404_for_notification_no_error_sent(self):
t = _new_transport()
diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
index 249ecb5006..c4fd970562 100644
--- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
+++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
@@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderCredentialSchema,
ProviderEntity,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
-from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
-from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
+from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
index 68aa130518..88bf555594 100644
--- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
+++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
@@ -56,7 +56,7 @@ class TestPluginModelRuntime:
assert len(providers) == 1
assert providers[0].provider == "langgenius/openai/openai"
assert providers[0].provider_name == "openai"
- assert providers[0].label.en_US == "OpenAI"
+ assert providers[0].label.en_us == "OpenAI"
client.fetch_model_providers.assert_called_once_with("tenant")
def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None:
diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
index d49b6e4b71..00a4207786 100644
--- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
+++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
@@ -466,7 +466,7 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
file_param = File(
tenant_id="tenant-1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/file.png",
storage_key="",
@@ -499,14 +499,14 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
file_one = File(
tenant_id="tenant-1",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/a.txt",
storage_key="",
)
file_two = File(
tenant_id="tenant-1",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/b.txt",
storage_key="",
diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
index 395d392127..e536c0831f 100644
--- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
@@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
files = [
File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
@@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files():
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query():
memory = MagicMock(spec=TokenBufferMemory)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
index 803afa54d7..28966242d8 100644
--- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
@@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import Conversation
diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
index 9f9ea33695..5308c8e7b3 100644
--- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
@@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey
# from graphon.model_runtime.entities.message_entities import UserPromptMessage
# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
# from graphon.model_runtime.entities.provider_entities import ProviderEntity
-# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
# from core.prompt.prompt_transform import PromptTransform
diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
index 64eb89590a..0220fb6d4a 100644
--- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
+++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
@@ -1,12 +1,14 @@
"""Primarily used for testing merged cell scenarios"""
+import gc
import io
import os
import tempfile
+import warnings
from collections import UserDict
from pathlib import Path
from types import SimpleNamespace
-from unittest.mock import MagicMock
+from unittest.mock import AsyncMock, MagicMock
import pytest
from docx import Document
@@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
WordExtractor("not-a-file", "tenant", "user")
-def test_del_closes_temp_file():
+def test_close_closes_temp_file():
extractor = object.__new__(WordExtractor)
+ extractor._closed = False
extractor.temp_file = MagicMock()
- WordExtractor.__del__(extractor)
+ extractor.close()
extractor.temp_file.close.assert_called_once()
+def test_close_is_idempotent():
+ extractor = object.__new__(WordExtractor)
+ extractor._closed = False
+ extractor.temp_file = MagicMock()
+
+ extractor.close()
+ extractor.close()
+
+ extractor.temp_file.close.assert_called_once()
+
+
+def test_close_handles_async_close_mock():
+ extractor = object.__new__(WordExtractor)
+ extractor._closed = False
+ extractor.temp_file = MagicMock()
+ extractor.temp_file.close = AsyncMock()
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ extractor.close()
+ gc.collect()
+
+ extractor.temp_file.close.assert_called_once()
+ assert not [
+ warning
+ for warning in caught
+ if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message)
+ ]
+
+
def test_extract_images_handles_invalid_external_cases(monkeypatch):
class FakeTargetRef:
def __contains__(self, item):
diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
index 8be1ac318c..18ae9fafc8 100644
--- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
+++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
@@ -14,7 +14,7 @@ from core.repositories.human_input_repository import (
HumanInputFormSubmissionRepository,
_WorkspaceMemberInfo,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py
index 1297a95df1..4248782d93 100644
--- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py
+++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py
@@ -21,7 +21,7 @@ from core.repositories.human_input_repository import (
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py
index f17927f16b..eab0176f41 100644
--- a/api/tests/unit_tests/core/test_file.py
+++ b/api/tests/unit_tests/core/test_file.py
@@ -6,9 +6,9 @@ from models.workflow import Workflow
def test_file_to_dict():
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",
diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py
index f45b43082c..a5a542c94f 100644
--- a/api/tests/unit_tests/core/test_provider_manager.py
+++ b/api/tests/unit_tests/core/test_provider_manager.py
@@ -372,6 +372,78 @@ def test_get_configurations_binds_manager_runtime_to_provider_configuration(
provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
+def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerFixture, mock_provider_entity):
+ manager = _build_provider_manager(mocker)
+ provider_configuration = Mock()
+ provider_factory = Mock()
+ provider_factory.get_providers.return_value = [mock_provider_entity]
+ custom_configuration = SimpleNamespace(provider=None, models=[])
+ system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
+
+ with (
+ patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
+ patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
+ patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
+ patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
+ patch.object(manager, "_get_all_provider_model_settings", return_value={}),
+ patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
+ patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
+ patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
+ patch.object(manager, "_to_system_configuration", return_value=system_configuration),
+ patch.object(manager, "_to_model_settings", return_value=[]),
+ patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls,
+ patch(
+ "core.provider_manager.ProviderConfiguration",
+ return_value=provider_configuration,
+ ) as mock_provider_configuration,
+ ):
+ first = manager.get_configurations("tenant-id")
+ second = manager.get_configurations("tenant-id")
+
+ assert first is second
+ mock_get_all_providers.assert_called_once_with("tenant-id")
+ mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime)
+ mock_provider_configuration.assert_called_once()
+ provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime)
+
+
+def test_clear_configurations_cache_rebuilds_requested_tenant(mocker: MockerFixture, mock_provider_entity):
+ manager = _build_provider_manager(mocker)
+ provider_factory = Mock()
+ provider_factory.get_providers.return_value = [mock_provider_entity]
+ custom_configuration = SimpleNamespace(provider=None, models=[])
+ system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None)
+ provider_configuration_first = Mock()
+ provider_configuration_second = Mock()
+
+ with (
+ patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers,
+ patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}),
+ patch.object(manager, "_get_all_provider_models", return_value={"openai": []}),
+ patch.object(manager, "_get_all_preferred_model_providers", return_value={}),
+ patch.object(manager, "_get_all_provider_model_settings", return_value={}),
+ patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}),
+ patch.object(manager, "_get_all_provider_model_credentials", return_value={}),
+ patch.object(manager, "_to_custom_configuration", return_value=custom_configuration),
+ patch.object(manager, "_to_system_configuration", return_value=system_configuration),
+ patch.object(manager, "_to_model_settings", return_value=[]),
+ patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory),
+ patch(
+ "core.provider_manager.ProviderConfiguration",
+ side_effect=[provider_configuration_first, provider_configuration_second],
+ ) as mock_provider_configuration,
+ ):
+ first = manager.get_configurations("tenant-id")
+ manager.clear_configurations_cache("tenant-id")
+ second = manager.get_configurations("tenant-id")
+
+ assert first is not second
+ assert mock_get_all_providers.call_count == 2
+ assert mock_provider_configuration.call_count == 2
+ provider_configuration_first.bind_model_runtime.assert_called_once_with(manager._model_runtime)
+ provider_configuration_second.bind_model_runtime.assert_called_once_with(manager._model_runtime)
+
+
def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture):
manager = _build_provider_manager(mocker)
provider_configuration = Mock()
diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py
index 72052c8c05..9e07ea1b6d 100644
--- a/api/tests/unit_tests/core/variables/test_segment.py
+++ b/api/tests/unit_tests/core/variables/test_segment.py
@@ -1,8 +1,9 @@
import dataclasses
+from typing import Annotated
import orjson
import pytest
-from pydantic import BaseModel
+from pydantic import BaseModel, Discriminator, Tag
from core.helper import encrypter
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
@@ -12,17 +13,18 @@ from graphon.runtime import VariablePool
from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import (
ArrayAnySegment,
+ ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
+ BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
- SegmentUnion,
StringSegment,
get_segment_discriminator,
)
@@ -47,6 +49,26 @@ from graphon.variables.variables import (
StringVariable,
Variable,
)
+from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup
+
+type SegmentUnion = Annotated[
+ (
+ Annotated[NoneSegment, Tag(SegmentType.NONE)]
+ | Annotated[StringSegment, Tag(SegmentType.STRING)]
+ | Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
+ | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
+ | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
+ | Annotated[FileSegment, Tag(SegmentType.FILE)]
+ | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
+ | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
+ | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
+ | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
+ | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
+ | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
+ | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
+ ),
+ Discriminator(get_segment_discriminator),
+]
def _build_variable_pool(
@@ -123,7 +145,7 @@ def create_test_file(
) -> File:
"""Factory function to create File objects for testing"""
return File(
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
@@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad:
assert restored == model
def test_all_segments_serialization(self):
- """Test serialization/deserialization of all segment types"""
+ """Test file-aware segment serialization through Dify's model boundary."""
# Create one instance of each segment type
test_file = create_test_file()
@@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
- loaded = _Segments.model_validate_json(json_str)
+ loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
@@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
- """Test serialization/deserialization of all variable types"""
+ """Test file-aware variable serialization through Dify's model boundary."""
# Create one instance of each variable type
test_file = create_test_file()
@@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
- loaded = _Variables.model_validate_json(json_str)
+ loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)
diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py
index 94e788edb2..317fe99d37 100644
--- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py
+++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py
@@ -35,7 +35,7 @@ def create_test_file(
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
index 76b2984a4b..9f3e3b00b9 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
@@ -1,12 +1,13 @@
-"""
-Mock node factory for testing workflows with third-party service dependencies.
+"""Mock node factory for third-party-service workflow tests.
-This module provides a MockNodeFactory that automatically detects and mocks nodes
-requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
+The factory follows the same config adaptation path as production
+`DifyNodeFactory.create_node()`, but swaps selected node classes for mock
+implementations before instantiation.
"""
from typing import TYPE_CHECKING, Any
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_factory import DifyNodeFactory
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes, NodeType
@@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory):
:param node_config: Node configuration dictionary
:return: Node instance (real or mocked)
"""
- typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
+ typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
+ node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_type = node_data.type
# Check if this node type should be mocked
if node_type in self._mock_node_types:
- node_id = typed_node_config["id"]
-
# Create mock node instance
mock_class = self._mock_node_types[node_type]
+ resolved_node_data = self._validate_resolved_node_data(mock_class, node_data)
if node_type == BuiltinNodeTypes.CODE:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory):
)
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory):
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
}:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory):
)
else:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory):
return mock_instance
# For non-mocked node types, use parent implementation
- return super().create_node(typed_node_config)
+ return super().create_node(node_config)
def should_mock_node(self, node_type: NodeType) -> bool:
"""
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
index 971b9b2bbf..f9819c47ec 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
@@ -55,13 +55,14 @@ class MockNodeMixin:
def __init__(
self,
- id: str,
- config: Mapping[str, Any],
+ node_id: str,
+ config: Any,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
- ):
+ ) -> None:
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
@@ -96,7 +97,7 @@ class MockNodeMixin:
kwargs.setdefault("message_transformer", MagicMock())
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
index 55a329eba9..75bc6d05f7 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
@@ -139,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
- id=start_config["id"],
- config=start_config,
+ node_id=start_config["id"],
+ config=StartNodeData(title="Start", variables=[]),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -154,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
- id=human_a_config["id"],
- config=human_a_config,
+ node_id=human_a_config["id"],
+ config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -164,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
- id=human_b_config["id"],
- config=human_b_config,
+ node_id=human_b_config["id"],
+ config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -182,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
- id=end_config["id"],
- config=end_config,
+ node_id=end_config["id"],
+ config=end_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
index 9c0ad25b58..76b4cd1ef4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
+++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
@@ -9,6 +9,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.answer.answer_node import AnswerNode
+from graphon.nodes.answer.entities import AnswerNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,20 +67,15 @@ def test_execute_answer():
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "answer",
- "data": {
- "title": "123",
- "type": "answer",
- "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
- },
- }
-
node = AnswerNode(
- id=str(uuid.uuid4()),
+ node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
+ config=AnswerNodeData(
+ title="123",
+ type="answer",
+ answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
+ ),
)
# Mock db.session.close()
diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
index 9cceadde49..d7ef781732 100644
--- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
+from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@@ -77,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
- id="n",
- config={
- "id": "n",
- "data": {
- "type": "datasource",
- "version": "1",
- "title": "Datasource",
- "provider_type": "plugin",
- "provider_name": "p",
- "plugin_id": "plug",
- "datasource_name": "ds",
- },
- },
+ node_id="n",
+ config=DatasourceNodeData(
+ type="datasource",
+ version="1",
+ title="Datasource",
+ provider_type="plugin",
+ provider_name="p",
+ plugin_id="plug",
+ datasource_name="ds",
+ ),
graph_init_params=gp,
graph_runtime_state=gs,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
index a3cadc0681..2e89a2da3c 100644
--- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
@@ -12,7 +12,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
-from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response
+from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config():
assert default_config["retry_config"]["max_retries"] == 7
-def test_get_default_config_with_malformed_http_request_config_raises_value_error():
- with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"):
+def test_get_default_config_with_malformed_http_request_config_raises_type_error():
+ with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"):
HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"})
@@ -114,8 +114,8 @@ def _build_http_node(
start_at=time.perf_counter(),
)
return HttpRequestNode(
- id="http-node",
- config=node_config,
+ node_id="http-node",
+ config=HttpRequestNodeData.model_validate(node_data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
index 1d6a4da7c4..07430498e5 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
@@ -1,4 +1,4 @@
-from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients
+from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients
from graphon.runtime import VariablePool
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
index c0e21d0bf7..0659984c76 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
@@ -19,7 +19,7 @@ from core.repositories.human_input_repository import (
HumanInputFormRecipientEntity,
HumanInputFormRepository,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryMethodType,
EmailDeliveryConfig,
EmailDeliveryMethod,
@@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
entity.status_value = HumanInputFormStatus.SUBMITTED
+def _build_human_input_node(
+ *,
+ node_id: str,
+ node_data: HumanInputNodeData | Mapping[str, Any],
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
+ runtime: DifyHumanInputNodeRuntime,
+) -> HumanInputNode:
+ typed_node_data = (
+ node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data)
+ )
+ return HumanInputNode(
+ node_id=node_id,
+ config=typed_node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ runtime=runtime,
+ )
+
+
class TestDeliveryMethod:
"""Test DeliveryMethod entity."""
@@ -239,7 +259,7 @@ class TestUserAction:
data[field_name] = value
with pytest.raises(ValidationError) as exc_info:
- UserAction(**data)
+ UserAction.model_validate(data)
errors = exc_info.value.errors()
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
@@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent:
form_repository = InMemoryHumanInputFormRepository()
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
index bc98028d5b..4a9438b14f 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
@@ -11,6 +11,7 @@ from graphon.graph_events import (
NodeRunHumanInputFormTimeoutEvent,
NodeRunStartedEvent,
)
+from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.human_input.enums import HumanInputFormStatus
from graphon.nodes.human_input.human_input_node import HumanInputNode
from graphon.runtime import GraphRuntimeState, VariablePool
@@ -25,6 +26,28 @@ class _FakeFormRepository:
return self._form
+def _create_human_input_node(
+ *,
+ config: dict,
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
+ repo: _FakeFormRepository,
+) -> HumanInputNode:
+ node_data = (
+ config["data"]
+ if isinstance(config["data"], HumanInputNodeData)
+ else HumanInputNodeData.model_validate(config["data"])
+ )
+ return HumanInputNode(
+ node_id=config["id"],
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ form_repository=repo,
+ runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ )
+
+
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
@@ -80,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
)
repo = _FakeFormRepository(fake_form)
- return HumanInputNode(
- id="node-1",
+ return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
- form_repository=repo,
- runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ repo=repo,
)
@@ -145,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode:
)
repo = _FakeFormRepository(fake_form)
- return HumanInputNode(
- id="node-1",
+ return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
- form_repository=repo,
- runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ repo=repo,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
index 82cc734274..8ffce39cd6 100644
--- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
@@ -5,6 +5,7 @@ import pytest
from core.workflow.system_variables import default_system_variables
from graphon.entities import GraphInitParams
+from graphon.nodes.iteration.entities import IterationNodeData
from graphon.nodes.iteration.exc import IterationGraphNotFoundError
from graphon.nodes.iteration.iteration_node import IterationNode
from graphon.runtime import (
@@ -44,17 +45,14 @@ def _build_iteration_node(
) -> IterationNode:
init_params = build_test_graph_init_params(graph_config=graph_config)
return IterationNode(
- id="iteration-node",
- config={
- "id": "iteration-node",
- "data": {
- "type": "iteration",
- "title": "Iteration",
- "iterator_selector": ["start", "items"],
- "output_selector": ["iteration-node", "output"],
- "start_node_id": start_node_id,
- },
- },
+ node_id="iteration-node",
+ config=IterationNodeData(
+ type="iteration",
+ title="Iteration",
+ iterator_selector=["start", "items"],
+ output_selector=["iteration-node", "output"],
+ start_node_id=start_node_id,
+ ),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
index a6fca1bfb4..f254fc3d09 100644
--- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
@@ -93,6 +93,25 @@ def sample_chunks():
}
+def _build_node(
+ *,
+ node_id: str,
+ node_data: KnowledgeIndexNodeData | dict[str, object],
+ graph_init_params,
+ graph_runtime_state,
+) -> KnowledgeIndexNode:
+ return KnowledgeIndexNode(
+ node_id=node_id,
+ config=(
+ node_data
+ if isinstance(node_data, KnowledgeIndexNodeData)
+ else KnowledgeIndexNodeData.model_validate(node_data)
+ ),
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+
class TestKnowledgeIndexNode:
"""
Test suite for KnowledgeIndexNode.
@@ -115,9 +134,9 @@ class TestKnowledgeIndexNode:
}
# Act
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -143,9 +162,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -176,9 +195,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -212,9 +231,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -269,9 +288,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -332,9 +351,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -383,9 +402,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -440,9 +459,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -498,9 +517,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -536,9 +555,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -583,9 +602,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
index 45e8ae7d20..e923ee761b 100644
--- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
@@ -14,7 +14,11 @@ from core.workflow.nodes.knowledge_retrieval.entities import (
SingleRetrievalConfig,
)
from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError
-from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
+from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import (
+ KnowledgeRetrievalNode,
+ _normalize_metadata_filter_scalar,
+ _normalize_metadata_filter_sequence_item,
+)
from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
@@ -85,6 +89,12 @@ def sample_node_data():
)
+def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None:
+ assert _normalize_metadata_filter_scalar(3) == 3
+ assert _normalize_metadata_filter_scalar(True) == "True"
+ assert _normalize_metadata_filter_sequence_item(4) == "4"
+
+
class TestKnowledgeRetrievalNode:
"""
Test suite for KnowledgeRetrievalNode.
@@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode:
# Act
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -470,8 +480,8 @@ class TestFetchDatasetRetriever:
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -507,8 +517,8 @@ class TestFetchDatasetRetriever:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -562,8 +572,8 @@ class TestFetchDatasetRetriever:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -610,8 +620,8 @@ class TestFetchDatasetRetriever:
mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme"))
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -671,8 +681,8 @@ class TestFetchDatasetRetriever:
node_id = str(uuid.uuid4())
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
index eca34f05be..388654f279 100644
--- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
+++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
@@ -1,3 +1,4 @@
+from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@@ -5,6 +6,7 @@ import pytest
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from graphon.entities import GraphInitParams
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
+from graphon.nodes.list_operator.entities import ListOperatorNodeData
from graphon.nodes.list_operator.node import ListOperatorNode
from graphon.runtime import GraphRuntimeState
from graphon.variables import ArrayNumberSegment, ArrayStringSegment
@@ -13,11 +15,28 @@ from graphon.variables import ArrayNumberSegment, ArrayStringSegment
class TestListOperatorNode:
"""Comprehensive tests for ListOperatorNode."""
+ @staticmethod
+ def _build_node(*, config, graph_init_params, graph_runtime_state):
+ return ListOperatorNode(
+ node_id="test",
+ config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config),
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+ @staticmethod
+ def _filter_by(comparison_operator: str, value: str) -> dict[str, object]:
+ return {
+ "enabled": True,
+ "conditions": [{"comparison_operator": comparison_operator, "value": value}],
+ }
+
@pytest.fixture
def mock_graph_runtime_state(self):
"""Create mock GraphRuntimeState."""
mock_state = MagicMock(spec=GraphRuntimeState)
mock_variable_pool = MagicMock()
+ mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value)
mock_state.variable_pool = mock_variable_pool
return mock_state
@@ -45,9 +64,8 @@ class TestListOperatorNode:
def _create_node(config, mock_variable):
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
- return ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ return self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -64,9 +82,8 @@ class TestListOperatorNode:
"limit": {"enabled": False},
}
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -109,9 +126,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=[])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -128,11 +144,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "contains",
- "value": "app",
- },
+ "filter_by": self._filter_by("contains", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -140,9 +152,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -157,11 +168,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "not contains",
- "value": "app",
- },
+ "filter_by": self._filter_by("not contains", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -169,9 +176,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -186,11 +192,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": ">",
- "value": "5",
- },
+ "filter_by": self._filter_by(">", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -198,9 +200,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -226,9 +227,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -254,9 +254,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -282,9 +281,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -299,11 +297,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": ">",
- "value": "3",
- },
+ "filter_by": self._filter_by(">", "3"),
"order_by": {
"enabled": True,
"value": "desc",
@@ -317,9 +311,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -341,9 +334,8 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = None
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -366,9 +358,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["first", "middle", "last"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -384,11 +375,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "start with",
- "value": "app",
- },
+ "filter_by": self._filter_by("start with", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -396,9 +383,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -413,11 +399,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "end with",
- "value": "le",
- },
+ "filter_by": self._filter_by("end with", "le"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -425,9 +407,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -442,11 +423,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": "=",
- "value": "5",
- },
+ "filter_by": self._filter_by("=", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -454,9 +431,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -471,11 +447,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": "≠",
- "value": "5",
- },
+ "filter_by": self._filter_by("≠", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -483,9 +455,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -511,9 +482,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
index 4186bbdc93..212ad07bd3 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
@@ -71,8 +71,8 @@ def _build_image_file(
mime_type: str = "image/png",
) -> File:
return File(
- id=file_id,
- type=FileType.IMAGE,
+ file_id=file_id,
+ file_type=FileType.IMAGE,
filename=f"{file_id}{extension}",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=remote_url,
@@ -95,6 +95,8 @@ def variable_pool() -> VariablePool:
def _fetch_prompt_messages_with_mocked_content(content):
variable_pool = VariablePool.empty()
model_instance = mock.MagicMock(spec=ModelInstance)
+ model_schema = mock.MagicMock()
+ model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text"
prompt_template = [
LLMNodeChatModelMessage(
text="You are a classifier.",
@@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content):
with (
mock.patch(
"graphon.nodes.llm.llm_utils.fetch_model_schema",
- return_value=mock.MagicMock(features=[]),
+ return_value=model_schema,
),
mock.patch(
"graphon.nodes.llm.llm_utils.handle_list_messages",
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
index b1f81b6c48..c707cf28cd 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
@@ -140,8 +140,8 @@ def _build_image_file(
mime_type: str = "image/png",
) -> File:
return File(
- id=file_id,
- type=FileType.IMAGE,
+ file_id=file_id,
+ file_type=FileType.IMAGE,
filename=f"{file_id}{extension}",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=remote_url,
@@ -205,14 +205,10 @@ def llm_node(
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol)
- node_config = {
- "id": "1",
- "data": llm_node_data.model_dump(),
- }
http_client = mock.MagicMock()
node = LLMNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers():
def test_fetch_files_with_file_segment():
file = File(
- id="1",
- type=FileType.IMAGE,
+ file_id="1",
+ file_type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
@@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment():
def test_fetch_files_with_array_file_segment():
files = [
File(
- id="1",
- type=FileType.IMAGE,
+ file_id="1",
+ file_type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
),
File(
- id="2",
- type=FileType.IMAGE,
+ file_id="2",
+ file_type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
@@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol)
- node_config = {
- "id": "1",
- "data": llm_node_data.model_dump(),
- }
http_client = mock.MagicMock()
node = LLMNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png",
)
mock_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
@@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/jpg",
)
mock_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
@@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
image_b64_data = base64.b64encode(image_raw_data).decode()
mock_saved_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
filename="test.png",
extension=".png",
@@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable
file_saver=file_saver,
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
reasoning_format="separated",
)
)
@@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out
file_saver=mock.MagicMock(spec=LLMFileSaver),
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
model_instance=_build_prepared_llm_mock(),
reasoning_format="separated",
request_start_time=1.0,
@@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors():
file_saver=mock.MagicMock(spec=LLMFileSaver),
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
model_instance=model_instance,
)
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
index bc44ececd8..892f6cc586 100644
--- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
@@ -13,6 +13,28 @@ from graphon.template_rendering import TemplateRenderError
from tests.workflow_test_utils import build_test_graph_init_params
+def _build_template_transform_node(
+ *,
+ node_data,
+ graph_init_params,
+ graph_runtime_state,
+ node_id: str = "test_node",
+ **kwargs,
+) -> TemplateTransformNode:
+ typed_node_data = (
+ node_data
+ if isinstance(node_data, TemplateTransformNodeData)
+ else TemplateTransformNodeData.model_validate(node_data)
+ )
+ return TemplateTransformNode(
+ node_id=node_id,
+ config=typed_node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ **kwargs,
+ )
+
+
class TestTemplateTransformNode:
"""Comprehensive test suite for TemplateTransformNode."""
@@ -59,9 +81,8 @@ class TestTemplateTransformNode:
def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test that TemplateTransformNode initializes correctly."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -75,9 +96,8 @@ class TestTemplateTransformNode:
def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_title method."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -88,9 +108,8 @@ class TestTemplateTransformNode:
def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_description method."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -108,9 +127,8 @@ class TestTemplateTransformNode:
}
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -143,9 +161,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
with pytest.raises(ValueError, match="max_output_length must be a positive integer"):
- TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -170,9 +187,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -198,9 +214,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Value: "
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -218,9 +233,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error")
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -238,9 +252,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -260,9 +273,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "1234567890"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -302,9 +314,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -375,8 +386,8 @@ class TestTemplateTransformNode:
)
assert mapping == {
- "node_123.var1": ["sys", "input1"],
- "node_123.empty_selector": [],
+ "node_123.var1": ("sys", "input1"),
+ "node_123.empty_selector": (),
}
def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self):
@@ -409,9 +420,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a static message."
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -448,9 +458,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Total: $31.5"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -477,9 +486,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -507,9 +515,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Tags: #python #ai #workflow "
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
index 636237e56e..a846efbb43 100644
--- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
@@ -4,6 +4,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from graphon.nodes.base.entities import VariableSelector
+from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import (
DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH,
TemplateTransformNode,
@@ -37,15 +38,13 @@ def mock_graph_runtime_state():
def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state):
node = TemplateTransformNode(
- id="test_node",
- config={
- "id": "test_node",
- "data": {
- "title": "Template Transform",
- "variables": [],
- "template": "hello",
- },
- },
+ node_id="test_node",
+ config=TemplateTransformNodeData(
+ title="Template Transform",
+ type="template-transform",
+ variables=[],
+ template="hello",
+ ),
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=MagicMock(),
@@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri
assert mapping == {
"node_123.validated": ["sys", "input1"],
- "node_123.raw": ["sys", "input2"],
+ "node_123.raw": ("sys", "input2"),
}
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
index 0522dd9d14..364408ead6 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
@@ -7,7 +7,6 @@ from core.workflow.node_runtime import resolve_dify_run_context
from core.workflow.system_variables import build_system_variables
from graphon.entities import GraphInitParams
from graphon.entities.base_node_data import BaseNodeData
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes
from graphon.nodes.base.node import Node
from graphon.runtime import GraphRuntimeState, VariablePool
@@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
return init_params, runtime_state
-def _build_node_config() -> NodeConfigDict:
- return NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- },
- }
- )
+def _build_node_config() -> dict[str, object]:
+ return {
+ "id": "node-1",
+ "data": _SampleNodeData(
+ type=BuiltinNodeTypes.ANSWER,
+ title="Sample",
+ foo="bar",
+ ),
+ }
+
+
+def _build_node_data() -> _SampleNodeData:
+ return _build_node_config()["data"] # type: ignore[return-value]
def test_node_hydrates_data_during_initialization():
@@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization():
init_params, runtime_state = _build_context(graph_config)
node = _SampleNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum():
)
node = _SampleNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error():
def test_base_node_data_keeps_dict_style_access_compatibility():
- node_data = _SampleNodeData.model_validate(
- {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- }
- )
+ node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar")
assert node_data["foo"] == "bar"
assert node_data.get("foo") == "bar"
@@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility():
def test_node_hydration_preserves_compatibility_extra_fields():
graph_config: dict[str, object] = {}
init_params, runtime_state = _build_context(graph_config)
- node_config = NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- "compat_flag": True,
- },
- }
- )
+ node_config = {
+ "id": "node-1",
+ "data": _SampleNodeData(
+ type=BuiltinNodeTypes.ANSWER,
+ title="Sample",
+ foo="bar",
+ compat_flag=True,
+ ),
+ }
node = _SampleNode(
- id="node-1",
- config=node_config,
+ node_id="node-1",
+ config=node_config["data"],
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
index 87ec2d5bce..dd75b32593 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
@@ -11,14 +11,16 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.file import File, FileTransferMethod
from graphon.node_events import NodeRunResult
from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
+from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError
from graphon.nodes.document_extractor.node import (
_extract_text_from_docx,
_extract_text_from_excel,
+ _extract_text_from_file,
_extract_text_from_pdf,
_extract_text_from_plain_text,
_normalize_docx_zip,
)
-from graphon.variables import ArrayFileSegment
+from graphon.variables import ArrayFileSegment, FileSegment
from graphon.variables.segments import ArrayStringSegment
from graphon.variables.variables import StringVariable
from tests.workflow_test_utils import build_test_graph_init_params
@@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params):
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
- node_config = {"id": "test_node_id", "data": node_data.model_dump()}
http_client = Mock()
node = DocumentExtractorNode(
- id="test_node_id",
- config=node_config,
+ node_id="test_node_id",
+ config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=Mock(),
http_client=http_client,
@@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file):
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"]
- mock_excel_instance.parse.side_effect = [df, Exception("Parse error")]
+ mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_mixed_content"
@@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"]
- mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")]
+ mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_all_bad_sheets"
@@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
assert mock_excel_instance.parse.call_count == 2
+@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook"))
+def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file):
+ with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"):
+ _extract_text_from_excel(b"broken")
+
+
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
"""Test extracting text from Excel file with numeric column names."""
@@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
assert expected_manual == result
+@pytest.mark.parametrize(
+ ("extension", "mime_type"),
+ [
+ (".xlsx", "text/plain"),
+ (None, "application/vnd.ms-excel"),
+ ],
+)
+def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type):
+ file = Mock(spec=File)
+ file.extension = extension
+ file.mime_type = mime_type
+
+ with (
+ patch(
+ "graphon.nodes.document_extractor.node._download_file_content",
+ return_value=b"excel",
+ ),
+ patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_excel",
+ return_value="excel text",
+ ) as mock_extract,
+ ):
+ result = _extract_text_from_file(
+ document_extractor_node.http_client,
+ file,
+ unstructured_api_config=document_extractor_node._unstructured_api_config,
+ )
+
+ assert result == "excel text"
+ mock_extract.assert_called_once_with(b"excel")
+
+
+def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):
+ file = Mock(spec=File)
+ file.extension = None
+ file.mime_type = None
+
+ with patch(
+ "graphon.nodes.document_extractor.node._download_file_content",
+ return_value=b"unknown",
+ ):
+ with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"):
+ _extract_text_from_file(
+ document_extractor_node.http_client,
+ file,
+ unstructured_api_config=document_extractor_node._unstructured_api_config,
+ )
+
+
+def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_list = Mock(spec=ArrayFileSegment)
+ file_list.value = [Mock(spec=File)]
+ mock_graph_runtime_state.variable_pool.get.return_value = file_list
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ side_effect=TextExtractionError("bad file"),
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert result.error == "bad file"
+
+
+def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_segment = Mock(spec=FileSegment)
+ file_segment.value = Mock(spec=File)
+ mock_graph_runtime_state.variable_pool.get.return_value = file_segment
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ side_effect=TextExtractionError("single file failed"),
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert result.error == "single file failed"
+
+
+def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_segment = Mock(spec=FileSegment)
+ file_segment.value = Mock(spec=File)
+ mock_graph_runtime_state.variable_pool.get.return_value = file_segment
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ return_value="single file text",
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs == {"text": "single file text"}
+
+
def _make_docx_zip(use_backslash: bool) -> bytes:
"""Helper to build a minimal in-memory DOCX zip.
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
index 782750e02e..aa9a1360b0 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
@@ -19,6 +19,20 @@ from graphon.variables import ArrayFileSegment
from tests.workflow_test_utils import build_test_graph_init_params
+def _build_if_else_node(
+ *,
+ node_data: IfElseNodeData | dict[str, object],
+ init_params,
+ graph_runtime_state,
+) -> IfElseNode:
+ return IfElseNode(
+ node_id=str(uuid.uuid4()),
+ graph_init_params=init_params,
+ graph_runtime_state=graph_runtime_state,
+ config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
+ )
+
+
def test_execute_if_else_result_true():
graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]}
@@ -61,9 +75,8 @@ def test_execute_if_else_result_true():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "if-else",
- "data": {
+ node = _build_if_else_node(
+ node_data={
"title": "123",
"type": "if-else",
"logical_operator": "and",
@@ -104,13 +117,8 @@ def test_execute_if_else_result_true():
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
- }
-
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
)
# Mock db.session.close()
@@ -155,9 +163,8 @@ def test_execute_if_else_result_false():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "if-else",
- "data": {
+ node = _build_if_else_node(
+ node_data={
"title": "123",
"type": "if-else",
"logical_operator": "or",
@@ -174,13 +181,8 @@ def test_execute_if_else_result_false():
},
],
},
- }
-
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
)
# Mock db.session.close()
@@ -222,11 +224,6 @@ def test_array_file_contains_file_name():
],
)
- node_config = {
- "id": "if-else",
- "data": node_data.model_dump(),
- }
-
# Create properly configured mock for graph_init_params
graph_init_params = Mock()
graph_init_params.workflow_id = "test_workflow"
@@ -242,17 +239,12 @@ def test_array_file_contains_file_name():
}
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=graph_init_params,
- graph_runtime_state=Mock(),
- config=node_config,
- )
+ node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock())
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
filename="ab",
@@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
"logical_operator": "and",
"conditions": [condition.model_dump()],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={"id": "if-else", "data": node_data},
)
# Mock db.session.close()
@@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions():
],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={
- "id": "if-else",
- "data": node_data,
- },
)
# Mock db.session.close()
@@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure():
}
],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={"id": "if-else", "data": node_data},
)
# Mock db.session.close()
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
index b217e4e8e7..465a4c0ff4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
@@ -19,6 +19,15 @@ from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract
from graphon.variables import ArrayFileSegment
+def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode:
+ return ListOperatorNode(
+ node_id="test_node_id",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=MagicMock(),
+ )
+
+
@pytest.fixture
def list_operator_node():
config = {
@@ -35,10 +44,6 @@ def list_operator_node():
"title": "Test Title",
}
node_data = ListOperatorNodeData.model_validate(config)
- node_config = {
- "id": "test_node_id",
- "data": node_data.model_dump(),
- }
# Create properly configured mock for graph_init_params
graph_init_params = MagicMock()
graph_init_params.workflow_id = "test_workflow"
@@ -54,12 +59,7 @@ def list_operator_node():
}
}
- node = ListOperatorNode(
- id="test_node_id",
- config=node_config,
- graph_init_params=graph_init_params,
- graph_runtime_state=MagicMock(),
- )
+ node = _build_list_operator_node(node_data, graph_init_params)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node
@@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node):
files = [
File(
filename="image1.jpg",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related1",
storage_key="",
),
File(
filename="document1.pdf",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related2",
storage_key="",
),
File(
filename="image2.png",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related3",
storage_key="",
),
File(
filename="audio1.mp3",
- type=FileType.AUDIO,
+ file_type=FileType.AUDIO,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related4",
storage_key="",
@@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node):
def test_get_file_extract_string_func():
# Create a File object
file = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename="test_file.txt",
extension=".txt",
@@ -156,7 +156,7 @@ def test_get_file_extract_string_func():
# Test with empty values
empty_file = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename=None,
extension=None,
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
index 543f9878de..5655f80737 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
@@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables):
inputs=user_inputs,
)
- config = {
- "id": "start",
- "data": StartNodeData(title="Start", variables=variables).model_dump(),
- }
+ node_data = StartNodeData(title="Start", variables=variables)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
@@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables):
)
return StartNode(
- id="start",
- config=config,
+ node_id="start",
+ config=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
@@ -109,7 +106,7 @@ def test_json_object_invalid_json_string():
node = make_start_node(user_inputs, variables)
- with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"):
+ with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"):
node._run()
@@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot():
inputs={"profile": {"age": 20, "name": "Tom"}},
)
- config = {
- "id": "start",
- "data": StartNodeData(
- title="Start",
- variables=[
- VariableEntity(
- variable="profile",
- label="profile",
- type=VariableEntityType.JSON_OBJECT,
- required=True,
- )
- ],
- ).model_dump(),
- }
+ node_data = StartNodeData(
+ title="Start",
+ variables=[
+ VariableEntity(
+ variable="profile",
+ label="profile",
+ type=VariableEntityType.JSON_OBJECT,
+ required=True,
+ )
+ ],
+ )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = StartNode(
- id="start",
- config=config,
+ node_id="start",
+ config=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
index c806181340..284af68319 100644
--- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
@@ -13,6 +13,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import StreamChunkEvent, StreamCompletedEvent
+from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variables.segments import ArrayFileSegment
@@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode:
runtime = _StubToolRuntime()
node = ToolNode(
- id="node-instance",
- config=config,
+ node_id="node-instance",
+ config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
@@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode:
return node
-def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
+def _collect_events(generator: Generator) -> list[Any]:
events: list[Any] = []
try:
while True:
events.append(next(generator))
- except StopIteration as stop:
- return events, stop.value
+ except StopIteration:
+ return events
def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]:
@@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li
node_id=tool_node._node_id,
tool_runtime=ToolRuntimeHandle(raw=object()),
)
- return _collect_events(generator)
+ events = _collect_events(generator)
+ completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)]
+ assert completed_events
+ return events, completed_events[-1].node_run_result.llm_usage
def test_link_messages_with_file_populate_files_output(tool_node: ToolNode):
file_obj = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="file-id",
filename="demo.pdf",
@@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode):
def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode):
file_obj = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="file-id",
filename="demo.pdf",
diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
index c8ddc53284..e3b5e3b591 100644
--- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
@@ -1,10 +1,10 @@
from collections.abc import Mapping
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
+from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
from core.workflow.system_variables import build_system_variables
from graphon.entities import GraphInitParams
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.runtime import GraphRuntimeState
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
@@ -27,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
return init_params, runtime_state
-def _build_node_config() -> NodeConfigDict:
- return NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": TRIGGER_PLUGIN_NODE_TYPE,
- "title": "Trigger Event",
- "plugin_id": "plugin-id",
- "provider_id": "provider-id",
- "event_name": "event-name",
- "subscription_id": "subscription-id",
- "plugin_unique_identifier": "plugin-unique-identifier",
- "event_parameters": {},
- },
- }
+def _build_node_data() -> TriggerEventNodeData:
+ return TriggerEventNodeData(
+ type=TRIGGER_PLUGIN_NODE_TYPE,
+ title="Trigger Event",
+ plugin_id="plugin-id",
+ provider_id="provider-id",
+ event_name="event-name",
+ subscription_id="subscription-id",
+ plugin_unique_identifier="plugin-unique-identifier",
+ event_parameters={},
)
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
init_params, runtime_state = _build_context(graph_config={})
node = TriggerEventNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
index 1bbc12b23f..07d03bec05 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
@@ -30,11 +30,6 @@ def create_webhook_node(
tenant_id: str = "test-tenant",
) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
- node_config = {
- "id": "webhook-node-1",
- "data": webhook_data.model_dump(),
- }
-
graph_init_params = GraphInitParams(
workflow_id="test-workflow",
graph_config={},
@@ -56,8 +51,8 @@ def create_webhook_node(
)
node = TriggerWebhookNode(
- id="webhook-node-1",
- config=node_config,
+ node_id="webhook-node-1",
+ config=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -66,10 +61,6 @@ def create_webhook_node(
runtime_state.app_config = Mock()
runtime_state.app_config.tenant_id = tenant_id
- # Provide compatibility alias expected by node implementation
- # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests
- node.node_id = node.id
-
return node
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
index 427afa96ec..b839490d3c 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
@@ -24,11 +24,6 @@ from tests.workflow_test_utils import build_test_variable_pool
def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
- node_config = {
- "id": "1",
- "data": webhook_data.model_dump(),
- }
-
graph_init_params = GraphInitParams(
workflow_id="1",
graph_config={},
@@ -48,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
start_at=0,
)
node = TriggerWebhookNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -57,9 +52,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
# Provide tenant_id for conversion path
runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})()
- # Compatibility alias for some nodes referencing `self.node_id`
- node.node_id = node.id
-
return node
@@ -225,7 +217,7 @@ def test_webhook_node_run_with_file_params():
"""Test webhook node execution with file parameter extraction."""
# Create mock file objects
file1 = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="image.jpg",
@@ -234,7 +226,7 @@ def test_webhook_node_run_with_file_params():
)
file2 = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file2",
filename="document.pdf",
@@ -269,8 +261,19 @@ def test_webhook_node_run_with_file_params():
# Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id
with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory:
- def _to_file(*, mapping):
- return File.model_validate(mapping)
+ def _to_file(*, mapping: dict[str, Any]) -> File:
+ return File(
+ file_id=mapping.get("id"),
+ file_type=FileType(mapping["type"]),
+ transfer_method=FileTransferMethod(mapping["transfer_method"]),
+ related_id=mapping.get("related_id"),
+ filename=mapping.get("filename"),
+ extension=mapping.get("extension"),
+ mime_type=mapping.get("mime_type"),
+ size=mapping.get("size", -1),
+ storage_key=mapping.get("storage_key", ""),
+ remote_url=mapping.get("url"),
+ )
mock_file_factory.side_effect = _to_file
result = node._run()
@@ -284,7 +287,7 @@ def test_webhook_node_run_with_file_params():
def test_webhook_node_run_mixed_parameters():
"""Test webhook node execution with mixed parameter types."""
file_obj = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="test.jpg",
@@ -317,8 +320,19 @@ def test_webhook_node_run_mixed_parameters():
# Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id
with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory:
- def _to_file(*, mapping):
- return File.model_validate(mapping)
+ def _to_file(*, mapping: dict[str, Any]) -> File:
+ return File(
+ file_id=mapping.get("id"),
+ file_type=FileType(mapping["type"]),
+ transfer_method=FileTransferMethod(mapping["transfer_method"]),
+ related_id=mapping.get("related_id"),
+ filename=mapping.get("filename"),
+ extension=mapping.get("extension"),
+ mime_type=mapping.get("mime_type"),
+ size=mapping.get("size", -1),
+ storage_key=mapping.get("storage_key", ""),
+ remote_url=mapping.get("url"),
+ )
mock_file_factory.side_effect = _to_file
result = node._run()
diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py
new file mode 100644
index 0000000000..8b5fceeb37
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py
@@ -0,0 +1,350 @@
+from types import SimpleNamespace
+
+import pytest
+from pydantic import BaseModel
+
+from core.workflow.human_input_adapter import (
+ DeliveryMethodType,
+ EmailDeliveryConfig,
+ EmailDeliveryMethod,
+ EmailRecipients,
+ WebAppDeliveryMethod,
+ _WebAppDeliveryConfig,
+ adapt_human_input_node_data_for_graph,
+ adapt_node_config_for_graph,
+ adapt_node_data_for_graph,
+ is_human_input_webapp_enabled,
+ parse_human_input_delivery_methods,
+)
+from graphon.enums import BuiltinNodeTypes
+from graphon.nodes.base.variable_template_parser import VariableTemplateParser
+
+
+def test_email_delivery_config_helpers_render_and_sanitize_text() -> None:
+ variable_pool = SimpleNamespace(
+ convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42"))
+ )
+
+ rendered = EmailDeliveryConfig.render_body_template(
+ body="Open {{#url#}} and use {{#node.value#}}",
+ url="https://example.com",
+ variable_pool=variable_pool,
+ )
+ sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team")
+ html = EmailDeliveryConfig.render_markdown_body(
+ "**Hello** [mail](mailto:test@example.com)"
+ )
+
+ assert rendered == "Open https://example.com and use 42"
+ assert sanitized == "Hello alert(1) Team"
+ assert "Hello " in html
+ assert " Team")
- html = EmailDeliveryConfig.render_markdown_body(
- "**Hello** [mail](mailto:test@example.com)"
- )
-
- assert rendered == "Open https://example.com and use 42"
- assert sanitized == "Hello alert(1) Team"
- assert "Hello " in html
- assert "