diff --git a/.agents/skills/e2e-cucumber-playwright/SKILL.md b/.agents/skills/e2e-cucumber-playwright/SKILL.md new file mode 100644 index 00000000000..de6b58f26df --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/SKILL.md @@ -0,0 +1,79 @@ +--- +name: e2e-cucumber-playwright +description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository. +--- + +# Dify E2E Cucumber + Playwright + +Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite. + +## Scope + +- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`. +- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead. +- Do not use this skill for backend test or API review tasks under `api/`. + +## Read Order + +1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first. +2. Read only the files directly involved in the task: + - target `.feature` files under `e2e/features/` + - related step files under `e2e/features/step-definitions/` + - `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters + - `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter +3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved. +4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved. +5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern. + +## Local Rules + +- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer. +- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions. +- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps. +- Browser session behavior comes from `features/support/hooks.ts`: + - default: authenticated session with shared storage state + - `@unauthenticated`: clean browser context + - `@authenticated`: readability/selective-run tag only unless implementation changes + - `@fresh`: only for `e2e:full*` flows +- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture. + +## Workflow + +1. Rebuild local context. + - Inspect the target feature area. + - Reuse an existing step when wording and behavior already match. + - Add a new step only for a genuinely new user action or assertion. + - Keep edits close to the current capability folder unless the step is broadly reusable. +2. Write behavior-first scenarios. + - Describe user-observable behavior, not DOM mechanics. + - Keep each scenario focused on one workflow or outcome. + - Keep scenarios independent and re-runnable. +3. Write step definitions in the local style. + - Keep one step to one user-visible action or one assertion. + - Prefer Cucumber Expressions such as `{string}` and `{int}`. + - Scope locators to stable containers when the page has repeated elements. + - Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them. +4. Use Playwright in the local style. + - Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts. + - Use web-first `expect(...)` assertions. + - Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior. +5. Validate narrowly. + - Run the narrowest tagged scenario or flow that exercises the change. + - Run `pnpm -C e2e check`. + - Broaden verification only when the change affects hooks, tags, setup, or shared step semantics. + +## Review Checklist + +- Does the scenario describe behavior rather than implementation? +- Does it fit the current session model, tags, and `DifyWorld` usage? +- Should an existing step be reused instead of adding a new one? +- Are locators user-facing and assertions web-first? +- Does the change introduce hidden coupling across scenarios, tags, or instance state? +- Does it document or implement behavior that differs from the real hooks or configuration? + +Lead findings with correctness, flake risk, and architecture drift. + +## References + +- [`references/playwright-best-practices.md`](references/playwright-best-practices.md) +- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) diff --git a/.agents/skills/e2e-cucumber-playwright/agents/openai.yaml b/.agents/skills/e2e-cucumber-playwright/agents/openai.yaml new file mode 100644 index 00000000000..605cce041d9 --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "E2E Cucumber + Playwright" + short_description: "Write and review Dify E2E scenarios." + default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/." diff --git a/.agents/skills/e2e-cucumber-playwright/references/cucumber-best-practices.md b/.agents/skills/e2e-cucumber-playwright/references/cucumber-best-practices.md new file mode 100644 index 00000000000..d7a1a52852b --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/references/cucumber-best-practices.md @@ -0,0 +1,93 @@ +# Cucumber Best Practices For Dify E2E + +Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite. + +Official sources: + +- https://cucumber.io/docs/guides/10-minute-tutorial/ +- https://cucumber.io/docs/cucumber/step-definitions/ +- https://cucumber.io/docs/cucumber/cucumber-expressions/ + +## What Matters Most + +### 1. Treat scenarios as executable specifications + +Cucumber scenarios should describe examples of behavior, not test implementation recipes. + +Apply it like this: + +- write what the user does and what should happen +- avoid UI-internal wording such as selector details, DOM structure, or component names +- keep language concrete enough that the scenario reads like living documentation + +### 2. Keep scenarios focused + +A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it. + +In Dify's suite, this means: + +- one capability-focused scenario per feature path +- no long setup chains when existing bootstrap or reusable steps already cover them +- no hidden dependency on another scenario's side effects + +### 3. Reuse steps, but only when behavior really matches + +Good reuse reduces duplication. Bad reuse hides meaning. + +Prefer reuse when: + +- the user action is genuinely the same +- the expected outcome is genuinely the same +- the wording stays natural across features + +Write a new step when: + +- the behavior is materially different +- reusing the old wording would make the scenario misleading +- a supposedly generic step would become an implementation-detail wrapper + +### 4. Prefer Cucumber Expressions + +Use Cucumber Expressions for parameters unless regex is clearly necessary. + +Common examples: + +- `{string}` for labels, names, and visible text +- `{int}` for counts +- `{float}` for decimal values +- `{word}` only when the value is truly a single token + +Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler. + +### 5. Keep step definitions thin and meaningful + +Step definitions are glue between Gherkin and automation, not a second abstraction language. + +For Dify: + +- type `this` as `DifyWorld` +- use `async function` +- keep each step to one user-visible action or assertion +- rely on `DifyWorld` and existing support code for shared context +- avoid leaking cross-scenario state + +### 6. Use tags intentionally + +Tags should communicate run scope or session semantics, not become ad hoc metadata. + +In Dify's current suite: + +- capability tags group related scenarios +- `@unauthenticated` changes session behavior +- `@authenticated` is descriptive/selective, not a behavior switch by itself +- `@fresh` belongs to reset/full-install flows only + +If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it. + +## Review Questions + +- Does the scenario read like a real example of product behavior? +- Are the steps behavior-oriented instead of implementation-oriented? +- Is a reused step still truthful in this feature? +- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement? +- Would a new reader understand the outcome without opening the step-definition file? diff --git a/.agents/skills/e2e-cucumber-playwright/references/playwright-best-practices.md b/.agents/skills/e2e-cucumber-playwright/references/playwright-best-practices.md new file mode 100644 index 00000000000..02e763d46b6 --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/references/playwright-best-practices.md @@ -0,0 +1,96 @@ +# Playwright Best Practices For Dify E2E + +Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite. + +Official sources: + +- https://playwright.dev/docs/best-practices +- https://playwright.dev/docs/locators +- https://playwright.dev/docs/test-assertions +- https://playwright.dev/docs/browser-contexts + +## What Matters Most + +### 1. Keep scenarios isolated + +Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`. + +Apply it like this: + +- do not depend on another scenario having run first +- do not persist ad hoc scenario state outside `DifyWorld` +- do not couple ordinary scenarios to `@fresh` behavior +- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes + +### 2. Prefer user-facing locators + +Playwright recommends built-in locators that reflect what users perceive on the page. + +Preferred order in this repository: + +1. `getByRole` +2. `getByLabel` +3. `getByPlaceholder` +4. `getByText` +5. `getByTestId` when an explicit test contract is the most stable option + +Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical. + +Also remember: + +- repeated content usually needs scoping to a stable container +- exact text matching is often too brittle when role/name or label already exists +- `getByTestId` is acceptable when semantics are weak but the contract is intentional + +### 3. Use web-first assertions + +Playwright assertions auto-wait and retry. Prefer them over manual state inspection. + +Prefer: + +- `await expect(page).toHaveURL(...)` +- `await expect(locator).toBeVisible()` +- `await expect(locator).toBeHidden()` +- `await expect(locator).toBeEnabled()` +- `await expect(locator).toHaveText(...)` + +Avoid: + +- `expect(await locator.isVisible()).toBe(true)` +- custom polling loops for DOM state +- `waitForTimeout` as synchronization + +If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit. + +### 4. Let actions wait for actionability + +Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity. + +Good pattern: + +- assert a meaningful visible state when that is part of the behavior +- then click/fill/select via locator APIs + +Bad pattern: + +- stack arbitrary waits before every action +- wait on unstable implementation details instead of the visible state the user cares about + +### 5. Match debugging to the current suite + +Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures: + +- full-page screenshots +- page HTML +- console errors +- page errors + +Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling. + +## Review Questions + +- Would this locator survive DOM refactors that do not change user-visible behavior? +- Is this assertion using Playwright's retrying semantics? +- Is any explicit wait masking a real readiness problem? +- Does this code preserve per-scenario isolation? +- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model? diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 4da070bdbfe..105c979c585 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path: - ✅ **Import real project components** directly (including base components and siblings) - ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers -- ❌ **DO NOT mock** base components (`@/app/components/base/*`) +- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`) - ❌ **DO NOT mock** sibling/child components in the same directory > See [Test Structure Template](#test-structure-template) for correct import/mock patterns. @@ -325,12 +325,12 @@ For more detailed information, refer to: ### Reference Examples in Codebase - `web/utils/classnames.spec.ts` - Utility function tests -- `web/app/components/base/button/index.spec.tsx` - Component tests +- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests - `web/__mocks__/provider-context.ts` - Mock factory example ### Project Configuration -- `web/vitest.config.ts` - Vitest configuration +- `web/vite.config.ts` - Vite/Vitest configuration - `web/vitest.setup.ts` - Test environment setup - `web/scripts/analyze-component.js` - Component analysis tool - Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files. diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md index 10b8fb66f9f..99258498dd1 100644 --- a/.agents/skills/frontend-testing/references/checklist.md +++ b/.agents/skills/frontend-testing/references/checklist.md @@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Integration vs Mocking -- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.) - [ ] Import real project components instead of mocking - [ ] Only mock: API calls, complex context providers, third-party libs with side effects - [ ] Prefer integration testing when using single spec file @@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Mocks -- [ ] **DO NOT mock base components** (`@/app/components/base/*`) +- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`) - [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`) - [ ] Shared mock state reset in `beforeEach` - [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index f58377c4a5f..8c2f1c0c583 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -2,29 +2,27 @@ ## ⚠️ Important: What NOT to Mock -### DO NOT Mock Base Components +### DO NOT Mock Base Components or dify-ui Primitives -**Never mock components from `@/app/components/base/`** such as: +**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as: -- `Loading`, `Spinner` -- `Button`, `Input`, `Select` -- `Tooltip`, `Modal`, `Dropdown` -- `Icon`, `Badge`, `Tag` +- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag` +- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast` **Why?** -- Base components will have their own dedicated tests +- These components have their own dedicated tests - Mocking them creates false positives (tests pass but real integration fails) - Using real components tests actual integration behavior ```typescript -// ❌ WRONG: Don't mock base components +// ❌ WRONG: Don't mock base components or dify-ui primitives vi.mock('@/app/components/base/loading', () => () =>
Loading
) -vi.mock('@/app/components/base/button', () => ({ children }: any) => ) +vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => })) -// ✅ CORRECT: Import and use real base components +// ✅ CORRECT: Import and use the real components import Loading from '@/app/components/base/loading' -import Button from '@/app/components/base/button' +import { Button } from '@langgenius/dify-ui/button' // They will render normally in tests ``` @@ -319,7 +317,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ### ✅ DO -1. **Use real base components** - Import from `@/app/components/base/` directly +1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly 1. **Use real project components** - Prefer importing over mocking 1. **Use real Zustand stores** - Set test state via `store.setState()` 1. **Reset mocks in `beforeEach`**, not `afterEach` @@ -330,7 +328,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ### ❌ DON'T -1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.) 1. **Don't mock Zustand store modules** - Use real stores with `setState()` 1. Don't mock components you can import directly 1. Don't create overly simplified mocks that miss conditional logic @@ -342,7 +340,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ``` Need to use a component in test? │ -├─ Is it from @/app/components/base/*? +├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*? │ └─ YES → Import real component, DO NOT mock │ ├─ Is it a project component? diff --git a/.claude/skills/e2e-cucumber-playwright b/.claude/skills/e2e-cucumber-playwright new file mode 120000 index 00000000000..71b0eae34ff --- /dev/null +++ b/.claude/skills/e2e-cucumber-playwright @@ -0,0 +1 @@ +../../.agents/skills/e2e-cucumber-playwright \ No newline at end of file diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index b92d4c35a89..7460636824d 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -7,7 +7,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.github/dependabot.yml b/.github/dependabot.yml index a183f0b58c1..266fa17c296 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,106 +1,6 @@ version: 2 updates: - - package-ecosystem: "pip" - directory: "/api" - open-pull-requests-limit: 10 - schedule: - interval: "weekly" - groups: - flask: - patterns: - - "flask" - - "flask-*" - - "werkzeug" - - "gunicorn" - google: - patterns: - - "google-*" - - "googleapis-*" - opentelemetry: - patterns: - - "opentelemetry-*" - pydantic: - patterns: - - "pydantic" - - "pydantic-*" - llm: - patterns: - - "langfuse" - - "langsmith" - - "litellm" - - "mlflow*" - - "opik" - - "weave*" - - "arize*" - - "tiktoken" - - "transformers" - database: - patterns: - - "sqlalchemy" - - "psycopg2*" - - "psycogreen" - - "redis*" - - "alembic*" - storage: - patterns: - - "boto3*" - - "botocore*" - - "azure-*" - - "bce-*" - - "cos-python-*" - - "esdk-obs-*" - - "google-cloud-storage" - - "opendal" - - "oss2" - - "supabase*" - - "tos*" - vdb: - patterns: - - "alibabacloud*" - - "chromadb" - - "clickhouse-*" - - "clickzetta-*" - - "couchbase" - - "elasticsearch" - - "opensearch-py" - - "oracledb" - - "pgvect*" - - "pymilvus" - - "pymochow" - - "pyobvector" - - "qdrant-client" - - "intersystems-*" - - "tablestore" - - "tcvectordb" - - "tidb-vector" - - "upstash-*" - - "volcengine-*" - - "weaviate-*" - - "xinference-*" - - "mo-vector" - - "mysql-connector-*" - dev: - patterns: - - "coverage" - - "dotenv-linter" - - "faker" - - "lxml-stubs" - - "basedpyright" - - "ruff" - - "pytest*" - - "types-*" - - "boto3-stubs" - - "hypothesis" - - "pandas-stubs" - - "scipy-stubs" - - "import-linter" - - "celery-types" - - "mypy*" - - "pyrefly" - python-packages: - patterns: - - "*" - package-ecosystem: "uv" directory: "/api" open-pull-requests-limit: 10 diff --git a/.github/labeler.yml b/.github/labeler.yml index d1d324d3813..3b9dc24749f 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,3 +1,10 @@ web: - changed-files: - - any-glob-to-any-file: 'web/**' + - any-glob-to-any-file: + - 'web/**' + - 'packages/**' + - 'package.json' + - 'pnpm-lock.yaml' + - 'pnpm-workspace.yaml' + - '.npmrc' + - '.nvmrc' diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 50dbde2aeed..1e848612ec5 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -7,6 +7,7 @@ ## Summary + ## Screenshots @@ -17,7 +18,7 @@ ## Checklist - [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs) -- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) -- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. -- [x] I've updated the documentation accordingly. -- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods +- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) +- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. +- [ ] I've updated the documentation accordingly. +- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods diff --git a/.github/scripts/generate-i18n-changes.mjs b/.github/scripts/generate-i18n-changes.mjs new file mode 100644 index 00000000000..3d25115ac34 --- /dev/null +++ b/.github/scripts/generate-i18n-changes.mjs @@ -0,0 +1,82 @@ +import { execFileSync } from 'node:child_process' +import fs from 'node:fs' +import path from 'node:path' + +const repoRoot = process.cwd() +const baseSha = process.env.BASE_SHA || '' +const headSha = process.env.HEAD_SHA || '' +const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean) +const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json' + +const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`) + +const readCurrentJson = (fileStem) => { + const filePath = englishPath(fileStem) + if (!fs.existsSync(filePath)) + return null + + return JSON.parse(fs.readFileSync(filePath, 'utf8')) +} + +const readBaseJson = (fileStem) => { + if (!baseSha) + return null + + try { + const relativePath = `web/i18n/en-US/${fileStem}.json` + const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' }) + return JSON.parse(content) + } + catch { + return null + } +} + +const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue) + +const changes = {} + +for (const fileStem of files) { + const currentJson = readCurrentJson(fileStem) + const beforeJson = readBaseJson(fileStem) || {} + const afterJson = currentJson || {} + const added = {} + const updated = {} + const deleted = [] + + for (const [key, value] of Object.entries(afterJson)) { + if (!(key in beforeJson)) { + added[key] = value + continue + } + + if (!compareJson(beforeJson[key], value)) { + updated[key] = { + before: beforeJson[key], + after: value, + } + } + } + + for (const key of Object.keys(beforeJson)) { + if (!(key in afterJson)) + deleted.push(key) + } + + changes[fileStem] = { + fileDeleted: currentJson === null, + added, + updated, + deleted, + } +} + +fs.writeFileSync( + outputPath, + JSON.stringify({ + baseSha, + headSha, + files, + changes, + }) +) diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml deleted file mode 100644 index b0f0a36bc9b..00000000000 --- a/.github/workflows/anti-slop.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Anti-Slop PR Check - -on: - pull_request_target: - types: [opened, edited, synchronize] - -permissions: - pull-requests: write - contents: read - -jobs: - anti-slop: - runs-on: ubuntu-latest - steps: - - uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - close-pr: false - failure-add-pr-labels: "needs-revision" diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index cd967b76cf7..717413937f7 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -35,7 +35,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -54,7 +54,7 @@ jobs: run: uv run --project api bash dev/pytest/pytest_unit_tests.sh - name: Upload unit coverage data - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: api-coverage-unit path: coverage-unit @@ -84,7 +84,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -105,7 +105,7 @@ jobs: run: sh .github/workflows/expose_service_ports.sh - name: Set up Sandbox - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -129,7 +129,7 @@ jobs: api/tests/test_containers_integration_tests - name: Upload integration coverage data - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: api-coverage-integration path: coverage-integration @@ -156,7 +156,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 9648c34274a..35683b112ff 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -25,7 +25,7 @@ jobs: - name: Check Docker Compose inputs if: github.event_name != 'merge_group' id: docker-compose-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | docker/generate_docker_compose @@ -35,18 +35,20 @@ jobs: - name: Check web inputs if: github.event_name != 'merge_group' id: web-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | web/** + packages/** package.json pnpm-lock.yaml pnpm-workspace.yaml + .npmrc .nvmrc - name: Check api inputs if: github.event_name != 'merge_group' id: api-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | api/** @@ -56,7 +58,7 @@ jobs: python-version: "3.11" - if: github.event_name != 'merge_group' - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 - name: Generate Docker Compose if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true' @@ -118,8 +120,7 @@ jobs: - name: ESLint autofix if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' run: | - cd web vp exec eslint --concurrency=2 --prune-suppressions --quiet || true - if: github.event_name != 'merge_group' - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 + uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4 diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index a23edc70e53..5f16fc69271 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -65,7 +65,7 @@ jobs: echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - name: Login to Docker Hub - uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} @@ -81,7 +81,7 @@ jobs: - name: Build Docker image id: build - uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 with: context: ${{ matrix.build_context }} file: ${{ matrix.file }} @@ -101,7 +101,7 @@ jobs: touch "/tmp/digests/${sanitized_digest}" - name: Upload digest - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }} path: /tmp/digests/* @@ -130,7 +130,7 @@ jobs: merge-multiple: true - name: Login to Docker Hub - uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0 + uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index 5991abe3ba1..17b867dd6dd 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -19,7 +19,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -40,7 +40,7 @@ jobs: cp middleware.env.example middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -69,7 +69,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -94,7 +94,7 @@ jobs: sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index cbeb1a3bb1b..5752076c364 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -6,12 +6,7 @@ on: - "main" paths: - api/Dockerfile - - web/docker/** - web/Dockerfile - - package.json - - pnpm-lock.yaml - - pnpm-workspace.yaml - - .nvmrc concurrency: group: docker-build-${{ github.head_ref || github.run_id }} @@ -48,7 +43,7 @@ jobs: uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 - name: Build Docker Image - uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 with: push: false context: ${{ matrix.context }} diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index 104368d1929..ba36b5c07a6 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -65,9 +65,11 @@ jobs: - 'docker/volumes/sandbox/conf/**' web: - 'web/**' + - 'packages/**' - 'package.json' - 'pnpm-lock.yaml' - 'pnpm-workspace.yaml' + - '.npmrc' - '.nvmrc' - '.github/workflows/web-tests.yml' - '.github/actions/setup-web/**' @@ -77,9 +79,11 @@ jobs: - 'api/uv.lock' - 'e2e/**' - 'web/**' + - 'packages/**' - 'package.json' - 'pnpm-lock.yaml' - 'pnpm-workspace.yaml' + - '.npmrc' - '.nvmrc' - 'docker/docker-compose.middleware.yaml' - 'docker/middleware.env.example' @@ -88,6 +92,7 @@ jobs: vdb: - 'api/core/rag/datasource/**' - 'api/tests/integration_tests/vdb/**' + - 'api/providers/vdb/*/tests/**' - '.github/workflows/vdb-tests.yml' - '.github/workflows/expose_service_ports.sh' - 'docker/.env.example' diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index 0278e1e0d33..c55b013dbeb 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -21,7 +21,7 @@ jobs: if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} steps: - name: Download pyrefly diff artifact - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -49,7 +49,7 @@ jobs: run: unzip -o pyrefly_diff.zip - name: Post comment - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -76,13 +76,11 @@ jobs: diff += '\\n\\n... (truncated) ...'; } - const body = diff.trim() - ? '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
' - : '### Pyrefly Diff\nNo changes detected.'; - - await github.rest.issues.createComment({ - issue_number: prNumber, - owner: context.repo.owner, - repo: context.repo.repo, - body, - }); + if (diff.trim()) { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body: '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
', + }); + } diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index 8623d35b043..eb15cd6f758 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -22,7 +22,7 @@ jobs: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true @@ -66,7 +66,7 @@ jobs: echo ${{ github.event.pull_request.number }} > pr_number.txt - name: Upload pyrefly diff - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: pyrefly_diff path: | @@ -75,7 +75,7 @@ jobs: - name: Comment PR with pyrefly diff if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }} - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml new file mode 100644 index 00000000000..3c6c96a664f --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage-comment.yml @@ -0,0 +1,118 @@ +name: Comment with Pyrefly Type Coverage + +on: + workflow_run: + workflows: + - Pyrefly Type Coverage + types: + - completed + +permissions: {} + +jobs: + comment: + name: Comment PR with type coverage + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + issues: write + pull-requests: write + if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} + steps: + - name: Checkout default branch (trusted code) + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Download type coverage artifact + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const artifacts = await github.rest.actions.listWorkflowRunArtifacts({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: ${{ github.event.workflow_run.id }}, + }); + const match = artifacts.data.artifacts.find((artifact) => + artifact.name === 'pyrefly_type_coverage' + ); + if (!match) { + throw new Error('pyrefly_type_coverage artifact not found'); + } + const download = await github.rest.actions.downloadArtifact({ + owner: context.repo.owner, + repo: context.repo.repo, + artifact_id: match.id, + archive_format: 'zip', + }); + fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data)); + + - name: Unzip artifact + run: unzip -o pyrefly_type_coverage.zip + + - name: Render coverage markdown from structured data + id: render + run: | + comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \ + --base base_report.json \ + < pr_report.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } > /tmp/type_coverage_comment.md + + - name: Post comment + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + let prNumber = null; + try { + prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10); + } catch (err) { + const prs = context.payload.workflow_run.pull_requests || []; + if (prs.length > 0 && prs[0].number) { + prNumber = prs[0].number; + } + } + if (!prNumber) { + throw new Error('PR number not found in artifact or workflow_run payload'); + } + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const marker = '### Pyrefly Type Coverage'; + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml new file mode 100644 index 00000000000..0599c94eef1 --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage.yml @@ -0,0 +1,120 @@ +name: Pyrefly Type Coverage + +on: + pull_request: + paths: + - 'api/**/*.py' + +permissions: + contents: read + +jobs: + pyrefly-type-coverage: + runs-on: ubuntu-latest + permissions: + contents: read + issues: write + pull-requests: write + steps: + - name: Checkout PR branch + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Run pyrefly report on PR branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \ + mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \ + echo '{}' > /tmp/pyrefly_report_pr.json + + - name: Save helper script from base branch + run: | + git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \ + || cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py + + - name: Checkout base branch + run: git checkout ${{ github.base_ref }} + + - name: Run pyrefly report on base branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \ + mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \ + echo '{}' > /tmp/pyrefly_report_base.json + + - name: Generate coverage comparison + id: coverage + run: | + comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \ + --base /tmp/pyrefly_report_base.json \ + < /tmp/pyrefly_report_pr.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md + + # Save structured data for the fork-PR comment workflow + cp /tmp/pyrefly_report_pr.json pr_report.json + cp /tmp/pyrefly_report_base.json base_report.json + + - name: Save PR number + run: | + echo ${{ github.event.pull_request.number }} > pr_number.txt + + - name: Upload type coverage artifact + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: pyrefly_type_coverage + path: | + pr_report.json + base_report.json + pr_number.txt + + - name: Comment PR with type coverage + if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const marker = '### Pyrefly Type Coverage'; + let body; + try { + body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + } catch { + body = `${marker}\n\n_Coverage report unavailable._`; + } + const prNumber = context.payload.pull_request.number; + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 5cf52daed20..c74f4a670a8 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -23,8 +23,8 @@ jobs: days-before-issue-stale: 15 days-before-issue-close: 3 repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it." - stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it." + stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it." + stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it." stale-issue-label: 'no-issue-activity' stale-pr-label: 'no-pr-activity' - any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted' + any-of-labels: '🌚 invalid,🙋‍♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted' diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 9bc4ceaa93e..d8c7ebbad3b 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -25,7 +25,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | api/** @@ -33,7 +33,7 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: false python-version: "3.12" @@ -73,13 +73,17 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | web/** + e2e/** + sdks/nodejs-client/** + packages/** package.json pnpm-lock.yaml pnpm-workspace.yaml + .npmrc .nvmrc .github/workflows/style.yml .github/actions/setup-web/** @@ -91,16 +95,16 @@ jobs: - name: Restore ESLint cache if: steps.changed-files.outputs.any_changed == 'true' id: eslint-cache-restore - uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 with: - path: web/.eslintcache - key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + path: .eslintcache + key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} restore-keys: | - ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run lint:ci - name: Web tsslint @@ -110,7 +114,7 @@ jobs: - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run type-check - name: Web dead code check @@ -120,9 +124,9 @@ jobs: - name: Save ESLint cache if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 with: - path: web/.eslintcache + path: .eslintcache key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} superlinter: @@ -138,7 +142,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | **.sh @@ -149,7 +153,7 @@ jobs: .editorconfig - name: Super-linter - uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0 + uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0 if: steps.changed-files.outputs.any_changed == 'true' env: BASH_SEVERITY: warning diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index 536a52b5600..bf33207a147 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -9,6 +9,7 @@ on: - package.json - pnpm-lock.yaml - pnpm-workspace.yaml + - .npmrc concurrency: group: sdk-tests-${{ github.head_ref || github.run_id }} @@ -29,7 +30,7 @@ jobs: persist-credentials: false - name: Use Node.js - uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 with: node-version: 22 cache: '' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 33af4f36fdd..eecbbb1a565 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -68,89 +68,7 @@ jobs: " web/i18n-config/languages.ts | sed 's/[[:space:]]*$//') generate_changes_json() { - node <<'NODE' - const { execFileSync } = require('node:child_process') - const fs = require('node:fs') - const path = require('node:path') - - const repoRoot = process.cwd() - const baseSha = process.env.BASE_SHA || '' - const headSha = process.env.HEAD_SHA || '' - const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean) - - const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`) - - const readCurrentJson = (fileStem) => { - const filePath = englishPath(fileStem) - if (!fs.existsSync(filePath)) - return null - - return JSON.parse(fs.readFileSync(filePath, 'utf8')) - } - - const readBaseJson = (fileStem) => { - if (!baseSha) - return null - - try { - const relativePath = `web/i18n/en-US/${fileStem}.json` - const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' }) - return JSON.parse(content) - } - catch (error) { - return null - } - } - - const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue) - - const changes = {} - - for (const fileStem of files) { - const currentJson = readCurrentJson(fileStem) - const beforeJson = readBaseJson(fileStem) || {} - const afterJson = currentJson || {} - const added = {} - const updated = {} - const deleted = [] - - for (const [key, value] of Object.entries(afterJson)) { - if (!(key in beforeJson)) { - added[key] = value - continue - } - - if (!compareJson(beforeJson[key], value)) { - updated[key] = { - before: beforeJson[key], - after: value, - } - } - } - - for (const key of Object.keys(beforeJson)) { - if (!(key in afterJson)) - deleted.push(key) - } - - changes[fileStem] = { - fileDeleted: currentJson === null, - added, - updated, - deleted, - } - } - - fs.writeFileSync( - '/tmp/i18n-changes.json', - JSON.stringify({ - baseSha, - headSha, - files, - changes, - }) - ) - NODE + node .github/scripts/generate-i18n-changes.mjs } if [ "${{ github.event_name }}" = "repository_dispatch" ]; then @@ -240,7 +158,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.context.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82 + uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} @@ -270,7 +188,7 @@ jobs: Tool rules: - Use Read for repository files. - Use Edit for JSON updates. - - Use Bash only for `pnpm`. + - Use Bash only for `vp`. - Do not use Bash for `git`, `gh`, or branch management. Required execution plan: @@ -292,7 +210,7 @@ jobs: - Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate. - If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth. 4. Run a scoped pre-check before editing: - - `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` - Use this command as the source of truth for missing and extra keys inside the current scope. 5. Apply translations. - For every target language and scoped file: @@ -300,19 +218,19 @@ jobs: - If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed. - ADD missing keys. - UPDATE stale translations when the English value changed. - - DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope. + - DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope. - Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names. - Match the existing terminology and register used by each locale. - Prefer one Edit per file when stable, but prioritize correctness over batching. 6. Verify only the edited files. - - Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- ` - - Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - Run `vp run dify-web#lint:fix --quiet -- ` + - Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` - If verification fails, fix the remaining problems before continuing. 7. Stop after the scoped locale files are updated and verification passes. - Do not create branches, commits, or pull requests. claude_args: | --max-turns 120 - --allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep" + --allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep" - name: Prepare branch metadata id: pr_meta @@ -354,6 +272,7 @@ jobs: - name: Create or update translation PR if: steps.pr_meta.outputs.has_changes == 'true' env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }} FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }} TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }} @@ -402,8 +321,8 @@ jobs: '', '## Verification', '', - `- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``, - `- \`pnpm --dir web lint:fix --quiet -- \``, + `- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``, + `- \`vp run dify-web#lint:fix --quiet -- \``, '', '## Notes', '', diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml index a1ca42b26ee..790ea9126d0 100644 --- a/.github/workflows/trigger-i18n-sync.yml +++ b/.github/workflows/trigger-i18n-sync.yml @@ -42,88 +42,7 @@ jobs: fi export BASE_SHA HEAD_SHA CHANGED_FILES - node <<'NODE' - const { execFileSync } = require('node:child_process') - const fs = require('node:fs') - const path = require('node:path') - - const repoRoot = process.cwd() - const baseSha = process.env.BASE_SHA || '' - const headSha = process.env.HEAD_SHA || '' - const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean) - - const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`) - - const readCurrentJson = (fileStem) => { - const filePath = englishPath(fileStem) - if (!fs.existsSync(filePath)) - return null - - return JSON.parse(fs.readFileSync(filePath, 'utf8')) - } - - const readBaseJson = (fileStem) => { - if (!baseSha) - return null - - try { - const relativePath = `web/i18n/en-US/${fileStem}.json` - const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' }) - return JSON.parse(content) - } - catch (error) { - return null - } - } - - const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue) - - const changes = {} - - for (const fileStem of files) { - const beforeJson = readBaseJson(fileStem) || {} - const afterJson = readCurrentJson(fileStem) || {} - const added = {} - const updated = {} - const deleted = [] - - for (const [key, value] of Object.entries(afterJson)) { - if (!(key in beforeJson)) { - added[key] = value - continue - } - - if (!compareJson(beforeJson[key], value)) { - updated[key] = { - before: beforeJson[key], - after: value, - } - } - } - - for (const key of Object.keys(beforeJson)) { - if (!(key in afterJson)) - deleted.push(key) - } - - changes[fileStem] = { - fileDeleted: readCurrentJson(fileStem) === null, - added, - updated, - deleted, - } - } - - fs.writeFileSync( - '/tmp/i18n-changes.json', - JSON.stringify({ - baseSha, - headSha, - files, - changes, - }) - ) - NODE + node .github/scripts/generate-i18n-changes.mjs if [ -n "$CHANGED_FILES" ]; then echo "has_changes=true" >> "$GITHUB_OUTPUT" @@ -137,7 +56,7 @@ jobs: - name: Trigger i18n sync workflow if: steps.detect.outputs.has_changes == 'true' - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 env: BASE_SHA: ${{ steps.detect.outputs.base_sha }} HEAD_SHA: ${{ steps.detect.outputs.head_sha }} diff --git a/.github/workflows/vdb-tests-full.yml b/.github/workflows/vdb-tests-full.yml index 01d25902f60..b79e8927d75 100644 --- a/.github/workflows/vdb-tests-full.yml +++ b/.github/workflows/vdb-tests-full.yml @@ -36,7 +36,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -65,7 +65,7 @@ jobs: # tiflash - name: Set up Full Vector Store Matrix - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.yaml @@ -89,7 +89,7 @@ jobs: cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env # - name: Check VDB Ready (TiDB) -# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py +# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py - name: Test Vector Stores run: uv run --project api bash dev/pytest/pytest_vdb.sh diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 47ec70f6032..bd13d662c3f 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -33,7 +33,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -62,7 +62,7 @@ jobs: # tiflash - name: Set up Vector Stores for Smoke Coverage - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.yaml @@ -81,12 +81,12 @@ jobs: cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env # - name: Check VDB Ready (TiDB) -# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py +# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py - name: Test Vector Stores run: | uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \ - api/tests/integration_tests/vdb/chroma \ - api/tests/integration_tests/vdb/pgvector \ - api/tests/integration_tests/vdb/qdrant \ - api/tests/integration_tests/vdb/weaviate + api/providers/vdb/vdb-chroma/tests/integration_tests \ + api/providers/vdb/vdb-pgvector/tests/integration_tests \ + api/providers/vdb/vdb-qdrant/tests/integration_tests \ + api/providers/vdb/vdb-weaviate/tests/integration_tests diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml index eb752619be8..6bd4d4f4069 100644 --- a/.github/workflows/web-e2e.yml +++ b/.github/workflows/web-e2e.yml @@ -28,7 +28,7 @@ jobs: uses: ./.github/actions/setup-web - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -53,7 +53,7 @@ jobs: - name: Upload Cucumber report if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: cucumber-report path: e2e/cucumber-report @@ -61,7 +61,7 @@ jobs: - name: Upload E2E logs if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: e2e-logs path: e2e/.logs diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 3c36335e791..2a5cf196452 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -43,7 +43,7 @@ jobs: - name: Upload blob report if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: blob-report-${{ matrix.shardIndex }} path: web/.vitest-reports/* @@ -89,3 +89,37 @@ jobs: flags: web env: CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} + + dify-ui-test: + name: dify-ui Tests + runs-on: ubuntu-latest + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + defaults: + run: + shell: bash + working-directory: ./packages/dify-ui + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup web environment + uses: ./.github/actions/setup-web + + - name: Install Chromium for Browser Mode + run: vp exec playwright install --with-deps chromium + + - name: Run dify-ui tests + run: vp test run --coverage --silent=passed-only + + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 + with: + directory: packages/dify-ui/coverage + flags: dify-ui + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index f703fc02e9b..836bddbb498 100644 --- a/.gitignore +++ b/.gitignore @@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info .vscode/* !.vscode/launch.json.template +!.vscode/settings.example.json !.vscode/README.md api/.vscode # vscode Code History Extension @@ -212,7 +213,7 @@ api/.vscode # pnpm /.pnpm-store -/node_modules +node_modules .vite-hooks/_ # plugin migrate @@ -236,9 +237,15 @@ scripts/stress-test/reports/ .playwright-mcp/ .serena/ +# vitest browser mode attachments (failure screenshots, traces, etc.) +.vitest-attachments/ +**/__screenshots__/ + # settings *.local.json *.local.md # Code Agent Folder .qoder/* + +.eslintcache diff --git a/.npmrc b/.npmrc new file mode 100644 index 00000000000..cffe8cdef13 --- /dev/null +++ b/.npmrc @@ -0,0 +1 @@ +save-exact=true diff --git a/.vite-hooks/pre-commit b/.vite-hooks/pre-commit index 54e09f80d6d..d48381bce26 100755 --- a/.vite-hooks/pre-commit +++ b/.vite-hooks/pre-commit @@ -56,64 +56,9 @@ if $api_modified; then fi fi -if $web_modified; then - if $skip_web_checks; then - echo "Git operation in progress, skipping web checks" - exit 0 - fi - - echo "Running ESLint on web module" - - if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then - web_ts_modified=false - else - ts_diff_status=$? - if [ $ts_diff_status -eq 1 ]; then - web_ts_modified=true - else - echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)." - exit $ts_diff_status - fi - fi - - cd ./web || exit 1 - vp staged - - if $web_ts_modified; then - echo "Running TypeScript type-check:tsgo" - if ! pnpm run type-check:tsgo; then - echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors." - exit 1 - fi - else - echo "No staged TypeScript changes detected, skipping type-check:tsgo" - fi - - echo "Running unit tests check" - modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true) - - if [ -n "$modified_files" ]; then - for file in $modified_files; do - test_file="${file%.*}.spec.ts" - echo "Checking for test file: $test_file" - - # check if the test file exists - if [ -f "../$test_file" ]; then - echo "Detected changes in $file, running corresponding unit tests..." - pnpm run test "../$test_file" - - if [ $? -ne 0 ]; then - echo "Unit tests failed. Please fix the errors before committing." - exit 1 - fi - echo "Unit tests for $file passed." - else - echo "Warning: $file does not have a corresponding test file." - fi - - done - echo "All unit tests for modified web/utils files have passed." - fi - - cd ../ +if $skip_web_checks; then + echo "Git operation in progress, skipping web checks" + exit 0 fi + +vp staged diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index c3e2c50c52f..2611b75c6c0 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -2,21 +2,10 @@ "version": "0.2.0", "configurations": [ { - "name": "Python: Flask API", + "name": "Python: API (gevent)", "type": "debugpy", "request": "launch", - "module": "flask", - "env": { - "FLASK_APP": "app.py", - "FLASK_ENV": "development" - }, - "args": [ - "run", - "--host=0.0.0.0", - "--port=5001", - "--no-debugger", - "--no-reload" - ], + "program": "${workspaceFolder}/api/app.py", "jinja": true, "justMyCode": true, "cwd": "${workspaceFolder}/api", diff --git a/web/.vscode/settings.example.json b/.vscode/settings.example.json similarity index 86% rename from web/.vscode/settings.example.json rename to .vscode/settings.example.json index 4b356f5b7a4..7cdbc51a3bd 100644 --- a/web/.vscode/settings.example.json +++ b/.vscode/settings.example.json @@ -1,12 +1,16 @@ { - // Disable the default formatter, use eslint instead - "prettier.enable": false, - "editor.formatOnSave": false, + "cucumber.features": [ + "e2e/features/**/*.feature", + ], + "cucumber.glue": [ + "e2e/features/**/*.ts", + ], + + "tailwindCSS.experimental.configFile": "web/app/styles/globals.css", // Auto fix "editor.codeActionsOnSave": { "source.fixAll.eslint": "explicit", - "source.organizeImports": "never" }, // Silent the stylistic rules in your IDE, but still auto fix them diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 775401bfa5c..d7f007af679 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process. ## Getting Help If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. - -## Automated Agent Contributions - -> [!NOTE] -> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in. diff --git a/api/.env.example b/api/.env.example index 40055e3704f..3e772be5b96 100644 --- a/api/.env.example +++ b/api/.env.example @@ -33,6 +33,9 @@ TRIGGER_URL=http://localhost:5001 # The time in seconds after the signature is rejected FILES_ACCESS_TIMEOUT=300 +# Collaboration mode toggle +ENABLE_COLLABORATION_MODE=false + # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 @@ -57,6 +60,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts. +# Leave empty to preserve current unprefixed behavior. +REDIS_KEY_PREFIX= # redis Sentinel configuration. REDIS_USE_SENTINEL=false @@ -71,6 +77,13 @@ REDIS_USE_CLUSTERS=false REDIS_CLUSTERS= REDIS_CLUSTERS_PASSWORD= +REDIS_RETRY_RETRIES=3 +REDIS_RETRY_BACKOFF_BASE=1.0 +REDIS_RETRY_BACKOFF_CAP=10.0 +REDIS_SOCKET_TIMEOUT=5.0 +REDIS_SOCKET_CONNECT_TIMEOUT=5.0 +REDIS_HEALTH_CHECK_INTERVAL=30 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 CELERY_BACKEND=redis @@ -102,6 +115,7 @@ S3_BUCKET_NAME=your-bucket-name S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key S3_REGION=your-region +S3_ADDRESS_STYLE=auto # Workflow run and Conversation archive storage (S3-compatible) ARCHIVE_STORAGE_ENABLED=false diff --git a/api/.ruff.toml b/api/.ruff.toml index 2a825f1ef05..dd78024a023 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -69,8 +69,6 @@ ignore = [ "FURB152", # math-constant "UP007", # non-pep604-annotation "UP032", # f-string - "UP045", # non-pep604-annotation-optional - "B005", # strip-with-multi-characters "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg @@ -84,7 +82,6 @@ ignore = [ "SIM102", # collapsible-if "SIM103", # needless-bool "SIM105", # suppressible-exception - "SIM107", # return-in-try-except-finally "SIM108", # if-else-block-instead-of-if-exp "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements @@ -93,32 +90,22 @@ ignore = [ ] [lint.per-file-ignores] -"__init__.py" = [ - "F401", # unused-import - "F811", # redefined-while-unused -] "configs/*" = [ "N802", # invalid-function-name ] -"graphon/model_runtime/callbacks/base_callback.py" = ["T201"] -"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"] "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name ] "tests/*" = [ - "F811", # redefined-while-unused "T201", # allow print in tests, "S110", # allow ignoring exceptions in tests code (currently) - ] -"controllers/console/explore/trial.py" = ["TID251"] -"controllers/console/human_input_form.py" = ["TID251"] -"controllers/web/human_input_form.py" = ["TID251"] - -[lint.flake8-tidy-imports] [lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"] msg = "Use Pydantic payload/query models instead of reqparse." [lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"] msg = "Use Pydantic payload/query models instead of reqparse." + +[lint.isort] +known-first-party = ["graphon"] \ No newline at end of file diff --git a/api/.vscode/launch.json.example b/api/.vscode/launch.json.example index 6bdfa2c0399..1001559176e 100644 --- a/api/.vscode/launch.json.example +++ b/api/.vscode/launch.json.example @@ -3,29 +3,21 @@ "compounds": [ { "name": "Launch Flask and Celery", - "configurations": ["Python: Flask", "Python: Celery"] + "configurations": ["Python: API (gevent)", "Python: Celery"] } ], "configurations": [ { - "name": "Python: Flask", - "consoleName": "Flask", + "name": "Python: API (gevent)", + "consoleName": "API", "type": "debugpy", "request": "launch", "python": "${workspaceFolder}/.venv/bin/python", "cwd": "${workspaceFolder}", "envFile": ".env", - "module": "flask", + "program": "${workspaceFolder}/app.py", "justMyCode": true, - "jinja": true, - "env": { - "FLASK_APP": "app.py", - "GEVENT_SUPPORT": "True" - }, - "args": [ - "run", - "--port=5001" - ] + "jinja": true }, { "name": "Python: Celery", diff --git a/api/Dockerfile b/api/Dockerfile index 7e0a4399545..6098652573c 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -21,8 +21,9 @@ RUN apt-get update \ # for building gmpy2 libmpfr-dev libmpc-dev -# Install Python dependencies +# Install Python dependencies (workspace members under providers/vdb/) COPY pyproject.toml uv.lock ./ +COPY providers ./providers RUN uv sync --locked --no-dev # production stage diff --git a/api/app.py b/api/app.py index c018c8a0457..e53b037be50 100644 --- a/api/app.py +++ b/api/app.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import sys from typing import TYPE_CHECKING, cast @@ -9,17 +10,35 @@ if TYPE_CHECKING: celery: Celery +HOST = "0.0.0.0" +PORT = 5001 +logger = logging.getLogger(__name__) + + def is_db_command() -> bool: if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db": return True return False +def log_startup_banner(host: str, port: int) -> None: + debugger_attached = sys.gettrace() is not None + logger.info("Serving Dify API via gevent WebSocket server") + logger.info("Bound to http://%s:%s", host, port) + logger.info("Debugger attached: %s", "on" if debugger_attached else "off") + logger.info("Press CTRL+C to quit") + + # create app +flask_app = None +socketio_app = None + if is_db_command(): from app_factory import create_migrations_app app = create_migrations_app() + socketio_app = app + flask_app = app else: # Gunicorn and Celery handle monkey patching automatically in production by # specifying the `gevent` worker class. Manual monkey patching is not required here. @@ -30,8 +49,14 @@ else: from app_factory import create_app - app = create_app() + socketio_app, flask_app = create_app() + app = flask_app celery = cast("Celery", app.extensions["celery"]) if __name__ == "__main__": - app.run(host="0.0.0.0", port=5001) + from gevent import pywsgi + from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs] + + log_startup_banner(HOST, PORT) + server = pywsgi.WSGIServer((HOST, PORT), socketio_app, handler_class=WebSocketHandler) + server.serve_forever() diff --git a/api/app_factory.py b/api/app_factory.py index 76838f9925d..48e50ceae9a 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,6 +1,7 @@ import logging import time +import socketio # type: ignore[reportMissingTypeStubs] from flask import request from opentelemetry.trace import get_current_span from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID @@ -10,6 +11,7 @@ from contexts.wrapper import RecyclableContextVar from controllers.console.error import UnauthorizedAndForceLogout from core.logging.context import init_request_context from dify_app import DifyApp +from extensions.ext_socketio import sio from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import LicenseStatus @@ -122,14 +124,18 @@ def create_flask_app_with_configs() -> DifyApp: return dify_app -def create_app() -> DifyApp: +def create_app() -> tuple[socketio.WSGIApp, DifyApp]: start_time = time.perf_counter() app = create_flask_app_with_configs() initialize_extensions(app) + + sio.app = app + socketio_app = socketio.WSGIApp(sio, app) + end_time = time.perf_counter() if dify_config.DEBUG: logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2)) - return app + return socketio_app, app def initialize_extensions(app: DifyApp): diff --git a/api/celery_healthcheck.py b/api/celery_healthcheck.py new file mode 100644 index 00000000000..23d856d7d04 --- /dev/null +++ b/api/celery_healthcheck.py @@ -0,0 +1,18 @@ +# This module provides a lightweight Celery instance for use in Docker health checks. +# Unlike celery_entrypoint.py, this does NOT import app.py and therefore avoids +# initializing all Flask extensions (DB, Redis, storage, blueprints, etc.). +# Using this module keeps the health check fast and low-cost. +from celery import Celery + +from configs import dify_config +from extensions.ext_celery import get_celery_broker_transport_options, get_celery_ssl_options + +celery = Celery(broker=dify_config.CELERY_BROKER_URL) + +broker_transport_options = get_celery_broker_transport_options() +if broker_transport_options: + celery.conf.update(broker_transport_options=broker_transport_options) + +ssl_options = get_celery_ssl_options() +if ssl_options: + celery.conf.update(broker_use_ssl=ssl_options) diff --git a/api/commands/account.py b/api/commands/account.py index 84af7a5ae62..761323a73d0 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,7 +2,7 @@ import base64 import secrets import click -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from constants.languages import languages from extensions.ext_database import db @@ -25,30 +25,32 @@ def reset_password(email, new_password, password_confirm): return normalized_email = email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - valid_password(new_password) - except: - click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) - return + try: + valid_password(new_password) + except: + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) + return - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() - # encrypt password with salt - password_hashed = hash_password(new_password, salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + with Session(db.engine) as session: + account = session.merge(account) account.password = base64_password_hashed account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(normalized_email) - click.echo(click.style("Password reset successfully.", fg="green")) + session.commit() + AccountService.reset_login_error_rate_limit(normalized_email) + click.echo(click.style("Password reset successfully.", fg="green")) @click.command("reset-email", help="Reset the account email.") @@ -65,21 +67,23 @@ def reset_email(email, new_email, email_confirm): return normalized_new_email = new_email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - email_validate(normalized_new_email) - except: - click.echo(click.style(f"Invalid email: {new_email}", fg="red")) - return + try: + email_validate(normalized_new_email) + except: + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) + return + with Session(db.engine) as session: + account = session.merge(account) account.email = normalized_new_email - click.echo(click.style("Email updated successfully.", fg="green")) + session.commit() + click.echo(click.style("Email updated successfully.", fg="green")) @click.command("create-tenant", help="Create account and tenant.") diff --git a/api/commands/retention.py b/api/commands/retention.py index 82a77ea77a7..657a2a2e839 100644 --- a/api/commands/retention.py +++ b/api/commands/retention.py @@ -1,7 +1,7 @@ import datetime import logging import time -from typing import Any +from typing import TypedDict import click import sqlalchemy as sa @@ -503,7 +503,19 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: return [row[0] for row in result] -def _count_orphaned_draft_variables() -> dict[str, Any]: +class _AppOrphanCounts(TypedDict): + variables: int + files: int + + +class OrphanedDraftVariableStatsDict(TypedDict): + total_orphaned_variables: int + total_orphaned_files: int + orphaned_app_count: int + orphaned_by_app: dict[str, _AppOrphanCounts] + + +def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict: """ Count orphaned draft variables by app, including associated file counts. @@ -526,7 +538,7 @@ def _count_orphaned_draft_variables() -> dict[str, Any]: with db.engine.connect() as conn: result = conn.execute(sa.text(variables_query)) - orphaned_by_app = {} + orphaned_by_app: dict[str, _AppOrphanCounts] = {} total_files = 0 for row in result: diff --git a/api/commands/vector.py b/api/commands/vector.py index cb7eb7c4522..956f20d6bb2 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -341,11 +341,10 @@ def add_qdrant_index(field: str): click.echo(click.style("No dataset collection bindings found.", fg="red")) return import qdrant_client + from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType - from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig - for binding in bindings: if dify_config.QDRANT_URL is None: raise ValueError("Qdrant URL is required.") diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 7c02606affd..038f9612767 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1284,6 +1284,13 @@ class PositionConfig(BaseSettings): return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} +class CollaborationConfig(BaseSettings): + ENABLE_COLLABORATION_MODE: bool = Field( + description="Whether to enable collaboration mode features across the workspace", + default=False, + ) + + class LoginConfig(BaseSettings): ENABLE_EMAIL_CODE_LOGIN: bool = Field( description="whether to enable email code login", @@ -1409,6 +1416,7 @@ class FeatureConfig( WorkflowConfig, WorkflowNodeExecutionConfig, WorkspaceConfig, + CollaborationConfig, LoginConfig, AccountConfig, SwaggerUIConfig, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 15ac8bf0bfd..c392b8840f2 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,5 +1,5 @@ import os -from typing import Any, Literal +from typing import Any, Literal, TypedDict from urllib.parse import parse_qsl, quote_plus from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field @@ -107,6 +107,17 @@ class KeywordStoreConfig(BaseSettings): ) +class SQLAlchemyEngineOptionsDict(TypedDict): + pool_size: int + max_overflow: int + pool_recycle: int + pool_pre_ping: bool + connect_args: dict[str, str] + pool_use_lifo: bool + pool_reset_on_return: None + pool_timeout: int + + class DatabaseConfig(BaseSettings): # Database type selector DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field( @@ -149,6 +160,16 @@ class DatabaseConfig(BaseSettings): default="", ) + DB_SESSION_TIMEZONE_OVERRIDE: str = Field( + description=( + "PostgreSQL session timezone override injected via startup options." + " Default is 'UTC' for out-of-the-box consistency." + " Set to empty string to disable app-level timezone injection, for example when using RDS Proxy" + " together with a database-side default timezone." + ), + default="UTC", + ) + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str: @@ -209,21 +230,22 @@ class DatabaseConfig(BaseSettings): @computed_field # type: ignore[prop-decorator] @property - def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: + def SQLALCHEMY_ENGINE_OPTIONS(self) -> SQLAlchemyEngineOptionsDict: # Parse DB_EXTRAS for 'options' db_extras_dict = dict(parse_qsl(self.DB_EXTRAS)) options = db_extras_dict.get("options", "") - connect_args = {} + connect_args: dict[str, str] = {} # Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"): - timezone_opt = "-c timezone=UTC" - if options: - merged_options = f"{options} {timezone_opt}" - else: - merged_options = timezone_opt - connect_args = {"options": merged_options} + merged_options = options.strip() + session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip() + if session_timezone_override: + timezone_opt = f"-c timezone={session_timezone_override}" + merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt + if merged_options: + connect_args = {"options": merged_options} - return { + result: SQLAlchemyEngineOptionsDict = { "pool_size": self.SQLALCHEMY_POOL_SIZE, "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, @@ -233,6 +255,7 @@ class DatabaseConfig(BaseSettings): "pool_reset_on_return": None, "pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT, } + return result class CeleryConfig(DatabaseConfig): diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 3b912075459..2def0a0d4ef 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -32,6 +32,11 @@ class RedisConfig(BaseSettings): default=0, ) + REDIS_KEY_PREFIX: str = Field( + description="Optional global prefix for Redis keys, topics, and transport artifacts", + default="", + ) + REDIS_USE_SSL: bool = Field( description="Enable SSL/TLS for the Redis connection", default=False, @@ -117,6 +122,37 @@ class RedisConfig(BaseSettings): default=None, ) + REDIS_RETRY_RETRIES: NonNegativeInt = Field( + description="Maximum number of retries per Redis command on " + "transient failures (ConnectionError, TimeoutError, socket.timeout)", + default=3, + ) + + REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field( + description="Base delay in seconds for exponential backoff between retries", + default=1.0, + ) + + REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field( + description="Maximum backoff delay in seconds between retries", + default=10.0, + ) + + REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field( + description="Socket timeout in seconds for Redis read/write operations", + default=5.0, + ) + + REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field( + description="Socket timeout in seconds for Redis connection establishment", + default=5.0, + ) + + REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field( + description="Interval in seconds between Redis connection health checks (0 to disable)", + default=30, + ) + @field_validator("REDIS_MAX_CONNECTIONS", mode="before") @classmethod def _empty_string_to_none_for_max_conns(cls, v): diff --git a/api/configs/middleware/vdb/hologres_config.py b/api/configs/middleware/vdb/hologres_config.py index 9812cce268c..788b3cfb783 100644 --- a/api/configs/middleware/vdb/hologres_config.py +++ b/api/configs/middleware/vdb/hologres_config.py @@ -1,4 +1,3 @@ -from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType from pydantic import Field from pydantic_settings import BaseSettings @@ -42,17 +41,17 @@ class HologresConfig(BaseSettings): default="public", ) - HOLOGRES_TOKENIZER: TokenizerType = Field( + HOLOGRES_TOKENIZER: str = Field( description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').", default="jieba", ) - HOLOGRES_DISTANCE_METHOD: DistanceType = Field( + HOLOGRES_DISTANCE_METHOD: str = Field( description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').", default="Cosine", ) - HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field( + HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field( description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').", default="rabitq", ) diff --git a/api/configs/middleware/vdb/iris_config.py b/api/configs/middleware/vdb/iris_config.py index c532d191c3b..f5993dd8f83 100644 --- a/api/configs/middleware/vdb/iris_config.py +++ b/api/configs/middleware/vdb/iris_config.py @@ -1,5 +1,7 @@ """Configuration for InterSystems IRIS vector database.""" +from typing import Any + from pydantic import Field, PositiveInt, model_validator from pydantic_settings import BaseSettings @@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]: """Validate IRIS configuration values. Args: diff --git a/api/constants/dsl_version.py b/api/constants/dsl_version.py new file mode 100644 index 00000000000..b0fbe0075c7 --- /dev/null +++ b/api/constants/dsl_version.py @@ -0,0 +1 @@ +CURRENT_APP_DSL_VERSION = "0.6.0" diff --git a/api/controllers/common/controller_schemas.py b/api/controllers/common/controller_schemas.py new file mode 100644 index 00000000000..c12d5764737 --- /dev/null +++ b/api/controllers/common/controller_schemas.py @@ -0,0 +1,104 @@ +from typing import Any, Literal +from uuid import UUID + +from pydantic import BaseModel, Field, model_validator + +from libs.helper import UUIDStrOrEmpty + +# --- Conversation schemas --- + + +class ConversationRenamePayload(BaseModel): + name: str | None = None + auto_generate: bool = False + + @model_validator(mode="after") + def validate_name_requirement(self): + if not self.auto_generate: + if self.name is None or not self.name.strip(): + raise ValueError("name is required when auto_generate is false") + return self + + +# --- Message schemas --- + + +class MessageListQuery(BaseModel): + conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID") + first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") + + +class MessageFeedbackPayload(BaseModel): + rating: Literal["like", "dislike"] | None = None + content: str | None = None + + +# --- Saved message schemas --- + + +class SavedMessageListQuery(BaseModel): + last_id: UUIDStrOrEmpty | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class SavedMessageCreatePayload(BaseModel): + message_id: UUIDStrOrEmpty + + +# --- Workflow schemas --- + + +class DefaultBlockConfigQuery(BaseModel): + q: str | None = None + + +class WorkflowListQuery(BaseModel): + page: int = Field(default=1, ge=1, le=99999) + limit: int = Field(default=10, ge=1, le=100) + user_id: str | None = None + named_only: bool = False + + +class WorkflowRunPayload(BaseModel): + inputs: dict[str, Any] + files: list[dict[str, Any]] | None = None + + +class WorkflowUpdatePayload(BaseModel): + marked_name: str | None = Field(default=None, max_length=20) + marked_comment: str | None = Field(default=None, max_length=100) + + +# --- Dataset schemas --- + + +DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 + + +class ChildChunkCreatePayload(BaseModel): + content: str + + +class ChildChunkUpdatePayload(BaseModel): + content: str + + +class DocumentBatchDownloadZipPayload(BaseModel): + """Request payload for bulk downloading documents as a zip archive.""" + + document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS) + + +class MetadataUpdatePayload(BaseModel): + name: str + + +# --- Audio schemas --- + + +class TextToAudioPayload(BaseModel): + message_id: str | None = Field(default=None, description="Message ID") + voice: str | None = Field(default=None, description="Voice to use for TTS") + text: str | None = Field(default=None, description="Text to convert to audio") + streaming: bool | None = Field(default=None, description="Enable streaming response") diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 4fe3fc90621..8e665c13865 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -2,9 +2,9 @@ from __future__ import annotations from typing import Any -from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field +from graphon.file import helpers as file_helpers from models.model import IconType type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index d624b10b22b..980e828945f 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -65,6 +65,7 @@ from .app import ( statistic, workflow, workflow_app_log, + workflow_comment, workflow_draft_variable, workflow_run, workflow_statistic, @@ -116,6 +117,7 @@ from .explore import ( saved_message, trial, ) +from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport] # Import tag controllers from .tag import tags @@ -201,6 +203,7 @@ __all__ = [ "saved_message", "setup", "site", + "socketio_workflow", "spec", "statistic", "tags", @@ -211,6 +214,7 @@ __all__ = [ "website", "workflow", "workflow_app_log", + "workflow_comment", "workflow_draft_variable", "workflow_run", "workflow_statistic", diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 9b8408980de..dce394be97f 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -2,6 +2,7 @@ import csv import io from collections.abc import Callable from functools import wraps +from typing import cast from flask import request from flask_restx import Resource @@ -17,7 +18,7 @@ from core.db.session_factory import session_factory from extensions.ext_database import db from libs.token import extract_access_token from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp -from services.billing_service import BillingService +from services.billing_service import BillingService, LangContentDict DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -328,7 +329,7 @@ class UpsertNotificationApi(Resource): def post(self): payload = UpsertNotificationPayload.model_validate(console_ns.payload) result = BillingService.upsert_notification( - contents=[c.model_dump() for c in payload.contents], + contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents], frequency=payload.frequency, status=payload.status, notification_id=payload.notification_id, diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 772bb9d0f19..b03d9b4a4cb 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,12 +1,16 @@ +from datetime import datetime + import flask_restx -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from flask_restx._http import HTTPStatus +from pydantic import field_validator from sqlalchemy import delete, func, select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from extensions.ext_database import db -from libs.helper import TimestampField +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset from models.enums import ApiTokenType @@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache from . import console_ns from .wraps import account_initialization_required, edit_permission_required, setup_required -api_key_fields = { - "id": fields.String, - "type": fields.String, - "token": fields.String, - "last_used_at": TimestampField, - "created_at": TimestampField, -} -api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} -api_key_list_model = console_ns.model( - "ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} -) +class ApiKeyItem(ResponseModel): + id: str + type: str + token: str + last_used_at: int | None = None + created_at: int | None = None + + @field_validator("last_used_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class ApiKeyList(ResponseModel): + data: list[ApiKeyItem] + + +register_schema_models(console_ns, ApiKeyItem, ApiKeyList) def _get_resource(resource_id, tenant_id, resource_model): @@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource): token_prefix: str | None = None max_keys = 10 - @marshal_with(api_key_list_model) def get(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) @@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource): ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id ) ).all() - return {"items": keys} + return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json") - @marshal_with(api_key_item_model) @edit_permission_required def post(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" @@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource): api_token.type = self.resource_type db.session.add(api_token) db.session.commit() - return api_token, 201 + return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201 class BaseApiKeyResource(Resource): @@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("get_app_api_keys") @console_ns.doc(description="Get all API keys for an app") @console_ns.doc(params={"resource_id": "App ID"}) - @console_ns.response(200, "Success", api_key_list_model) + @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) def get(self, resource_id): # type: ignore """Get all API keys for an app""" return super().get(resource_id) @@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("create_app_api_key") @console_ns.doc(description="Create a new API key for an app") @console_ns.doc(params={"resource_id": "App ID"}) - @console_ns.response(201, "API key created successfully", api_key_item_model) + @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") def post(self, resource_id): # type: ignore """Create a new API key for an app""" @@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get all API keys for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) - @console_ns.response(200, "Success", api_key_list_model) + @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) def get(self, resource_id): # type: ignore """Get all API keys for a dataset""" return super().get(resource_id) @@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("create_dataset_api_key") @console_ns.doc(description="Create a new API key for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) - @console_ns.response(201, "API key created successfully", api_key_item_model) + @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") def post(self, resource_id): # type: ignore """Create a new API key for a dataset""" diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index 3bd61feb44c..ed66da1be5a 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required -from services.advanced_prompt_template_service import AdvancedPromptTemplateService +from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService class AdvancedPromptTemplateQuery(BaseModel): @@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource): @account_initialization_required def get(self): args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - - return AdvancedPromptTemplateService.get_prompt(args.model_dump()) + prompt_args: AdvancedPromptTemplateArgs = { + "app_mode": args.app_mode, + "model_mode": args.model_mode, + "model_name": args.model_name, + "has_context": args.has_context, + } + return AdvancedPromptTemplateService.get_prompt(prompt_args) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 9931bb5dd7d..528785931ea 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -25,7 +25,13 @@ from fields.annotation_fields import ( ) from libs.helper import uuid_value from libs.login import login_required -from services.annotation_service import AppAnnotationService +from services.annotation_service import ( + AppAnnotationService, + EnableAnnotationArgs, + UpdateAnnotationArgs, + UpdateAnnotationSettingArgs, + UpsertAnnotationArgs, +) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource): args = AnnotationReplyPayload.model_validate(console_ns.payload) match action: case "enable": - result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) + enable_args: EnableAnnotationArgs = { + "score_threshold": args.score_threshold, + "embedding_provider_name": args.embedding_provider_name, + "embedding_model_name": args.embedding_model_name, + } + result = AppAnnotationService.enable_app_annotation(enable_args, app_id) case "disable": result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 @@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource): args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump()) + setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold} + result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args) return result, 200 @@ -237,8 +249,16 @@ class AnnotationApi(Resource): def post(self, app_id): app_id = str(app_id) args = CreateAnnotationPayload.model_validate(console_ns.payload) - data = args.model_dump(exclude_none=True) - annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) + upsert_args: UpsertAnnotationArgs = {} + if args.answer is not None: + upsert_args["answer"] = args.answer + if args.content is not None: + upsert_args["content"] = args.content + if args.message_id is not None: + upsert_args["message_id"] = args.message_id + if args.question is not None: + upsert_args["question"] = args.question + annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) args = UpdateAnnotationPayload.model_validate(console_ns.payload) - annotation = AppAnnotationService.update_app_annotation_directly( - args.model_dump(exclude_none=True), app_id, annotation_id - ) + update_args: UpdateAnnotationArgs = {} + if args.answer is not None: + update_args["answer"] = args.answer + if args.question is not None: + update_args["question"] = args.question + annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index c67ca57c63a..9102983d869 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -5,11 +5,9 @@ from typing import Any, Literal from flask import request from flask_restx import Resource -from graphon.enums import WorkflowExecutionStatus -from graphon.file import helpers as file_helpers -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator +from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest from controllers.common.helpers import FileInfo @@ -26,25 +24,27 @@ from controllers.console.wraps import ( setup_required, ) from core.ops.ops_trace_manager import OpsTraceManager +from core.rag.entities import PreProcessingRule, Rule, Segmentation from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db +from fields.base import ResponseModel +from graphon.enums import WorkflowExecutionStatus +from libs.helper import build_icon_url from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType -from services.app_dsl_service import AppDslService, ImportMode +from services.app_dsl_service import AppDslService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService +from services.entities.dsl_entities import ImportMode, ImportStatus from services.entities.knowledge_entities.knowledge_entities import ( DataSource, InfoList, NotionIcon, NotionInfo, NotionPage, - PreProcessingRule, RerankingModel, - Rule, - Segmentation, WebsiteInfo, WeightKeywordSetting, WeightModel, @@ -129,6 +129,7 @@ class AppNamePayload(BaseModel): class AppIconPayload(BaseModel): icon: str | None = Field(default=None, description="Icon data") + icon_type: IconType | None = Field(default=None, description="Icon type") icon_background: str | None = Field(default=None, description="Icon background color") @@ -155,31 +156,12 @@ class AppTracePayload(BaseModel): type JSONValue = Any -class ResponseModel(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="ignore", - populate_by_name=True, - serialize_by_alias=True, - protected_namespaces=(), - ) - - def _to_timestamp(value: datetime | int | None) -> int | None: if isinstance(value, datetime): return int(value.timestamp()) return value -def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: - if icon is None or icon_type is None: - return None - icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) - if icon_type_value.lower() != IconType.IMAGE: - return None - return file_helpers.get_signed_file_url(icon) - - class Tag(ResponseModel): id: str name: str @@ -302,7 +284,7 @@ class Site(ResponseModel): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) @field_validator("icon_type", mode="before") @classmethod @@ -352,7 +334,7 @@ class AppPartial(ResponseModel): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) @field_validator("created_at", "updated_at", mode="before") @classmethod @@ -400,7 +382,7 @@ class AppDetailWithSite(AppDetail): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) class AppPagination(ResponseModel): @@ -642,7 +624,7 @@ class AppCopyApi(Resource): args = CopyAppPayload.model_validate(console_ns.payload or {}) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) result = import_service.import_app( @@ -655,6 +637,13 @@ class AppCopyApi(Resource): icon=args.icon, icon_background=args.icon_background, ) + if result.status == ImportStatus.FAILED: + session.rollback() + return result.model_dump(mode="json"), 400 + if result.status == ImportStatus.PENDING: + session.rollback() + return result.model_dump(mode="json"), 202 + session.commit() # Inherit web app permission from original app if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: @@ -741,7 +730,12 @@ class AppIconApi(Resource): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") + app_model = app_service.update_app_icon( + app_model, + args.icon or "", + args.icon_background or "", + args.icon_type, + ) response_model = AppDetail.model_validate(app_model, from_attributes=True) return response_model.model_dump(mode="json") diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 16e1fa32455..e91dc9cfe51 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,7 +1,8 @@ -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( account_initialization_required, @@ -10,34 +11,15 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.app_fields import ( - app_import_check_dependencies_fields, - app_import_fields, - leaked_dependency_fields, -) from libs.login import current_account_with_tenant, login_required from models.model import App -from services.app_dsl_service import AppDslService, ImportStatus +from services.app_dsl_service import AppDslService, Import from services.enterprise.enterprise_service import EnterpriseService +from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus from services.feature_service import FeatureService from .. import console_ns -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base model first -leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields) - -app_import_model = console_ns.model("AppImport", app_import_fields) - -# For nested models, need to replace nested dict with registered model -app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy() -app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model)) -app_import_check_dependencies_model = console_ns.model( - "AppImportCheckDependencies", app_import_check_dependencies_fields_copy -) - -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AppImportPayload(BaseModel): mode: str = Field(..., description="Import mode") @@ -51,18 +33,18 @@ class AppImportPayload(BaseModel): app_id: str | None = Field(None) -console_ns.schema_model( - AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult) @console_ns.route("/apps/imports") class AppImportApi(Resource): @console_ns.expect(console_ns.models[AppImportPayload.__name__]) + @console_ns.response(200, "Import completed", console_ns.models[Import.__name__]) + @console_ns.response(202, "Import pending confirmation", console_ns.models[Import.__name__]) + @console_ns.response(400, "Import failed", console_ns.models[Import.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(app_import_model) @cloud_edition_billing_resource_check("apps") @edit_permission_required def post(self): @@ -70,8 +52,9 @@ class AppImportApi(Resource): current_user, _ = current_account_with_tenant() args = AppImportPayload.model_validate(console_ns.payload) - # Create service with session - with sessionmaker(db.engine).begin() as session: + # AppDslService performs internal commits for some creation paths, so use a plain + # Session here instead of nesting it inside sessionmaker(...).begin(). + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Import app account = current_user @@ -87,35 +70,45 @@ class AppImportApi(Resource): icon_background=args.icon_background, app_id=args.app_id, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: # update web app setting as private EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") # Return appropriate status code based on result status = result.status - if status == ImportStatus.FAILED: - return result.model_dump(mode="json"), 400 - elif status == ImportStatus.PENDING: - return result.model_dump(mode="json"), 202 - return result.model_dump(mode="json"), 200 + match status: + case ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + case ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS: + return result.model_dump(mode="json"), 200 @console_ns.route("/apps/imports//confirm") class AppImportConfirmApi(Resource): + @console_ns.response(200, "Import confirmed", console_ns.models[Import.__name__]) + @console_ns.response(400, "Import failed", console_ns.models[Import.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(app_import_model) @edit_permission_required def post(self, import_id): # Check user role first current_user, _ = current_account_with_tenant() - # Create service with session - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Confirm import account = current_user result = import_service.confirm_import(import_id=import_id, account=account) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() # Return appropriate status code based on result if result.status == ImportStatus.FAILED: @@ -125,14 +118,14 @@ class AppImportConfirmApi(Resource): @console_ns.route("/apps/imports//check-dependencies") class AppImportCheckDependenciesApi(Resource): + @console_ns.response(200, "Dependencies checked", console_ns.models[CheckDependenciesResult.__name__]) @setup_required @login_required @get_app_model @account_initialization_required - @marshal_with(app_import_check_dependencies_model) @edit_permission_required def get(self, app_model: App): - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) result = import_service.check_dependencies(app_model=app_model) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 78ddb904e14..91fbe4a85ae 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,7 +2,6 @@ import logging from flask import request from flask_restx import Resource, fields -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -23,6 +22,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index d83925d173a..fe274e4c9a8 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,7 +3,6 @@ from typing import Any, Literal from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -27,6 +26,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index d329d22309d..b2b1049f0c3 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -2,20 +2,37 @@ from typing import Literal import sqlalchemy as sa from flask import abort, request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, or_ from sqlalchemy.orm import selectinload from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from fields.raws import FilesContainedField +from fields.conversation_fields import ( + Conversation as ConversationResponse, +) +from fields.conversation_fields import ( + ConversationDetail as ConversationDetailResponse, +) +from fields.conversation_fields import ( + ConversationMessageDetail as ConversationMessageDetailResponse, +) +from fields.conversation_fields import ( + ConversationPagination as ConversationPaginationResponse, +) +from fields.conversation_fields import ( + ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse, +) +from fields.conversation_fields import ( + ResultResponse, +) from libs.datetime_utils import naive_utc_now, parse_time_range -from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models import Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode @@ -62,267 +79,16 @@ console_ns.schema_model( ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) -# Register models for flask_restx to avoid dict type issues in Swagger -# Register in dependency order: base models first, then dependent models - -# Base models -simple_account_model = console_ns.model( - "SimpleAccount", - { - "id": fields.String, - "name": fields.String, - "email": fields.String, - }, -) - -feedback_stat_model = console_ns.model( - "FeedbackStat", - { - "like": fields.Integer, - "dislike": fields.Integer, - }, -) - -status_count_model = console_ns.model( - "StatusCount", - { - "success": fields.Integer, - "failed": fields.Integer, - "partial_success": fields.Integer, - "paused": fields.Integer, - }, -) - -message_file_model = console_ns.model( - "MessageFile", - { - "id": fields.String, - "filename": fields.String, - "type": fields.String, - "url": fields.String, - "mime_type": fields.String, - "size": fields.Integer, - "transfer_method": fields.String, - "belongs_to": fields.String(default="user"), - "upload_file_id": fields.String(default=None), - }, -) - -agent_thought_model = console_ns.model( - "AgentThought", - { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), - }, -) - -simple_model_config_model = console_ns.model( - "SimpleModelConfig", - { - "model": fields.Raw(attribute="model_dict"), - "pre_prompt": fields.String, - }, -) - -model_config_model = console_ns.model( - "ModelConfig", - { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "model": fields.Raw, - "user_input_form": fields.Raw, - "pre_prompt": fields.String, - "agent_mode": fields.Raw, - }, -) - -# Models that depend on simple_account_model -feedback_model = console_ns.model( - "Feedback", - { - "rating": fields.String, - "content": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account": fields.Nested(simple_account_model, allow_null=True), - }, -) - -annotation_model = console_ns.model( - "Annotation", - { - "id": fields.String, - "question": fields.String, - "content": fields.String, - "account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -annotation_hit_history_model = console_ns.model( - "AnnotationHitHistory", - { - "annotation_id": fields.String(attribute="id"), - "annotation_create_account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - - -class MessageTextField(fields.Raw): - def format(self, value): - return value[0]["text"] if value else "" - - -# Simple message detail model -simple_message_detail_model = console_ns.model( - "SimpleMessageDetail", - { - "inputs": FilesContainedField, - "query": fields.String, - "message": MessageTextField, - "answer": fields.String, - }, -) - -# Message detail model that depends on multiple models -message_detail_model = console_ns.model( - "MessageDetail", - { - "id": fields.String, - "conversation_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "message": fields.Raw, - "message_tokens": fields.Integer, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "answer_tokens": fields.Integer, - "provider_response_latency": fields.Float, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "feedbacks": fields.List(fields.Nested(feedback_model)), - "workflow_run_id": fields.String, - "annotation": fields.Nested(annotation_model, allow_null=True), - "annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), - "message_files": fields.List(fields.Nested(message_file_model)), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "parent_message_id": fields.String, - }, -) - -# Conversation models -conversation_fields_model = console_ns.model( - "Conversation", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String(), - "from_account_id": fields.String, - "from_account_name": fields.String, - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotation": fields.Nested(annotation_model, allow_null=True), - "model_config": fields.Nested(simple_model_config_model), - "user_feedback_stats": fields.Nested(feedback_stat_model), - "admin_feedback_stats": fields.Nested(feedback_stat_model), - "message": fields.Nested(simple_message_detail_model, attribute="first_message"), - }, -) - -conversation_pagination_model = console_ns.model( - "ConversationPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_fields_model), attribute="items"), - }, -) - -conversation_message_detail_model = console_ns.model( - "ConversationMessageDetail", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "model_config": fields.Nested(model_config_model), - "message": fields.Nested(message_detail_model, attribute="first_message"), - }, -) - -conversation_with_summary_model = console_ns.model( - "ConversationWithSummary", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String, - "from_account_id": fields.String, - "from_account_name": fields.String, - "name": fields.String, - "summary": fields.String(attribute="summary_or_query"), - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "model_config": fields.Nested(simple_model_config_model), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_model), - "admin_feedback_stats": fields.Nested(feedback_stat_model), - "status_count": fields.Nested(status_count_model), - }, -) - -conversation_with_summary_pagination_model = console_ns.model( - "ConversationWithSummaryPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"), - }, -) - -conversation_detail_model = console_ns.model( - "ConversationDetail", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "introduction": fields.String, - "model_config": fields.Nested(model_config_model), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_model), - "admin_feedback_stats": fields.Nested(feedback_stat_model), - }, +register_schema_models( + console_ns, + CompletionConversationQuery, + ChatConversationQuery, + ConversationResponse, + ConversationPaginationResponse, + ConversationMessageDetailResponse, + ConversationWithSummaryPaginationResponse, + ConversationDetailResponse, + ResultResponse, ) @@ -332,13 +98,12 @@ class CompletionConversationApi(Resource): @console_ns.doc(description="Get completion conversations with pagination and filtering") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[CompletionConversationQuery.__name__]) - @console_ns.response(200, "Success", conversation_pagination_model) + @console_ns.response(200, "Success", console_ns.models[ConversationPaginationResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - @marshal_with(conversation_pagination_model) @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() @@ -394,7 +159,9 @@ class CompletionConversationApi(Resource): conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) - return conversations + return ConversationPaginationResponse.model_validate(conversations, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//completion-conversations/") @@ -402,19 +169,19 @@ class CompletionConversationDetailApi(Resource): @console_ns.doc("get_completion_conversation") @console_ns.doc(description="Get completion conversation details with messages") @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) - @console_ns.response(200, "Success", conversation_message_detail_model) + @console_ns.response(200, "Success", console_ns.models[ConversationMessageDetailResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - @marshal_with(conversation_message_detail_model) @edit_permission_required def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - - return _get_conversation(app_model, conversation_id) + return ConversationMessageDetailResponse.model_validate( + _get_conversation(app_model, conversation_id), from_attributes=True + ).model_dump(mode="json") @console_ns.doc("delete_completion_conversation") @console_ns.doc(description="Delete a completion conversation") @@ -436,7 +203,7 @@ class CompletionConversationDetailApi(Resource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 @console_ns.route("/apps//chat-conversations") @@ -445,13 +212,12 @@ class ChatConversationApi(Resource): @console_ns.doc(description="Get chat conversations with pagination, filtering and summary") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ChatConversationQuery.__name__]) - @console_ns.response(200, "Success", conversation_with_summary_pagination_model) + @console_ns.response(200, "Success", console_ns.models[ConversationWithSummaryPaginationResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(conversation_with_summary_pagination_model) @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() @@ -546,7 +312,9 @@ class ChatConversationApi(Resource): conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) - return conversations + return ConversationWithSummaryPaginationResponse.model_validate(conversations, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//chat-conversations/") @@ -554,19 +322,19 @@ class ChatConversationDetailApi(Resource): @console_ns.doc("get_chat_conversation") @console_ns.doc(description="Get chat conversation details") @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) - @console_ns.response(200, "Success", conversation_detail_model) + @console_ns.response(200, "Success", console_ns.models[ConversationDetailResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(conversation_detail_model) @edit_permission_required def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - - return _get_conversation(app_model, conversation_id) + return ConversationDetailResponse.model_validate( + _get_conversation(app_model, conversation_id), from_attributes=True + ).model_dump(mode="json") @console_ns.doc("delete_chat_conversation") @console_ns.doc(description="Delete a chat conversation") @@ -588,7 +356,7 @@ class ChatConversationDetailApi(Resource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 def _get_conversation(app_model, conversation_id): diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 369c26a80cb..9c8b095b9f8 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,44 +1,86 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from fields.conversation_variable_fields import ( - conversation_variable_fields, - paginated_conversation_variable_fields, -) +from fields._value_type_serializer import serialize_value_type +from fields.base import ResponseModel from libs.login import login_required from models import ConversationVariable from models.model import AppMode -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ConversationVariablesQuery(BaseModel): conversation_id: str = Field(..., description="Conversation ID to filter variables") -console_ns.schema_model( - ConversationVariablesQuery.__name__, - ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base model first -conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields) -# For nested models, need to replace nested dict with registered model -paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy() -paginated_conversation_variable_fields_copy["data"] = fields.List( - fields.Nested(conversation_variable_model), attribute="data" -) -paginated_conversation_variable_model = console_ns.model( - "PaginatedConversationVariable", paginated_conversation_variable_fields_copy +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def _normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type()) + if isinstance(value, str): + return value + try: + return serialize_value_type(value) + except Exception: + return serialize_value_type({"value_type": value}) + + @field_validator("value", mode="before") + @classmethod + def _normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class PaginatedConversationVariableResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationVariableResponse] + + +register_schema_models( + console_ns, + ConversationVariablesQuery, + ConversationVariableResponse, + PaginatedConversationVariableResponse, ) @@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource): @console_ns.doc(description="Get conversation variables for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__]) - @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model) + @console_ns.response( + 200, + "Conversation variables retrieved successfully", + console_ns.models[PaginatedConversationVariableResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.ADVANCED_CHAT) - @marshal_with(paginated_conversation_variable_model) def get(self, app_model): args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource): with sessionmaker(db.engine, expire_on_commit=False).begin() as session: rows = session.scalars(stmt).all() - return { - "page": page, - "limit": page_size, - "total": len(rows), - "has_more": False, - "data": [ - { - "created_at": row.created_at, - "updated_at": row.updated_at, - **row.to_variable().model_dump(), - } - for row in rows - ], - } + response = PaginatedConversationVariableResponse.model_validate( + { + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ + ConversationVariableResponse.model_validate( + { + "created_at": row.created_at, + "updated_at": row.updated_at, + **row.to_variable().model_dump(), + } + ) + for row in rows + ], + } + ) + return response.model_dump(mode="json") diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 7101d5df7b4..c720a5e074b 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -20,6 +19,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 412fc8795a4..d517f695b8f 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,39 +1,68 @@ import json +from datetime import datetime +from typing import Any -from flask_restx import Resource, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from extensions.ext_database import db -from fields.app_fields import app_server_fields +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.enums import AppMCPServerStatus from models.model import AppMCPServer -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - -# Register model for flask_restx to avoid dict type issues in Swagger -app_server_model = console_ns.model("AppServer", app_server_fields) - class MCPServerCreatePayload(BaseModel): description: str | None = Field(default=None, description="Server description") - parameters: dict = Field(..., description="Server parameters configuration") + parameters: dict[str, Any] = Field(..., description="Server parameters configuration") class MCPServerUpdatePayload(BaseModel): id: str = Field(..., description="Server ID") description: str | None = Field(default=None, description="Server description") - parameters: dict = Field(..., description="Server parameters configuration") + parameters: dict[str, Any] = Field(..., description="Server parameters configuration") status: str | None = Field(default=None, description="Server status") -for model in (MCPServerCreatePayload, MCPServerUpdatePayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class AppMCPServerResponse(ResponseModel): + id: str + name: str + server_code: str + description: str + status: AppMCPServerStatus + parameters: dict[str, Any] | list[Any] | str + created_at: int | None = None + updated_at: int | None = None + + @field_validator("parameters", mode="before") + @classmethod + def _normalize_parameters(cls, value: Any) -> Any: + if isinstance(value, str): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + return value + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse) @console_ns.route("/apps//server") @@ -41,27 +70,31 @@ class AppMCPServerController(Resource): @console_ns.doc("get_app_mcp_server") @console_ns.doc(description="Get MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model) + @console_ns.response( + 200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @login_required @account_initialization_required @setup_required @get_app_model - @marshal_with(app_server_model) def get(self, app_model): server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) - return server + if server is None: + return {} + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") @console_ns.doc("create_app_mcp_server") @console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__]) - @console_ns.response(201, "MCP server configuration created successfully", app_server_model) + @console_ns.response( + 201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @account_initialization_required @get_app_model @login_required @setup_required - @marshal_with(app_server_model) @edit_permission_required def post(self, app_model): _, current_tenant_id = current_account_with_tenant() @@ -82,20 +115,21 @@ class AppMCPServerController(Resource): ) db.session.add(server) db.session.commit() - return server + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201 @console_ns.doc("update_app_mcp_server") @console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__]) - @console_ns.response(200, "MCP server configuration updated successfully", app_server_model) + @console_ns.response( + 200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @get_app_model @login_required @setup_required @account_initialization_required - @marshal_with(app_server_model) @edit_permission_required def put(self, app_model): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) @@ -118,7 +152,7 @@ class AppMCPServerController(Resource): except ValueError: raise ValueError("Invalid status") db.session.commit() - return server + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//server/refresh") @@ -126,13 +160,12 @@ class AppMCPServerRefreshController(Resource): @console_ns.doc("refresh_app_mcp_server") @console_ns.doc(description="Refresh MCP server configuration and regenerate server code") @console_ns.doc(params={"server_id": "Server ID"}) - @console_ns.response(200, "MCP server refreshed successfully", app_server_model) + @console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @setup_required @login_required @account_initialization_required - @marshal_with(app_server_model) @edit_permission_required def get(self, server_id): _, current_tenant_id = current_account_with_tenant() @@ -145,4 +178,4 @@ class AppMCPServerRefreshController(Resource): raise NotFound() server.server_code = AppMCPServer.generate_server_code(16) db.session.commit() - return server + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 2afe2767427..44e19b57db5 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,13 +1,14 @@ import logging +from datetime import datetime from typing import Literal from flask import request -from flask_restx import Resource, fields, marshal_with -from graphon.model_runtime.errors.invoke import InvokeError +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( @@ -24,10 +25,22 @@ from controllers.console.wraps import ( setup_required, ) from core.app.entities.app_invoke_entities import InvokeFrom +from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from extensions.ext_database import db -from fields.raws import FilesContainedField -from libs.helper import TimestampField, uuid_value +from fields.base import ResponseModel +from fields.conversation_fields import ( + AgentThought, + ConversationAnnotation, + ConversationAnnotationHitHistory, + Feedback, + JSONValue, + MessageFile, + format_files_contained, + to_timestamp, +) +from graphon.model_runtime.errors.invoke import InvokeError +from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required from models.enums import FeedbackFromSource, FeedbackRating @@ -59,10 +72,8 @@ class ChatMessagesQuery(BaseModel): return uuid_value(value) -class MessageFeedbackPayload(BaseModel): +class MessageFeedbackPayload(_MessageFeedbackPayloadBase): message_id: str = Field(..., description="Message ID") - rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") - content: str | None = Field(default=None, description="Feedback content") @field_validator("message_id") @classmethod @@ -99,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel): data: list[str] = Field(description="Suggested question") +class MessageDetailResponse(ResponseModel): + id: str + conversation_id: str + inputs: dict[str, JSONValue] + query: str + message: JSONValue | None = None + message_tokens: int | None = None + answer: str = Field(validation_alias="re_sign_file_url_answer") + answer_tokens: int | None = None + provider_response_latency: float | None = None + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + feedbacks: list[Feedback] = Field(default_factory=list) + workflow_run_id: str | None = None + annotation: ConversationAnnotation | None = None + annotation_hit_history: ConversationAnnotationHitHistory | None = None + created_at: int | None = None + agent_thoughts: list[AgentThought] = Field(default_factory=list) + message_files: list[MessageFile] = Field(default_factory=list) + extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list) + metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict") + status: str + error: str | None = None + parent_message_id: str | None = None + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class MessageInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[MessageDetailResponse] + + register_schema_models( console_ns, ChatMessagesQuery, @@ -106,124 +162,8 @@ register_schema_models( FeedbackExportQuery, AnnotationCountResponse, SuggestedQuestionsResponse, -) - -# Register models for flask_restx to avoid dict type issues in Swagger -# Register in dependency order: base models first, then dependent models - -# Base models -simple_account_model = console_ns.model( - "SimpleAccount", - { - "id": fields.String, - "name": fields.String, - "email": fields.String, - }, -) - -message_file_model = console_ns.model( - "MessageFile", - { - "id": fields.String, - "filename": fields.String, - "type": fields.String, - "url": fields.String, - "mime_type": fields.String, - "size": fields.Integer, - "transfer_method": fields.String, - "belongs_to": fields.String(default="user"), - "upload_file_id": fields.String(default=None), - }, -) - -agent_thought_model = console_ns.model( - "AgentThought", - { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), - }, -) - -# Models that depend on simple_account_model -feedback_model = console_ns.model( - "Feedback", - { - "rating": fields.String, - "content": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account": fields.Nested(simple_account_model, allow_null=True), - }, -) - -annotation_model = console_ns.model( - "Annotation", - { - "id": fields.String, - "question": fields.String, - "content": fields.String, - "account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -annotation_hit_history_model = console_ns.model( - "AnnotationHitHistory", - { - "annotation_id": fields.String(attribute="id"), - "annotation_create_account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -# Message detail model that depends on multiple models -message_detail_model = console_ns.model( - "MessageDetail", - { - "id": fields.String, - "conversation_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "message": fields.Raw, - "message_tokens": fields.Integer, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "answer_tokens": fields.Integer, - "provider_response_latency": fields.Float, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "feedbacks": fields.List(fields.Nested(feedback_model)), - "workflow_run_id": fields.String, - "annotation": fields.Nested(annotation_model, allow_null=True), - "annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), - "message_files": fields.List(fields.Nested(message_file_model)), - "extra_contents": fields.List(fields.Raw), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "parent_message_id": fields.String, - }, -) - -# Message infinite scroll pagination model -message_infinite_scroll_pagination_model = console_ns.model( - "MessageInfiniteScrollPagination", - { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_detail_model)), - }, + MessageDetailResponse, + MessageInfiniteScrollPaginationResponse, ) @@ -233,13 +173,12 @@ class ChatMessageListApi(Resource): @console_ns.doc(description="Get chat messages for a conversation with pagination") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ChatMessagesQuery.__name__]) - @console_ns.response(200, "Success", message_infinite_scroll_pagination_model) + @console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__]) @console_ns.response(404, "Conversation not found") @login_required @account_initialization_required @setup_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): args = ChatMessagesQuery.model_validate(request.args.to_dict()) @@ -299,7 +238,10 @@ class ChatMessageListApi(Resource): history_messages = list(reversed(history_messages)) attach_message_extra_contents(history_messages) - return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more) + return MessageInfiniteScrollPaginationResponse.model_validate( + InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more), + from_attributes=True, + ).model_dump(mode="json") @console_ns.route("/apps//feedbacks") @@ -469,13 +411,12 @@ class MessageApi(Resource): @console_ns.doc("get_message") @console_ns.doc(description="Get message details by ID") @console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) - @console_ns.response(200, "Message retrieved successfully", message_detail_model) + @console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__]) @console_ns.response(404, "Message not found") @get_app_model @setup_required @login_required @account_initialization_required - @marshal_with(message_detail_model) def get(self, app_model, message_id: str): message_id = str(message_id) @@ -487,4 +428,4 @@ class MessageApi(Resource): raise NotFound("Message Not Exists.") attach_message_extra_contents([message]) - return message + return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 8bb5aa2c1b1..1869cbf5f6d 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,9 +1,11 @@ import json -from typing import cast +from typing import Any, cast from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource +from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required @@ -18,30 +20,30 @@ from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService +class ModelConfigRequest(BaseModel): + provider: str | None = Field(default=None, description="Model provider") + model: str | None = Field(default=None, description="Model name") + configs: dict[str, Any] | None = Field(default=None, description="Model configuration parameters") + opening_statement: str | None = Field(default=None, description="Opening statement") + suggested_questions: list[str] | None = Field(default=None, description="Suggested questions") + more_like_this: dict[str, Any] | None = Field(default=None, description="More like this configuration") + speech_to_text: dict[str, Any] | None = Field(default=None, description="Speech to text configuration") + text_to_speech: dict[str, Any] | None = Field(default=None, description="Text to speech configuration") + retrieval_model: dict[str, Any] | None = Field(default=None, description="Retrieval model configuration") + tools: list[dict[str, Any]] | None = Field(default=None, description="Available tools") + dataset_configs: dict[str, Any] | None = Field(default=None, description="Dataset configurations") + agent_mode: dict[str, Any] | None = Field(default=None, description="Agent mode configuration") + + +register_schema_models(console_ns, ModelConfigRequest) + + @console_ns.route("/apps//model-config") class ModelConfigResource(Resource): @console_ns.doc("update_app_model_config") @console_ns.doc(description="Update application model configuration") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "ModelConfigRequest", - { - "provider": fields.String(description="Model provider"), - "model": fields.String(description="Model name"), - "configs": fields.Raw(description="Model configuration parameters"), - "opening_statement": fields.String(description="Opening statement"), - "suggested_questions": fields.List(fields.String(), description="Suggested questions"), - "more_like_this": fields.Raw(description="More like this configuration"), - "speech_to_text": fields.Raw(description="Speech to text configuration"), - "text_to_speech": fields.Raw(description="Text to speech configuration"), - "retrieval_model": fields.Raw(description="Retrieval model configuration"), - "tools": fields.List(fields.Raw(), description="Available tools"), - "dataset_configs": fields.Raw(description="Dataset configurations"), - "agent_mode": fields.Raw(description="Agent mode configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[ModelConfigRequest.__name__]) @console_ns.response(200, "Model configuration updated successfully") @console_ns.response(400, "Invalid configuration") @console_ns.response(404, "App not found") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 7f44a99ff13..9991d78d949 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,11 +1,12 @@ from typing import Literal -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( @@ -15,13 +16,11 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.app_fields import app_site_fields +from fields.base import ResponseModel from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import Site -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AppSiteUpdatePayload(BaseModel): title: str | None = Field(default=None) @@ -49,13 +48,26 @@ class AppSiteUpdatePayload(BaseModel): return supported_language(value) -console_ns.schema_model( - AppSiteUpdatePayload.__name__, - AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class AppSiteResponse(ResponseModel): + app_id: str + access_token: str | None = Field(default=None, validation_alias="code") + code: str | None = None + title: str + icon: str | None = None + icon_background: str | None = None + description: str | None = None + default_language: str + customize_domain: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + customize_token_strategy: str + prompt_public: bool + show_workflow_steps: bool + use_icon_as_answer_icon: bool -# Register model for flask_restx to avoid dict type issues in Swagger -app_site_model = console_ns.model("AppSite", app_site_fields) + +register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse) @console_ns.route("/apps//site") @@ -64,7 +76,7 @@ class AppSite(Resource): @console_ns.doc(description="Update application site configuration") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__]) - @console_ns.response(200, "Site configuration updated successfully", app_site_model) + @console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "App not found") @setup_required @@ -72,7 +84,6 @@ class AppSite(Resource): @edit_permission_required @account_initialization_required @get_app_model - @marshal_with(app_site_model) def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() @@ -106,7 +117,7 @@ class AppSite(Resource): site.updated_at = naive_utc_now() db.session.commit() - return site + return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//site/access-token-reset") @@ -114,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @console_ns.doc("reset_app_site_access_token") @console_ns.doc(description="Reset access token for application site") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Access token reset successfully", app_site_model) + @console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__]) @console_ns.response(403, "Insufficient permissions (admin/owner required)") @console_ns.response(404, "App or site not found") @setup_required @@ -122,7 +133,6 @@ class AppSiteAccessTokenReset(Resource): @is_admin_or_owner_required @account_initialization_required @get_app_model - @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) @@ -135,4 +145,4 @@ class AppSiteAccessTokenReset(Resource): site.updated_at = naive_utc_now() db.session.commit() - return site + return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index dcd24d2200f..478f783eb05 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -4,16 +4,13 @@ from collections.abc import Sequence from typing import Any from flask import abort, request -from flask_restx import Resource, fields, marshal_with -from graphon.enums import NodeType -from graphon.file import File -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.utils.encoders import jsonable_encoder +from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, ValidationError, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services +from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload from controllers.console import console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.workflow_run import workflow_run_node_execution_model @@ -38,7 +35,13 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields +from fields.online_user_fields import online_user_list_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from graphon.enums import NodeType +from graphon.file import File +from graphon.file import helpers as file_helpers +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value @@ -46,6 +49,7 @@ from libs.login import current_account_with_tenant, login_required from models import App from models.model import AppMode from models.workflow import Workflow +from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -56,6 +60,7 @@ _file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" +MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50 # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -142,10 +147,6 @@ class PublishWorkflowPayload(BaseModel): marked_comment: str | None = Field(default=None, max_length=100) -class DefaultBlockConfigQuery(BaseModel): - q: str | None = None - - class ConvertToWorkflowPayload(BaseModel): name: str | None = None icon_type: str | None = None @@ -153,16 +154,12 @@ class ConvertToWorkflowPayload(BaseModel): icon_background: str | None = None -class WorkflowListQuery(BaseModel): - page: int = Field(default=1, ge=1, le=99999) - limit: int = Field(default=10, ge=1, le=100) - user_id: str | None = None - named_only: bool = False +class WorkflowFeaturesPayload(BaseModel): + features: dict[str, Any] = Field(..., description="Workflow feature configuration") -class WorkflowUpdatePayload(BaseModel): - marked_name: str | None = Field(default=None, max_length=20) - marked_comment: str | None = Field(default=None, max_length=100) +class WorkflowOnlineUsersQuery(BaseModel): + app_ids: str = Field(..., description="Comma-separated app IDs") class DraftWorkflowTriggerRunPayload(BaseModel): @@ -188,6 +185,8 @@ reg(DefaultBlockConfigQuery) reg(ConvertToWorkflowPayload) reg(WorkflowListQuery) reg(WorkflowUpdatePayload) +reg(WorkflowFeaturesPayload) +reg(WorkflowOnlineUsersQuery) reg(DraftWorkflowTriggerRunPayload) reg(DraftWorkflowTriggerRunAllPayload) @@ -946,6 +945,32 @@ class ConvertToWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/draft/features") +class WorkflowFeaturesApi(Resource): + """Update draft workflow features.""" + + @console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__]) + @console_ns.doc("update_workflow_features") + @console_ns.doc(description="Update draft workflow features") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Workflow features updated successfully") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App): + current_user, _ = current_account_with_tenant() + + args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {}) + features = args.features + + workflow_service = WorkflowService() + workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user) + + return {"result": "success"} + + @console_ns.route("/apps//workflows") class PublishedAllWorkflowApi(Resource): @console_ns.expect(console_ns.models[WorkflowListQuery.__name__]) @@ -957,7 +982,6 @@ class PublishedAllWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_pagination_model) @edit_permission_required def get(self, app_model: App): """ @@ -985,9 +1009,10 @@ class PublishedAllWorkflowApi(Resource): user_id=user_id, named_only=named_only, ) + serialized_workflows = marshal(workflows, workflow_fields_copy) return { - "items": workflows, + "items": serialized_workflows, "page": page, "limit": limit, "has_more": has_more, @@ -1355,3 +1380,62 @@ class DraftWorkflowTriggerRunAllApi(Resource): "status": "error", } ), 400 + + +@console_ns.route("/apps/workflows/online-users") +class WorkflowOnlineUsersApi(Resource): + @console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__]) + @console_ns.doc("get_workflow_online_users") + @console_ns.doc(description="Get workflow online users") + @setup_required + @login_required + @account_initialization_required + @marshal_with(online_user_list_fields) + def get(self): + args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip())) + if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS: + raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.") + + if not app_ids: + return {"data": []} + + _, current_tenant_id = current_account_with_tenant() + workflow_service = WorkflowService() + accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id) + + results = [] + for app_id in app_ids: + if app_id not in accessible_app_ids: + continue + + users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}") + + users = [] + for _, user_info_json in users_json.items(): + try: + user_info = json.loads(user_info_json) + except Exception: + continue + + if not isinstance(user_info, dict): + continue + + avatar = user_info.get("avatar") + if isinstance(avatar, str) and avatar and not avatar.startswith(("http://", "https://")): + try: + user_info["avatar"] = file_helpers.get_signed_file_url(avatar) + except Exception as exc: + logger.warning( + "Failed to sign workflow online user avatar; using original value. " + "app_id=%s avatar=%s error=%s", + app_id, + avatar, + exc, + ) + + users.append(user_info) + results.append({"app_id": app_id, "users": users}) + + return {"data": results} diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 3b24c2a4021..4b39590235e 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,27 +1,26 @@ from datetime import datetime +from typing import Any from dateutil.parser import isoparse from flask import request -from flask_restx import Resource, marshal_with -from graphon.enums import WorkflowExecutionStatus +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from fields.workflow_app_log_fields import ( - build_workflow_app_log_pagination_model, - build_workflow_archived_log_pagination_model, -) +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount +from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode from services.workflow_app_service import WorkflowAppService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class WorkflowAppLogQuery(BaseModel): keyword: str | None = Field(default=None, description="Search keyword for filtering logs") @@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel): raise ValueError("Invalid boolean value for detail") -console_ns.schema_model( - WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None -# Register model for flask_restx to avoid dict type issues in Swagger -workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) -workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns) + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowRunForArchivedLogResponse(ResponseModel): + id: str + status: str | None = None + triggered_from: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: Any = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowArchivedLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForArchivedLogResponse | None = None + trigger_metadata: Any = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +class WorkflowArchivedLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowArchivedLogPartialResponse] + + +register_schema_models( + console_ns, + WorkflowAppLogQuery, + WorkflowRunForLogResponse, + WorkflowRunForArchivedLogResponse, + WorkflowAppLogPartialResponse, + WorkflowArchivedLogPartialResponse, + WorkflowAppLogPaginationResponse, + WorkflowArchivedLogPaginationResponse, +) @console_ns.route("/apps//workflow-app-logs") @@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource): @console_ns.doc(description="Get workflow application execution logs") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) - @console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model) + @console_ns.response( + 200, + "Workflow app logs retrieved successfully", + console_ns.models[WorkflowAppLogPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - @marshal_with(workflow_app_log_pagination_model) def get(self, app_model: App): """ Get workflow app logs @@ -87,7 +189,7 @@ class WorkflowAppLogApi(Resource): # get paginate workflow app logs workflow_app_service = WorkflowAppService() - with sessionmaker(db.engine).begin() as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( session=session, app_model=app_model, @@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource): created_by_account=args.created_by_account, ) - return workflow_app_log_pagination + return WorkflowAppLogPaginationResponse.model_validate( + workflow_app_log_pagination, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/apps//workflow-archived-logs") @@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource): @console_ns.doc(description="Get workflow archived execution logs") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) - @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model) + @console_ns.response( + 200, + "Workflow archived logs retrieved successfully", + console_ns.models[WorkflowArchivedLogPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - @marshal_with(workflow_archived_log_pagination_model) def get(self, app_model: App): """ Get workflow archived logs @@ -124,7 +231,7 @@ class WorkflowArchivedLogApi(Resource): args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore workflow_app_service = WorkflowAppService() - with sessionmaker(db.engine).begin() as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs( session=session, app_model=app_model, @@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource): limit=args.limit, ) - return workflow_app_log_pagination + return WorkflowArchivedLogPaginationResponse.model_validate( + workflow_app_log_pagination, from_attributes=True + ).model_dump(mode="json") diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py new file mode 100644 index 00000000000..e7c3e982a6b --- /dev/null +++ b/api/controllers/console/app/workflow_comment.py @@ -0,0 +1,335 @@ +import logging + +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field, TypeAdapter + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from fields.member_fields import AccountWithRole +from fields.workflow_comment_fields import ( + workflow_comment_basic_fields, + workflow_comment_create_fields, + workflow_comment_detail_fields, + workflow_comment_reply_create_fields, + workflow_comment_reply_update_fields, + workflow_comment_resolve_fields, + workflow_comment_update_fields, +) +from libs.login import current_user, login_required +from models import App +from services.account_service import TenantService +from services.workflow_comment_service import WorkflowCommentService + +logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowCommentCreatePayload(BaseModel): + content: str = Field(..., description="Comment content") + position_x: float = Field(..., description="Comment X position") + position_y: float = Field(..., description="Comment Y position") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentUpdatePayload(BaseModel): + content: str = Field(..., description="Comment content") + position_x: float | None = Field(default=None, description="Comment X position") + position_y: float | None = Field(default=None, description="Comment Y position") + mentioned_user_ids: list[str] | None = Field( + default=None, + description="Mentioned user IDs. Omit to keep existing mentions.", + ) + + +class WorkflowCommentReplyPayload(BaseModel): + content: str = Field(..., description="Reply content") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentMentionUsersPayload(BaseModel): + users: list[AccountWithRole] + + +for model in ( + WorkflowCommentCreatePayload, + WorkflowCommentUpdatePayload, + WorkflowCommentReplyPayload, +): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload) + +workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields) +workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields) +workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields) +workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields) +workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields) +workflow_comment_reply_create_model = console_ns.model( + "WorkflowCommentReplyCreate", workflow_comment_reply_create_fields +) +workflow_comment_reply_update_model = console_ns.model( + "WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields +) + + +@console_ns.route("/apps//workflow/comments") +class WorkflowCommentListApi(Resource): + """API for listing and creating workflow comments.""" + + @console_ns.doc("list_workflow_comments") + @console_ns.doc(description="Get all comments for a workflow") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_basic_model, envelope="data") + def get(self, app_model: App): + """Get all comments for a workflow.""" + comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id) + + return comments + + @console_ns.doc("create_workflow_comment") + @console_ns.doc(description="Create a new workflow comment") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__]) + @console_ns.response(201, "Comment created successfully", workflow_comment_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_create_model) + @edit_permission_required + def post(self, app_model: App): + """Create a new workflow comment.""" + payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + created_by=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments/") +class WorkflowCommentDetailApi(Resource): + """API for managing individual workflow comments.""" + + @console_ns.doc("get_workflow_comment") + @console_ns.doc(description="Get a specific workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_detail_model) + def get(self, app_model: App, comment_id: str): + """Get a specific workflow comment.""" + comment = WorkflowCommentService.get_comment( + tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id + ) + + return comment + + @console_ns.doc("update_workflow_comment") + @console_ns.doc(description="Update a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__]) + @console_ns.response(200, "Comment updated successfully", workflow_comment_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_update_model) + @edit_permission_required + def put(self, app_model: App, comment_id: str): + """Update a workflow comment.""" + payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.update_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result + + @console_ns.doc("delete_workflow_comment") + @console_ns.doc(description="Delete a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(204, "Comment deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @edit_permission_required + def delete(self, app_model: App, comment_id: str): + """Delete a workflow comment.""" + WorkflowCommentService.delete_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments//resolve") +class WorkflowCommentResolveApi(Resource): + """API for resolving and reopening workflow comments.""" + + @console_ns.doc("resolve_workflow_comment") + @console_ns.doc(description="Resolve a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_resolve_model) + @edit_permission_required + def post(self, app_model: App, comment_id: str): + """Resolve a workflow comment.""" + comment = WorkflowCommentService.resolve_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return comment + + +@console_ns.route("/apps//workflow/comments//replies") +class WorkflowCommentReplyApi(Resource): + """API for managing comment replies.""" + + @console_ns.doc("create_workflow_comment_reply") + @console_ns.doc(description="Add a reply to a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) + @console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_create_model) + @edit_permission_required + def post(self, app_model: App, comment_id: str): + """Add a reply to a workflow comment.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_reply( + comment_id=comment_id, + content=payload.content, + created_by=current_user.id, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments//replies/") +class WorkflowCommentReplyDetailApi(Resource): + """API for managing individual comment replies.""" + + @console_ns.doc("update_workflow_comment_reply") + @console_ns.doc(description="Update a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) + @console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_update_model) + @edit_permission_required + def put(self, app_model: App, comment_id: str, reply_id: str): + """Update a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) + + reply = WorkflowCommentService.update_reply( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + reply_id=reply_id, + user_id=current_user.id, + content=payload.content, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return reply + + @console_ns.doc("delete_workflow_comment_reply") + @console_ns.doc(description="Delete a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.response(204, "Reply deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @edit_permission_required + def delete(self, app_model: App, comment_id: str, reply_id: str): + """Delete a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + WorkflowCommentService.delete_reply( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + reply_id=reply_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments/mention-users") +class WorkflowCommentMentionUsersApi(Resource): + """API for getting mentionable users for workflow comments.""" + + @console_ns.doc("workflow_comment_mention_users") + @console_ns.doc(description="Get all users in current tenant for mentions") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response( + 200, "Mentionable users retrieved successfully", console_ns.models[WorkflowCommentMentionUsersPayload.__name__] + ) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + def get(self, app_model: App): + """Get all users in current tenant for mentions.""" + members = TenantService.get_tenant_members(current_user.current_tenant) + users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = WorkflowCommentMentionUsersPayload(users=users) + return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 366f1453608..e32ba5f66c5 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -1,14 +1,10 @@ import logging from collections.abc import Callable from functools import wraps -from typing import Any +from typing import Any, TypedDict from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.file import helpers as file_helpers -from graphon.variables.segment_group import SegmentGroup -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker @@ -22,8 +18,13 @@ from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.file_access import DatabaseFileAccessController from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db +from factories import variable_factory from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable @@ -45,6 +46,16 @@ class WorkflowDraftVariableUpdatePayload(BaseModel): value: Any | None = Field(default=None, description="Variable value") +class ConversationVariableUpdatePayload(BaseModel): + conversation_variables: list[dict[str, Any]] = Field( + ..., description="Conversation variables for the draft workflow" + ) + + +class EnvironmentVariableUpdatePayload(BaseModel): + environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow") + + console_ns.schema_model( WorkflowDraftVariableListQuery.__name__, WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), @@ -53,6 +64,14 @@ console_ns.schema_model( WorkflowDraftVariableUpdatePayload.__name__, WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) +console_ns.schema_model( + ConversationVariableUpdatePayload.__name__, + ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + EnvironmentVariableUpdatePayload.__name__, + EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) def _convert_values_to_json_serializable_object(value: Segment): @@ -83,10 +102,17 @@ def _serialize_var_value(variable: WorkflowDraftVariable): def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: value_type = workflow_draft_var.value_type - return value_type.exposed_type().value + return str(value_type.exposed_type()) -def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None: +class FullContentDict(TypedDict): + size_bytes: int | None + value_type: str + length: int | None + download_url: str + + +def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None: """Serialize full_content information for large variables.""" if not variable.is_truncated(): return None @@ -94,12 +120,13 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None: variable_file = variable.variable_file assert variable_file is not None - return { + result: FullContentDict = { "size_bytes": variable_file.size, - "value_type": variable_file.value_type.exposed_type().value, + "value_type": str(variable_file.value_type.exposed_type()), "length": variable_file.length, "download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True), } + return result def _ensure_variable_access( @@ -193,7 +220,7 @@ workflow_draft_variable_list_model = console_ns.model( ) -def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]: +def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]: """Common prerequisites for all draft workflow variable APIs. It ensures the following conditions are satisfied: @@ -210,7 +237,7 @@ def _api_prerequisite(f: Callable[..., Any]) -> Callable[..., Any]: @edit_permission_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @wraps(f) - def wrapper(*args: Any, **kwargs: Any): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response: return f(*args, **kwargs) return wrapper @@ -384,24 +411,27 @@ class VariableApi(Resource): new_value = None if raw_value is not None: - if variable.value_type == SegmentType.FILE: - if not isinstance(raw_value, dict): - raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping( - mapping=raw_value, - tenant_id=app_model.tenant_id, - access_controller=_file_access_controller, - ) - elif variable.value_type == SegmentType.ARRAY_FILE: - if not isinstance(raw_value, list): - raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") - if len(raw_value) > 0 and not isinstance(raw_value[0], dict): - raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings( - mappings=raw_value, - tenant_id=app_model.tenant_id, - access_controller=_file_access_controller, - ) + match variable.value_type: + case SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) + case SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) + case _: + pass new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() @@ -499,6 +529,34 @@ class ConversationVariableCollectionApi(Resource): db.session.commit() return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) + @console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__]) + @console_ns.doc("update_conversation_variables") + @console_ns.doc(description="Update conversation variables for workflow draft") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Conversation variables updated successfully") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_app_model(mode=AppMode.ADVANCED_CHAT) + def post(self, app_model: App): + payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {}) + + workflow_service = WorkflowService() + + conversation_variables_list = payload.conversation_variables + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + + workflow_service.update_draft_workflow_conversation_variables( + app_model=app_model, + account=current_user, + conversation_variables=conversation_variables, + ) + + return {"result": "success"} + @console_ns.route("/apps//workflows/draft/system-variables") class SystemVariableCollectionApi(Resource): @@ -540,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.exposed_type().value, + "value_type": str(v.value_type.exposed_type()), "value": v.value, # Do not track edited for env vars. "edited": False, @@ -550,3 +608,31 @@ class EnvironmentVariableCollectionApi(Resource): ) return {"items": env_vars_list} + + @console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__]) + @console_ns.doc("update_environment_variables") + @console_ns.doc(description="Update environment variables for workflow draft") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Environment variables updated successfully") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {}) + + workflow_service = WorkflowService() + + environment_variables_list = payload.environment_variables + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + + workflow_service.update_draft_workflow_environment_variables( + app_model=app_model, + account=current_user, + environment_variables=environment_variables, + ) + + return {"result": "success"} diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 83e8bedc110..6748d95d6bf 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,8 +3,6 @@ from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -28,6 +26,8 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value @@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME -from services.workflow_run_service import WorkflowRunService +from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService def _build_backstage_input_url(form_token: str | None) -> str | None: @@ -214,7 +214,11 @@ class AdvancedChatAppWorkflowRunListApi(Resource): Get advanced chat app workflow run list """ args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - args = args_model.model_dump(exclude_none=True) + args: WorkflowRunListArgs = {"limit": args_model.limit} + if args_model.last_id is not None: + args["last_id"] = args_model.last_id + if args_model.status is not None: + args["status"] = args_model.status # Default to DEBUGGING if not specified triggered_from = ( @@ -356,7 +360,11 @@ class WorkflowRunListApi(Resource): Get workflow run list """ args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - args = args_model.model_dump(exclude_none=True) + args: WorkflowRunListArgs = {"limit": args_model.limit} + if args_model.last_id is not None: + args["last_id"] = args_model.last_id + if args_model.status is not None: + args["status"] = args_model.status # Default to DEBUGGING for workflow if not specified (backward compatibility) triggered_from = ( diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index aa37d247383..a6715fa200f 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -1,16 +1,17 @@ import logging +from datetime import datetime from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel +from flask_restx import Resource +from pydantic import BaseModel, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound from configs import dify_config -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from extensions.ext_database import db -from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields +from fields.base import ResponseModel from libs.login import current_user, login_required from models.enums import AppTriggerStatus from models.model import Account, App, AppMode @@ -21,15 +22,6 @@ from ..app.wraps import get_app_model from ..wraps import account_initialization_required, edit_permission_required, setup_required logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - -trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields) - -triggers_list_fields_copy = triggers_list_fields.copy() -triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model)) -triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy) - -webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields) class Parser(BaseModel): @@ -41,10 +33,52 @@ class ParserEnable(BaseModel): enable_trigger: bool -console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class WorkflowTriggerResponse(ResponseModel): + id: str + trigger_type: str + title: str + node_id: str + provider_name: str + icon: str + status: str + created_at: datetime | None = None + updated_at: datetime | None = None -console_ns.schema_model( - ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) + @field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before") + @classmethod + def _normalize_string_fields(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) + + +class WorkflowTriggerListResponse(ResponseModel): + data: list[WorkflowTriggerResponse] + + +class WebhookTriggerResponse(ResponseModel): + id: str + webhook_id: str + webhook_url: str + webhook_debug_url: str + node_id: str + created_at: datetime | None = None + + @field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before") + @classmethod + def _normalize_string_fields(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) + + +register_schema_models( + console_ns, + Parser, + ParserEnable, + WorkflowTriggerResponse, + WorkflowTriggerListResponse, + WebhookTriggerResponse, ) @@ -57,28 +91,28 @@ class WebhookTriggerApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(webhook_trigger_model) + @console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__]) def get(self, app_model: App): """Get webhook trigger for a node""" args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore node_id = args.node_id - with sessionmaker(db.engine).begin() as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Get webhook trigger for this app and node - webhook_trigger = ( - session.query(WorkflowWebhookTrigger) + webhook_trigger = session.scalar( + select(WorkflowWebhookTrigger) .where( WorkflowWebhookTrigger.app_id == app_model.id, WorkflowWebhookTrigger.node_id == node_id, ) - .first() + .limit(1) ) if not webhook_trigger: raise NotFound("Webhook trigger not found for this node") - return webhook_trigger + return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//triggers") @@ -89,13 +123,13 @@ class AppTriggersApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(triggers_list_model) + @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__]) def get(self, app_model: App): """Get app triggers list""" assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None - with sessionmaker(db.engine).begin() as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Get all triggers for this app using select API triggers = ( session.execute( @@ -118,7 +152,9 @@ class AppTriggersApi(Resource): else: trigger.icon = "" # type: ignore - return {"data": triggers} + return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//trigger-enable") @@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(trigger_model) + @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__]) def post(self, app_model: App): """Update app trigger (enable/disable)""" args = ParserEnable.model_validate(console_ns.payload) @@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource): else: trigger.icon = "" # type: ignore - return trigger + return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index bd6f019eac1..c9cf08072ad 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Any +from typing import overload from sqlalchemy import select @@ -23,14 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None: return app_model -def get_app_model( - view: Callable[..., Any] | None = None, +@overload +def get_app_model[**P, R]( + view: Callable[P, R], *, mode: AppMode | list[AppMode] | None = None, -) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]: +) -> Callable[P, R]: ... + + +@overload +def get_app_model[**P, R]( + view: None = None, + *, + mode: AppMode | list[AppMode] | None = None, +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... + + +def get_app_model[**P, R]( + view: Callable[P, R] | None = None, + *, + mode: AppMode | list[AppMode] | None = None, +) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: + def decorator(view_func: Callable[P, R]) -> Callable[P, R]: @wraps(view_func) - def decorated_view(*args: Any, **kwargs: Any): + def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") @@ -68,14 +84,30 @@ def get_app_model( return decorator(view) -def get_app_model_with_trial( - view: Callable[..., Any] | None = None, +@overload +def get_app_model_with_trial[**P, R]( + view: Callable[P, R], *, mode: AppMode | list[AppMode] | None = None, -) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]: +) -> Callable[P, R]: ... + + +@overload +def get_app_model_with_trial[**P, R]( + view: None = None, + *, + mode: AppMode | list[AppMode] | None = None, +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... + + +def get_app_model_with_trial[**P, R]( + view: Callable[P, R] | None = None, + *, + mode: AppMode | list[AppMode] | None = None, +) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: + def decorator(view_func: Callable[P, R]) -> Callable[P, R]: @wraps(view_func) - def decorated_view(*args: Any, **kwargs: Any): + def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index f741107b874..f7061f820f9 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,8 +1,11 @@ +from typing import Any + from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.error import AlreadyActivateError from extensions.ext_database import db @@ -11,8 +14,6 @@ from libs.helper import EmailStr, timezone from models import AccountStatus from services.account_service import RegisterService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ActivateCheckQuery(BaseModel): workspace_id: str | None = Field(default=None) @@ -39,8 +40,16 @@ class ActivatePayload(BaseModel): return timezone(value) -for model in (ActivateCheckQuery, ActivatePayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class ActivationCheckResponse(BaseModel): + is_valid: bool = Field(description="Whether token is valid") + data: dict[str, Any] | None = Field(default=None, description="Activation data if valid") + + +class ActivationResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse) @console_ns.route("/activate/check") @@ -51,13 +60,7 @@ class ActivateCheckApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "ActivationCheckResponse", - { - "is_valid": fields.Boolean(description="Whether token is valid"), - "data": fields.Raw(description="Activation data if valid"), - }, - ), + console_ns.models[ActivationCheckResponse.__name__], ) def get(self): args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -95,12 +98,7 @@ class ActivateApi(Resource): @console_ns.response( 200, "Account activated successfully", - console_ns.model( - "ActivationResponse", - { - "result": fields.String(description="Operation result"), - }, - ), + console_ns.models[ActivationResponse.__name__], ) @console_ns.response(400, "Already activated or invalid token") def post(self): diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 9e7faa09c58..1fd781b4fc2 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -14,7 +13,6 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models import Account @@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - raise EmailAlreadyInUseError() - else: - account = self._create_new_account(normalized_email, args.password_confirm) - if not account: - raise AccountNotFoundError() - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(normalized_email) + if account: + raise EmailAlreadyInUseError() + else: + account = self._create_new_account(normalized_email, args.password_confirm) + if not account: + raise AccountNotFoundError() + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 844f3c91ff0..ed390a5f89d 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -3,8 +3,7 @@ import secrets from flask import request from flask_restx import Resource -from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import sessionmaker +from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -20,35 +19,18 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip -from libs.password import hash_password, valid_password +from libs.password import hash_password from services.account_service import AccountService, TenantService +from services.entities.auth_entities import ( + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordSendPayload, +) from services.feature_service import FeatureService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" -class ForgotPasswordSendPayload(BaseModel): - email: EmailStr = Field(...) - language: str | None = Field(default=None) - - -class ForgotPasswordCheckPayload(BaseModel): - email: EmailStr = Field(...) - code: str = Field(...) - token: str = Field(...) - - -class ForgotPasswordResetPayload(BaseModel): - token: str = Field(...) - new_password: str = Field(...) - password_confirm: str = Field(...) - - @field_validator("new_password", "password_confirm") - @classmethod - def validate_password(cls, value: str) -> str: - return valid_password(value) - - class ForgotPasswordEmailResponse(BaseModel): result: str = Field(description="Operation result") data: str | None = Field(default=None, description="Reset token") @@ -102,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_reset_password_email( account=account, @@ -201,17 +182,18 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt, session) - else: - raise AccountNotFound() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AccountNotFound() return {"result": "success"} - def _update_existing_account(self, account, password_hashed, salt, session): + def _update_existing_account(self, account, password_hashed, salt): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 400df138b87..8216b3d0da1 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,9 +1,10 @@ -from typing import Any +import logging import flask_login from flask import make_response, request from flask_restx import Resource from pydantic import BaseModel, Field +from werkzeug.exceptions import Unauthorized import services from configs import dify_config @@ -42,18 +43,18 @@ from libs.token import ( set_csrf_token_to_cookie, set_refresh_token_to_cookie, ) -from services.account_service import AccountService, RegisterService, TenantService +from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService from services.billing_service import BillingService +from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +logger = logging.getLogger(__name__) -class LoginPayload(BaseModel): - email: EmailStr = Field(..., description="Email address") - password: str = Field(..., description="Password") +class LoginPayload(LoginPayloadBase): remember_me: bool = Field(default=False, description="Remember me flag") invite_token: str | None = Field(default=None, description="Invitation token") @@ -94,14 +95,16 @@ class LoginApi(Resource): normalized_email = request_email.lower() if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): + _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE) raise AccountInFreezeError() is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email) if is_login_error_rate_limit: + _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED) raise EmailPasswordLoginLimitError() invite_token = args.invite_token - invitation_data: dict[str, Any] | None = None + invitation_data: InvitationDetailDict | None = None if invite_token: invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token) if invitation_data is None: @@ -113,14 +116,20 @@ class LoginApi(Resource): invitee_email = data.get("email") if data else None invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email if invitee_email_normalized != normalized_email: + _log_console_login_failure( + email=normalized_email, + reason=LoginFailureReason.INVALID_INVITATION_EMAIL, + ) raise InvalidEmailError() account = _authenticate_account_with_case_fallback( request_email, normalized_email, args.password, invite_token ) except services.errors.account.AccountLoginError: + _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED) raise AccountBannedError() except services.errors.account.AccountPasswordError as exc: AccountService.add_login_error_rate_limit(normalized_email) + _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS) raise AuthenticationFailedError() from exc # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) @@ -243,20 +252,27 @@ class EmailCodeLoginApi(Resource): token_data = AccountService.get_email_code_login_data(args.token) if token_data is None: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN) raise InvalidTokenError() token_email = token_data.get("email") normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email if normalized_token_email != user_email: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH) raise InvalidEmailError() if token_data["code"] != args.code: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE) raise EmailCodeError() AccountService.revoke_email_code_login_token(args.token) try: account = _get_account_with_case_fallback(original_email) + except Unauthorized as exc: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED) + raise AccountBannedError() from exc except AccountRegisterError: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE) raise AccountInFreezeError() if account: tenants = TenantService.get_join_tenants(account) @@ -282,6 +298,7 @@ class EmailCodeLoginApi(Resource): except WorkSpaceNotAllowedCreateError: raise NotAllowedCreateWorkspace() except AccountRegisterError: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE) raise AccountInFreezeError() except WorkspacesLimitExceededError: raise WorkspacesLimitExceeded() @@ -339,3 +356,12 @@ def _authenticate_account_with_case_fallback( if original_email == normalized_email: raise return AccountService.authenticate(normalized_email, password, invite_token) + + +def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None: + logger.warning( + "Console login failed: email=%s reason=%s ip_address=%s", + email, + reason, + extract_remote_ip(request), + ) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5c7011fd22f..d31fb4a46c1 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,6 @@ import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(user_info.email) return account diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index b55cda42448..727428c8e7f 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -5,11 +5,11 @@ from typing import Concatenate from flask import jsonify, request from flask.typing import ResponseReturnValue from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 23c01eedb1d..45de338559a 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -2,18 +2,17 @@ import base64 from typing import Literal from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from enums.cloud_plan import CloudPlan from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class SubscriptionQuery(BaseModel): plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan") @@ -24,8 +23,7 @@ class PartnerTenantsPayload(BaseModel): click_id: str = Field(..., description="Click Id from partner referral link") -for model in (SubscriptionQuery, PartnerTenantsPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, SubscriptionQuery, PartnerTenantsPayload) @console_ns.route("/billing/subscription") @@ -58,12 +56,7 @@ class PartnerTenants(Resource): @console_ns.doc("sync_partner_tenants_bindings") @console_ns.doc(description="Sync partner tenants bindings") @console_ns.doc(params={"partner_key": "Partner key"}) - @console_ns.expect( - console_ns.model( - "SyncPartnerTenantsBindingsRequest", - {"click_id": fields.String(required=True, description="Click Id from partner referral link")}, - ) - ) + @console_ns.expect(console_ns.models[PartnerTenantsPayload.__name__]) @console_ns.response(200, "Tenants synced to partner successfully") @console_ns.response(400, "Invalid partner information") @setup_required diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index ac143490458..ed3c1a59d4e 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -158,10 +158,13 @@ class DataSourceApi(Resource): @login_required @account_initialization_required def patch(self, binding_id, action: Literal["enable", "disable"]): + _, current_tenant_id = current_account_with_tenant() binding_id = str(binding_id) with sessionmaker(db.engine, expire_on_commit=False).begin() as session: data_source_binding = session.execute( - select(DataSourceOauthBinding).filter_by(id=binding_id) + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id + ) ).scalar_one_or_none() if data_source_binding is None: raise NotFound("Data source binding not found.") @@ -221,11 +224,11 @@ class DataSourceNotionListApi(Resource): raise ValueError("Dataset is not notion type.") documents = session.scalars( - select(Document).filter_by( - dataset_id=query.dataset_id, - tenant_id=current_tenant_id, - data_source_type="notion_import", - enabled=True, + select(Document).where( + Document.dataset_id == query.dataset_id, + Document.tenant_id == current_tenant_id, + Document.data_source_type == "notion_import", + Document.enabled.is_(True), ) ).all() if documents: diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f23c7eb4310..d001dfba64e 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,7 +2,6 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound @@ -11,10 +10,7 @@ import services from configs import dify_config from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns -from controllers.console.apikey import ( - api_key_item_model, - api_key_list_model, -) +from controllers.console.apikey import ApiKeyItem, ApiKeyList from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from controllers.console.wraps import ( @@ -52,7 +48,9 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required +from libs.url_utils import normalize_api_base_url from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum from models.enums import ApiTokenType, SegmentStatus @@ -785,23 +783,23 @@ class DatasetApiKeyApi(Resource): @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get dataset API keys") - @console_ns.response(200, "API keys retrieved successfully", api_key_list_model) + @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_key_list_model) def get(self): _, current_tenant_id = current_account_with_tenant() keys = db.session.scalars( select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) ).all() - return {"items": keys} + return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json") + @console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) + @console_ns.response(400, "Maximum keys exceeded") @setup_required @login_required @is_admin_or_owner_required @account_initialization_required - @marshal_with(api_key_item_model) def post(self): _, current_tenant_id = current_account_with_tenant() @@ -828,7 +826,7 @@ class DatasetApiKeyApi(Resource): api_token.type = self.resource_type db.session.add(api_token) db.session.commit() - return api_token, 200 + return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200 @console_ns.route("/datasets/api-keys/") @@ -892,7 +890,8 @@ class DatasetApiBaseUrlApi(Resource): @login_required @account_initialization_required def get(self): - return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} + base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/") + return {"api_base_url": normalize_api_base_url(base)} @console_ns.route("/datasets/retrieval-setting") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ab367d84838..3372a967d9b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -3,20 +3,19 @@ import logging from argparse import ArgumentTypeError from collections.abc import Sequence from contextlib import ExitStack +from datetime import datetime from typing import Any, Literal, cast -from uuid import UUID import sqlalchemy as sa from flask import request, send_file -from flask_restx import Resource, fields, marshal, marshal_with -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from pydantic import BaseModel, Field +from flask_restx import Resource, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.common.schema import get_or_create_model, register_schema_models +from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload +from controllers.common.schema import register_schema_models from controllers.console import console_ns from core.errors.error import ( LLMBadRequestError, @@ -31,14 +30,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db -from fields.dataset_fields import dataset_fields +from fields.base import ResponseModel from fields.document_fields import ( - dataset_and_document_fields, document_fields, - document_metadata_fields, document_status_fields, document_with_segments_fields, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile @@ -71,31 +70,101 @@ from ..wraps import ( logger = logging.getLogger(__name__) -# NOTE: Keep constants near the top of the module for discoverability. -DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -# Register models for flask_restx to avoid dict type issues in Swagger -dataset_model = get_or_create_model("Dataset", dataset_fields) +def _normalize_enum(value: Any) -> Any: + if isinstance(value, str) or value is None: + return value + return getattr(value, "value", value) -document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields) -document_fields_copy = document_fields.copy() -document_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_model = get_or_create_model("Document", document_fields_copy) +class DatasetResponse(ResponseModel): + id: str + name: str + description: str | None = None + permission: str | None = None + data_source_type: str | None = None + indexing_technique: str | None = None + created_by: str | None = None + created_at: int | None = None -document_with_segments_fields_copy = document_with_segments_fields.copy() -document_with_segments_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) + @field_validator("data_source_type", "indexing_technique", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) -dataset_and_document_fields_copy = dataset_and_document_fields.copy() -dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) -dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) -dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentMetadataResponse(ResponseModel): + id: str + name: str + type: str + value: str | None = None + + +class DocumentResponse(ResponseModel): + id: str + position: int | None = None + data_source_type: str | None = None + data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict") + data_source_detail_dict: Any = None + dataset_process_rule_id: str | None = None + name: str + created_from: str | None = None + created_by: str | None = None + created_at: int | None = None + tokens: int | None = None + indexing_status: str | None = None + error: str | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + archived: bool | None = None + display_status: str | None = None + word_count: int | None = None + hit_count: int | None = None + doc_form: str | None = None + doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details") + summary_index_status: str | None = None + need_summary: bool | None = None + + @field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("doc_metadata", mode="before") + @classmethod + def _normalize_doc_metadata(cls, value: Any) -> list[Any]: + if value is None: + return [] + return value + + @field_validator("created_at", "disabled_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentWithSegmentsResponse(DocumentResponse): + process_rule_dict: Any = None + completed_segments: int | None = None + total_segments: int | None = None + + +class DatasetAndDocumentResponse(ResponseModel): + dataset: DatasetResponse + documents: list[DocumentResponse] + batch: str class DocumentRetryPayload(BaseModel): @@ -110,10 +179,9 @@ class GenerateSummaryPayload(BaseModel): document_list: list[str] -class DocumentBatchDownloadZipPayload(BaseModel): - """Request payload for bulk downloading documents as a zip archive.""" - - document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS) +class DocumentMetadataUpdatePayload(BaseModel): + doc_type: str | None = None + doc_metadata: Any = None class DocumentDatasetListParam(BaseModel): @@ -133,7 +201,13 @@ register_schema_models( DocumentRetryPayload, DocumentRenamePayload, GenerateSummaryPayload, + DocumentMetadataUpdatePayload, DocumentBatchDownloadZipPayload, + DatasetResponse, + DocumentMetadataResponse, + DocumentResponse, + DocumentWithSegmentsResponse, + DatasetAndDocumentResponse, ) @@ -280,7 +354,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id) + query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id) if status: query = DocumentService.apply_display_status_filter(query, status) @@ -366,10 +440,10 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) + @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) @@ -407,7 +481,9 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return {"dataset": dataset, "documents": documents, "batch": batch} + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @setup_required @login_required @@ -435,12 +511,13 @@ class DatasetInitApi(Resource): @console_ns.doc("init_dataset") @console_ns.doc(description="Initialize dataset with documents") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) - @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) + @console_ns.response( + 201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__] + ) @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self): @@ -488,9 +565,9 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = {"dataset": dataset, "documents": documents, "batch": batch} - - return response + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/datasets//documents//indexing-estimate") @@ -997,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource): @console_ns.doc("update_document_metadata") @console_ns.doc(description="Update document metadata") @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @console_ns.expect( - console_ns.model( - "UpdateDocumentMetadataRequest", - { - "doc_type": fields.String(description="Document type"), - "doc_metadata": fields.Raw(description="Document metadata"), - }, - ) - ) + @console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__]) @console_ns.response(200, "Document metadata updated successfully") @console_ns.response(404, "Document not found") @console_ns.response(403, "Permission denied") @@ -1018,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - req_data = request.get_json() + req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {}) - doc_type = req_data.get("doc_type") - doc_metadata = req_data.get("doc_metadata") + doc_type = req_data.doc_type + doc_metadata = req_data.doc_metadata # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: @@ -1035,7 +1104,7 @@ class DocumentMetadataApi(DocumentResource): if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") - metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) + metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) document.doc_metadata = {} if doc_type == "others": @@ -1203,7 +1272,7 @@ class DocumentRenameApi(DocumentResource): @setup_required @login_required @account_initialization_required - @marshal_with(document_model) + @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__]) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator @@ -1221,7 +1290,7 @@ class DocumentRenameApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return document + return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json") @console_ns.route("/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index c5f4e3a6e26..2647bb1f5a5 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,7 +2,6 @@ import uuid from flask import request from flask_restx import Resource, marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -10,6 +9,7 @@ from werkzeug.exceptions import Forbidden, NotFound import services from configs import dify_config +from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ProviderNotInitializeError @@ -31,6 +31,7 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment @@ -82,14 +83,6 @@ class BatchImportPayload(BaseModel): upload_file_id: str -class ChildChunkCreatePayload(BaseModel): - content: str - - -class ChildChunkUpdatePayload(BaseModel): - content: str - - class ChildChunkBatchUpdatePayload(BaseModel): chunks: list[ChildChunkUpdateArgs] diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index f3866f6aefd..e513e8c8f90 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -227,10 +227,11 @@ class ExternalApiUseCheckApi(Resource): @login_required @account_initialization_required def get(self, external_knowledge_api_id): + _, current_tenant_id = current_account_with_tenant() external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check( - external_knowledge_api_id + external_knowledge_api_id, current_tenant_id ) return {"is_using": external_knowledge_api_is_using, "count": count}, 200 diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index e62be13c2ff..36a7a4bb0e0 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,13 +1,13 @@ -from flask_restx import Resource, fields +from __future__ import annotations -from controllers.common.schema import register_schema_model -from fields.hit_testing_fields import ( - child_chunk_fields, - document_fields, - files_fields, - hit_testing_record_fields, - segment_fields, -) +from datetime import datetime +from typing import Any + +from flask_restx import Resource +from pydantic import Field, field_validator + +from controllers.common.schema import register_schema_models +from fields.base import ResponseModel from libs.login import login_required from .. import console_ns @@ -18,39 +18,92 @@ from ..wraps import ( setup_required, ) -register_schema_model(console_ns, HitTestingPayload) + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -def _get_or_create_model(model_name: str, field_def): - """Get or create a flask_restx model to avoid dict type issues in Swagger.""" - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing +class HitTestingDocument(ResponseModel): + id: str | None = None + data_source_type: str | None = None + name: str | None = None + doc_type: str | None = None + doc_metadata: Any | None = None -# Register models for flask_restx to avoid dict type issues in Swagger -document_model = _get_or_create_model("HitTestingDocument", document_fields) +class HitTestingSegment(ResponseModel): + id: str | None = None + position: int | None = None + document_id: str | None = None + content: str | None = None + sign_content: str | None = None + answer: str | None = None + word_count: int | None = None + tokens: int | None = None + keywords: list[str] = Field(default_factory=list) + index_node_id: str | None = None + index_node_hash: str | None = None + hit_count: int | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + status: str | None = None + created_by: str | None = None + created_at: int | None = None + indexing_at: int | None = None + completed_at: int | None = None + error: str | None = None + stopped_at: int | None = None + document: HitTestingDocument | None = None -segment_fields_copy = segment_fields.copy() -segment_fields_copy["document"] = fields.Nested(document_model) -segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy) + @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields) -files_model = _get_or_create_model("HitTestingFile", files_fields) -hit_testing_record_fields_copy = hit_testing_record_fields.copy() -hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model) -hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model)) -hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model)) -hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy) +class HitTestingChildChunk(ResponseModel): + id: str | None = None + content: str | None = None + position: int | None = None + score: float | None = None -# Response model for hit testing API -hit_testing_response_fields = { - "query": fields.String, - "records": fields.List(fields.Nested(hit_testing_record_model)), -} -hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields) + +class HitTestingFile(ResponseModel): + id: str | None = None + name: str | None = None + size: int | None = None + extension: str | None = None + mime_type: str | None = None + source_url: str | None = None + + +class HitTestingRecord(ResponseModel): + segment: HitTestingSegment | None = None + child_chunks: list[HitTestingChildChunk] = Field(default_factory=list) + score: float | None = None + tsne_position: Any | None = None + files: list[HitTestingFile] = Field(default_factory=list) + summary: str | None = None + + +class HitTestingResponse(ResponseModel): + query: str + records: list[HitTestingRecord] = Field(default_factory=list) + + +register_schema_models( + console_ns, + HitTestingPayload, + HitTestingDocument, + HitTestingSegment, + HitTestingChildChunk, + HitTestingFile, + HitTestingRecord, + HitTestingResponse, +) @console_ns.route("/datasets//hit-testing") @@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) - @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model) + @console_ns.response( + 200, + "Hit testing completed successfully", + model=console_ns.models[HitTestingResponse.__name__], + ) @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @setup_required @@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): args = payload.model_dump(exclude_none=True) self.hit_testing_args_check(args) - return self.perform_hit_testing(dataset, args) + return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json") diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 8fb3699849e..699fa599c8a 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,7 +2,6 @@ import logging from typing import Any from flask_restx import marshal -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -21,6 +20,7 @@ from core.errors.error import ( QuotaExceededError, ) from fields.hit_testing_fields import hit_testing_record_fields +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 2e69ddc5abc..d966e1629e7 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,9 +1,9 @@ from typing import Literal from flask_restx import Resource, marshal_with -from pydantic import BaseModel from werkzeug.exceptions import NotFound +from controllers.common.controller_schemas import MetadataUpdatePayload from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required @@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.metadata_service import MetadataService - -class MetadataUpdatePayload(BaseModel): - name: str - - register_schema_models( console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail ) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index bdf83b991e5..fd0a8b33bc7 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,8 +2,6 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -12,6 +10,8 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 1758bad31da..4fe96902575 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -3,6 +3,7 @@ import logging from flask import request from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models @@ -86,8 +87,8 @@ class CustomizedPipelineTemplateApi(Resource): @enterprise_license_required def post(self, template_id: str): with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - template = ( - session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() + template = session.scalar( + select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1) ) if not template: raise ValueError("Customized pipeline template not found.") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index d635dcb5301..b31d73f27db 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,9 +1,9 @@ import logging +from collections.abc import Callable from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -27,6 +27,7 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline @@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel): register_schema_models(console_ns, WorkflowDraftVariablePatchPayload) -def _api_prerequisite(f): +def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]: """Common prerequisites for all draft workflow variable APIs. It ensures the following conditions are satisfied: @@ -70,7 +71,7 @@ def _api_prerequisite(f): @login_required @account_initialization_required @get_rag_pipeline - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response: if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) @@ -222,24 +223,27 @@ class RagPipelineVariableApi(Resource): new_value = None if raw_value is not None: - if variable.value_type == SegmentType.FILE: - if not isinstance(raw_value, dict): - raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping( - mapping=raw_value, - tenant_id=pipeline.tenant_id, - access_controller=_file_access_controller, - ) - elif variable.value_type == SegmentType.ARRAY_FILE: - if not isinstance(raw_value, list): - raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") - if len(raw_value) > 0 and not isinstance(raw_value[0], dict): - raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings( - mappings=raw_value, - tenant_id=pipeline.tenant_id, - access_controller=_file_access_controller, - ) + match variable.value_type: + case SegmentType.FILE: + if not isinstance(raw_value, dict): + raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) + case SegmentType.ARRAY_FILE: + if not isinstance(raw_value, list): + raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") + if len(raw_value) > 0 and not isinstance(raw_value[0], dict): + raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) + case _: + pass new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py index 732a6dc4468..aa27458176c 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_import.py @@ -19,7 +19,7 @@ from fields.rag_pipeline_fields import ( ) from libs.login import current_account_with_tenant, login_required from models.dataset import Pipeline -from services.app_dsl_service import ImportStatus +from services.entities.dsl_entities import ImportStatus from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService @@ -83,11 +83,13 @@ class RagPipelineImportApi(Resource): # Return appropriate status code based on result status = result.status - if status == ImportStatus.FAILED: - return result.model_dump(mode="json"), 400 - elif status == ImportStatus.PENDING: - return result.model_dump(mode="json"), 202 - return result.model_dump(mode="json"), 200 + match status: + case ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + case ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS: + return result.model_dump(mode="json"), 200 @console_ns.route("/rag/pipelines/imports//confirm") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 70dfe47d7f8..ee146e82871 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,12 +4,12 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services +from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( @@ -40,6 +40,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required @@ -94,22 +95,6 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload): original_document_id: str | None = None -class DefaultBlockConfigQuery(BaseModel): - q: str | None = None - - -class WorkflowListQuery(BaseModel): - page: int = Field(default=1, ge=1, le=99999) - limit: int = Field(default=10, ge=1, le=100) - user_id: str | None = None - named_only: bool = False - - -class WorkflowUpdatePayload(BaseModel): - marked_name: str | None = Field(default=None, max_length=20) - marked_comment: str | None = Field(default=None, max_length=100) - - class NodeIdQuery(BaseModel): node_id: str @@ -361,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): -# @setup_required -# @login_required -# @account_initialization_required -# @get_rag_pipeline -# def post(self, pipeline: Pipeline, node_id: str): -# """ -# Run rag pipeline datasource -# """ -# # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.has_edit_permission: -# raise Forbidden() -# -# if not isinstance(current_user, Account): -# raise Forbidden() -# -# parser = (reqparse.RequestParser() -# .add_argument("job_id", type=str, required=True, nullable=False, location="json") -# .add_argument("datasource_type", type=str, required=True, location="json") -# ) -# args = parser.parse_args() -# -# job_id = args.get("job_id") -# if job_id == None: -# raise ValueError("missing job_id") -# datasource_type = args.get("datasource_type") -# if datasource_type == None: -# raise ValueError("missing datasource_type") -# -# rag_pipeline_service = RagPipelineService() -# result = rag_pipeline_service.run_datasource_workflow_node_status( -# pipeline=pipeline, -# node_id=node_id, -# job_id=job_id, -# account=current_user, -# datasource_type=datasource_type, -# is_published=True -# ) -# -# return result - - -# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): -# @setup_required -# @login_required -# @account_initialization_required -# @get_rag_pipeline -# def post(self, pipeline: Pipeline, node_id: str): -# """ -# Run rag pipeline datasource -# """ -# # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.has_edit_permission: -# raise Forbidden() -# -# if not isinstance(current_user, Account): -# raise Forbidden() -# -# parser = (reqparse.RequestParser() -# .add_argument("job_id", type=str, required=True, nullable=False, location="json") -# .add_argument("datasource_type", type=str, required=True, location="json") -# ) -# args = parser.parse_args() -# -# job_id = args.get("job_id") -# if job_id == None: -# raise ValueError("missing job_id") -# datasource_type = args.get("datasource_type") -# if datasource_type == None: -# raise ValueError("missing datasource_type") -# -# rag_pipeline_service = RagPipelineService() -# result = rag_pipeline_service.run_datasource_workflow_node_status( -# pipeline=pipeline, -# node_id=node_id, -# job_id=job_id, -# account=current_user, -# datasource_type=datasource_type, -# is_published=False -# ) -# -# return result -# @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//run") class RagPipelinePublishedDatasourceNodeRunApi(Resource): @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__]) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index b1b01b5f51c..ab660d9dc3a 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,11 +1,10 @@ import logging from flask import request -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.controller_schemas import TextToAudioPayload from controllers.common.schema import register_schema_model from controllers.console.app.error import ( AppUnavailableError, @@ -20,6 +19,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -32,14 +32,6 @@ from .. import console_ns logger = logging.getLogger(__name__) - -class TextToAudioPayload(BaseModel): - message_id: str | None = None - voice: str | None = None - text: str | None = None - streaming: bool | None = Field(default=None, description="Enable streaming response") - - register_schema_model(console_ns, TextToAudioPayload) diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index eacd7332fe8..ccdccceaa67 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,7 +2,6 @@ import logging from typing import Any, Literal from uuid import UUID -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,6 +25,7 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 092f509f1c3..2eb2054e64a 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,10 +1,11 @@ from typing import Any from flask import request -from pydantic import BaseModel, Field, TypeAdapter, model_validator +from pydantic import BaseModel, Field, TypeAdapter from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound +from controllers.common.controller_schemas import ConversationRenamePayload from controllers.common.schema import register_schema_models from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource @@ -32,18 +33,6 @@ class ConversationListQuery(BaseModel): pinned: bool | None = None -class ConversationRenamePayload(BaseModel): - name: str | None = None - auto_generate: bool = False - - @model_validator(mode="after") - def validate_name_requirement(self): - if not self.auto_generate: - if self.name is None or not self.name.strip(): - raise ValueError("name is required when auto_generate is false") - return self - - register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 0740dd0e24c..2d9a997fbf3 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,21 +1,24 @@ import logging +from datetime import datetime from typing import Any from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, computed_field, field_validator from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields +from fields.base import ResponseModel +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import App, InstalledApp, RecommendedApp +from models.model import IconType from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel): logger = logging.getLogger(__name__) -app_model = get_or_create_model("InstalledAppInfo", app_fields) +def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE: + return None + return file_helpers.get_signed_file_url(icon) -installed_app_fields_copy = installed_app_fields.copy() -installed_app_fields_copy["app"] = fields.Nested(app_model) -installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy) -installed_app_list_fields_copy = installed_app_list_fields.copy() -installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model)) -installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy) +def _safe_primitive(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool, datetime)): + return value + return None + + +class InstalledAppInfoResponse(ResponseModel): + id: str + name: str | None = None + mode: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + use_icon_as_answer_icon: bool | None = None + + @field_validator("mode", "icon_type", mode="before") + @classmethod + def _normalize_enum_like(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + +class InstalledAppResponse(ResponseModel): + id: str + app: InstalledAppInfoResponse + app_owner_tenant_id: str + is_pinned: bool + last_used_at: int | None = None + editable: bool + uninstallable: bool + + @field_validator("app", mode="before") + @classmethod + def _normalize_app(cls, value: Any) -> Any: + if isinstance(value, dict): + return value + return { + "id": _safe_primitive(getattr(value, "id", "")) or "", + "name": _safe_primitive(getattr(value, "name", None)), + "mode": _safe_primitive(getattr(value, "mode", None)), + "icon_type": _safe_primitive(getattr(value, "icon_type", None)), + "icon": _safe_primitive(getattr(value, "icon", None)), + "icon_background": _safe_primitive(getattr(value, "icon_background", None)), + "use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)), + } + + @field_validator("last_used_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class InstalledAppListResponse(ResponseModel): + installed_apps: list[InstalledAppResponse] + + +register_schema_models( + console_ns, + InstalledAppCreatePayload, + InstalledAppUpdatePayload, + InstalledAppsListQuery, + InstalledAppInfoResponse, + InstalledAppResponse, + InstalledAppListResponse, +) @console_ns.route("/installed-apps") class InstalledAppsListApi(Resource): @login_required @account_initialization_required - @marshal_with(installed_app_list_model) + @console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__]) def get(self): query = InstalledAppsListQuery.model_validate(request.args.to_dict()) current_user, current_tenant_id = current_account_with_tenant() @@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource): ) ) - return {"installed_apps": installed_app_list} + return InstalledAppListResponse.model_validate( + {"installed_apps": installed_app_list}, from_attributes=True + ).model_dump(mode="json") @login_required @account_initialization_required diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index fcbefcda33b..209667d1d0d 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,10 +2,10 @@ import logging from typing import Literal from flask import request -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import BaseModel, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery from controllers.common.schema import register_schema_models from controllers.console.app.error import ( AppMoreLikeThisDisabledError, @@ -24,8 +24,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant from models.enums import FeedbackRating from models.model import AppMode @@ -44,17 +44,6 @@ from .. import console_ns logger = logging.getLogger(__name__) -class MessageListQuery(BaseModel): - conversation_id: UUIDStrOrEmpty - first_id: UUIDStrOrEmpty | None = None - limit: int = Field(default=20, ge=1, le=100) - - -class MessageFeedbackPayload(BaseModel): - rating: Literal["like", "dislike"] | None = None - content: str | None = None - - class MoreLikeThisQuery(BaseModel): response_mode: Literal["blocking", "streaming"] diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index c9920c97cfe..55bd679b489 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,66 +1,83 @@ +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, computed_field, field_validator from constants.languages import languages -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required -from libs.helper import AppIconUrlField +from fields.base import ResponseModel +from libs.helper import build_icon_url from libs.login import current_user, login_required from services.recommended_app_service import RecommendedAppService -app_fields = { - "id": fields.String, - "name": fields.String, - "mode": fields.String, - "icon": fields.String, - "icon_type": fields.String, - "icon_url": AppIconUrlField, - "icon_background": fields.String, -} - -app_model = get_or_create_model("RecommendedAppInfo", app_fields) - -recommended_app_fields = { - "app": fields.Nested(app_model, attribute="app"), - "app_id": fields.String, - "description": fields.String(attribute="description"), - "copyright": fields.String, - "privacy_policy": fields.String, - "custom_disclaimer": fields.String, - "category": fields.String, - "position": fields.Integer, - "is_listed": fields.Boolean, - "can_trial": fields.Boolean, -} - -recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields) - -recommended_app_list_fields = { - "recommended_apps": fields.List(fields.Nested(recommended_app_model)), - "categories": fields.List(fields.String), -} - -recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields) - class RecommendedAppsQuery(BaseModel): language: str | None = Field(default=None) -console_ns.schema_model( - RecommendedAppsQuery.__name__, - RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"), +class RecommendedAppInfoResponse(ResponseModel): + id: str + name: str | None = None + mode: str | None = None + icon: str | None = None + icon_type: str | None = None + icon_background: str | None = None + + @staticmethod + def _normalize_enum_like(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("mode", "icon_type", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> str | None: + return cls._normalize_enum_like(value) + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def icon_url(self) -> str | None: + return build_icon_url(self.icon_type, self.icon) + + +class RecommendedAppResponse(ResponseModel): + app: RecommendedAppInfoResponse | None = None + app_id: str + description: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + category: str | None = None + position: int | None = None + is_listed: bool | None = None + can_trial: bool | None = None + + +class RecommendedAppListResponse(ResponseModel): + recommended_apps: list[RecommendedAppResponse] + categories: list[str] + + +register_schema_models( + console_ns, + RecommendedAppsQuery, + RecommendedAppInfoResponse, + RecommendedAppResponse, + RecommendedAppListResponse, ) @console_ns.route("/explore/apps") class RecommendedAppListApi(Resource): @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__]) + @console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__]) @login_required @account_initialization_required - @marshal_with(recommended_app_list_model) def get(self): # language args args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource): else: language_prefix = languages[0] - return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) + return RecommendedAppListResponse.model_validate( + RecommendedAppService.get_recommended_apps_and_categories(language_prefix), + from_attributes=True, + ).model_dump(mode="json") @console_ns.route("/explore/apps/") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index ea3de917416..9ec4e823248 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,28 +1,18 @@ from flask import request -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import TypeAdapter from werkzeug.exceptions import NotFound +from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem -from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService - -class SavedMessageListQuery(BaseModel): - last_id: UUIDStrOrEmpty | None = None - limit: int = Field(default=20, ge=1, le=100) - - -class SavedMessageCreatePayload(BaseModel): - message_id: UUIDStrOrEmpty - - register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload) diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index e432574434d..1456301a247 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,8 +3,6 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -61,6 +59,8 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user @@ -169,6 +169,7 @@ console_ns.schema_model( class TrialAppWorkflowRunApi(TrialAppResource): + @trial_feature_enable @console_ns.expect(console_ns.models[WorkflowRunRequest.__name__]) def post(self, trial_app): """ @@ -210,6 +211,7 @@ class TrialAppWorkflowRunApi(TrialAppResource): class TrialAppWorkflowTaskStopApi(TrialAppResource): + @trial_feature_enable def post(self, trial_app, task_id: str): """ Stop workflow task @@ -290,7 +292,6 @@ class TrialChatApi(TrialAppResource): class TrialMessageSuggestedQuestionApi(TrialAppResource): - @trial_feature_enable def get(self, trial_app, message_id): app_model = trial_app app_mode = AppMode.value_of(app_model.mode) @@ -470,7 +471,6 @@ class TrialCompletionApi(TrialAppResource): class TrialSitApi(Resource): """Resource for trial app sites.""" - @trial_feature_enable @get_app_model_with_trial(None) def get(self, app_model): """Retrieve app site info. @@ -492,7 +492,6 @@ class TrialSitApi(Resource): class TrialAppParameterApi(Resource): """Resource for app variables.""" - @trial_feature_enable @get_app_model_with_trial(None) def get(self, app_model): """Retrieve app parameters.""" @@ -521,7 +520,6 @@ class TrialAppParameterApi(Resource): class AppApi(Resource): - @trial_feature_enable @get_app_model_with_trial(None) @marshal_with(app_detail_with_site_model) def get(self, app_model): @@ -534,7 +532,6 @@ class AppApi(Resource): class AppWorkflowApi(Resource): - @trial_feature_enable @get_app_model_with_trial(None) @marshal_with(workflow_model) def get(self, app_model): @@ -547,7 +544,6 @@ class AppWorkflowApi(Resource): class DatasetListApi(Resource): - @trial_feature_enable @get_app_model_with_trial(None) def get(self, app_model): page = request.args.get("page", default=1, type=int) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 42cafc71932..438cce4fd86 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,11 +1,8 @@ import logging -from typing import Any -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel from werkzeug.exceptions import InternalServerError +from controllers.common.controller_schemas import WorkflowRunPayload from controllers.common.schema import register_schema_model from controllers.console.app.error import ( CompletionRequestError, @@ -24,6 +21,8 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp @@ -34,12 +33,6 @@ from .. import console_ns logger = logging.getLogger(__name__) - -class WorkflowRunPayload(BaseModel): - inputs: dict[str, Any] - files: list[dict[str, Any]] | None = None - - register_schema_model(console_ns, WorkflowRunPayload) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index efa46c9779a..7a6356d052f 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,15 +1,18 @@ +from datetime import datetime +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, field_validator from constants import HIDDEN_VALUE -from fields.api_based_extension_fields import api_based_extension_fields +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService -from ..common.schema import register_schema_models +from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models from . import console_ns from .wraps import account_initialization_required, setup_required @@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel): api_key: str = Field(description="API key for authentication") -register_schema_models(console_ns, APIBasedExtensionPayload) +class CodeBasedExtensionResponse(ResponseModel): + module: str = Field(description="Module name") + data: Any = Field(description="Extension data") -api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields) +def _mask_api_key(api_key: str) -> str: + if not api_key: + return api_key + if len(api_key) <= 8: + return api_key[0] + "******" + api_key[-1] + return api_key[:3] + "******" + api_key[-3:] -api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model)) + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class APIBasedExtensionResponse(ResponseModel): + id: str + name: str + api_endpoint: str + api_key: str + created_at: int | None = None + + @field_validator("api_key", mode="before") + @classmethod + def _normalize_api_key(cls, value: str) -> str: + return _mask_api_key(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse) +console_ns.schema_model( + "APIBasedExtensionListResponse", + TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + + +def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]: + return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json") @console_ns.route("/code-based-extension") @@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "CodeBasedExtensionResponse", - {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, - ), + console_ns.models[CodeBasedExtensionResponse.__name__], ) @setup_required @login_required @@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource): def get(self): query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)} + return CodeBasedExtensionResponse( + module=query.module, + data=CodeBasedExtensionService.get_code_based_extension(query.module), + ).model_dump(mode="json") @console_ns.route("/api-based-extension") class APIBasedExtensionAPI(Resource): @console_ns.doc("get_api_based_extensions") @console_ns.doc(description="Get all API-based extensions for current tenant") - @console_ns.response(200, "Success", api_based_extension_list_model) + @console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def get(self): _, tenant_id = current_account_with_tenant() - return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + return [ + _serialize_api_based_extension(extension) + for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + ] @console_ns.doc("create_api_based_extension") @console_ns.doc(description="Create a new API-based extension") @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) - @console_ns.response(201, "Extension created successfully", api_based_extension_model) + @console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def post(self): payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() @@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource): api_key=payload.api_key, ) - return APIBasedExtensionService.save(extension_data) + return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data)) @console_ns.route("/api-based-extension/") @@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource): @console_ns.doc("get_api_based_extension") @console_ns.doc(description="Get API-based extension by ID") @console_ns.doc(params={"id": "Extension ID"}) - @console_ns.response(200, "Success", api_based_extension_model) + @console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def get(self, id): api_based_extension_id = str(id) _, tenant_id = current_account_with_tenant() - return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + return _serialize_api_based_extension( + APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + ) @console_ns.doc("update_api_based_extension") @console_ns.doc(description="Update API-based extension") @console_ns.doc(params={"id": "Extension ID"}) @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) - @console_ns.response(200, "Extension updated successfully", api_based_extension_model) + @console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def post(self, id): api_based_extension_id = str(id) _, current_tenant_id = current_account_with_tenant() @@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource): if payload.api_key != HIDDEN_VALUE: extension_data_from_db.api_key = payload.api_key - return APIBasedExtensionService.save(extension_data_from_db) + return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db)) @console_ns.doc("delete_api_based_extension") @console_ns.doc(description="Delete API-based extension") diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index e37e78c966f..845af37365c 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -7,7 +7,8 @@ import logging from collections.abc import Generator from flask import Response, jsonify, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream logger = logging.getLogger(__name__) +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict + action: str + + def _jsonify_form_definition(form: Form) -> Response: payload = form.get_definition().model_dump() payload["expiration_time"] = int(form.expiration_time.timestamp()) @@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource): "action": "Approve" } """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() + payload = HumanInputFormSubmitPayload.model_validate(request.get_json()) current_user, _ = current_account_with_tenant() service = HumanInputService(db.engine) @@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource): service.submit_form_by_token( recipient_type=recipient_type, form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], + selected_action_id=payload.action, + form_data=payload.inputs, submission_user_id=current_user.id, ) @@ -168,12 +171,13 @@ class ConsoleWorkflowEventsApi(Resource): else: msg_generator = MessageGenerator() generator: BaseAppGenerator - if app.mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app.mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") + match app.mode: + case AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + case AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + case _: + raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" diff --git a/api/controllers/console/notification.py b/api/controllers/console/notification.py index 53e4aa3d86c..5d464701737 100644 --- a/api/controllers/console/notification.py +++ b/api/controllers/console/notification.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import TypedDict + from flask import request from flask_restx import Resource from pydantic import BaseModel, Field @@ -11,9 +14,34 @@ from services.billing_service import BillingService _FALLBACK_LANG = "en-US" -def _pick_lang_content(contents: dict, lang: str) -> dict: +class NotificationLangContent(TypedDict, total=False): + lang: str + title: str + subtitle: str + body: str + titlePicUrl: str + + +class NotificationItemDict(TypedDict): + notification_id: str | None + frequency: str | None + lang: str + title: str + subtitle: str + body: str + title_pic_url: str + + +class NotificationResponseDict(TypedDict): + should_show: bool + notifications: list[NotificationItemDict] + + +def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent: """Return the single LangContent for *lang*, falling back to English.""" - return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {}) + return ( + contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent()) + ) class DismissNotificationPayload(BaseModel): @@ -45,28 +73,30 @@ class NotificationApi(Resource): result = BillingService.get_account_notification(str(current_user.id)) # Proto JSON uses camelCase field names (Kratos default marshaling). + response: NotificationResponseDict if not result.get("shouldShow"): - return {"should_show": False, "notifications": []}, 200 + response = {"should_show": False, "notifications": []} + return response, 200 lang = current_user.interface_language or _FALLBACK_LANG - notifications = [] + notifications: list[NotificationItemDict] = [] for notification in result.get("notifications") or []: - contents: dict = notification.get("contents") or {} + contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {} lang_content = _pick_lang_content(contents, lang) - notifications.append( - { - "notification_id": notification.get("notificationId"), - "frequency": notification.get("frequency"), - "lang": lang_content.get("lang", lang), - "title": lang_content.get("title", ""), - "subtitle": lang_content.get("subtitle", ""), - "body": lang_content.get("body", ""), - "title_pic_url": lang_content.get("titlePicUrl", ""), - } - ) + item: NotificationItemDict = { + "notification_id": notification.get("notificationId"), + "frequency": notification.get("frequency"), + "lang": lang_content.get("lang", lang), + "title": lang_content.get("title", ""), + "subtitle": lang_content.get("subtitle", ""), + "body": lang_content.get("body", ""), + "title_pic_url": lang_content.get("titlePicUrl", ""), + } + notifications.append(item) - return {"should_show": bool(notifications), "notifications": notifications}, 200 + response = {"should_show": bool(notifications), "notifications": notifications} + return response, 200 @console_ns.route("/notification/dismiss") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 551c86fd827..2a46d2250a0 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,7 +2,6 @@ import urllib.parse import httpx from flask_restx import Resource -from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -16,6 +15,7 @@ from controllers.console import console_ns from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/socketio/__init__.py b/api/controllers/console/socketio/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/api/controllers/console/socketio/__init__.py @@ -0,0 +1 @@ + diff --git a/api/controllers/console/socketio/workflow.py b/api/controllers/console/socketio/workflow.py new file mode 100644 index 00000000000..b4f03593fd7 --- /dev/null +++ b/api/controllers/console/socketio/workflow.py @@ -0,0 +1,108 @@ +import logging +from collections.abc import Callable +from typing import cast + +from flask import Request as FlaskRequest + +from extensions.ext_socketio import sio +from libs.passport import PassportService +from libs.token import extract_access_token +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository +from services.account_service import AccountService +from services.workflow_collaboration_service import WorkflowCollaborationService + +repository = WorkflowCollaborationRepository() +collaboration_service = WorkflowCollaborationService(repository, sio) + + +def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]: + return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event)) + + +@_sio_on("connect") +def socket_connect(sid, environ, auth): + """ + WebSocket connect event, do authentication here. + """ + try: + request_environ = FlaskRequest(environ) + token = extract_access_token(request_environ) + except Exception: + logging.exception("Failed to extract token") + token = None + + if not token: + logging.warning("Socket connect rejected: missing token (sid=%s)", sid) + return False + + try: + decoded = PassportService().verify(token) + user_id = decoded.get("user_id") + if not user_id: + logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid) + return False + + with sio.app.app_context(): + user = AccountService.load_logged_in_account(account_id=user_id) + if not user: + logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid) + return False + if not user.has_edit_permission: + logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid) + return False + + collaboration_service.save_socket_identity(sid, user) + return True + + except Exception: + logging.exception("Socket authentication failed") + return False + + +@_sio_on("user_connect") +def handle_user_connect(sid, data): + """ + Handle user connect event. Each session (tab) is treated as an independent collaborator. + """ + workflow_id = data.get("workflow_id") + if not workflow_id: + return {"msg": "workflow_id is required"}, 400 + + result = collaboration_service.authorize_and_join_workflow_room(workflow_id, sid) + if not result: + return {"msg": "unauthorized"}, 401 + + user_id, is_leader = result + return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader} + + +@_sio_on("disconnect") +def handle_disconnect(sid): + """ + Handle session disconnect event. Remove the specific session from online users. + """ + collaboration_service.disconnect_session(sid) + + +@_sio_on("collaboration_event") +def handle_collaboration_event(sid, data): + """ + Handle general collaboration events, include: + 1. mouse_move + 2. vars_and_features_update + 3. sync_request (ask leader to update graph) + 4. app_state_update + 5. mcp_server_update + 6. workflow_update + 7. comments_update + 8. node_panel_presence + """ + return collaboration_service.relay_collaboration_event(sid, data) + + +@_sio_on("graph_event") +def handle_graph_event(sid, data): + """ + Handle graph events - simple broadcast relay. + """ + return collaboration_service.relay_graph_event(sid, data) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 7511c970a36..614bf03ea51 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,43 +1,40 @@ from typing import Literal from flask import request -from flask_restx import Namespace, Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required -from services.tag_service import TagService - -dataset_tag_fields = { - "id": fields.String, - "name": fields.String, - "type": fields.String, - "binding_count": fields.String, -} - - -def build_dataset_tag_fields(api_or_ns: Namespace): - return api_or_ns.model("DataSetTag", dataset_tag_fields) +from models.enums import TagType +from services.tag_service import ( + SaveTagPayload, + TagBindingCreatePayload, + TagBindingDeletePayload, + TagService, + UpdateTagPayload, +) class TagBasePayload(BaseModel): name: str = Field(description="Tag name", min_length=1, max_length=50) - type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + type: TagType = Field(description="Tag type") class TagBindingPayload(BaseModel): tag_ids: list[str] = Field(description="Tag IDs to bind") target_id: str = Field(description="Target ID to bind tags to") - type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + type: TagType = Field(description="Tag type") class TagBindingRemovePayload(BaseModel): tag_id: str = Field(description="Tag ID to remove") target_id: str = Field(description="Target ID to unbind tag from") - type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + type: TagType = Field(description="Tag type") class TagListQueryParam(BaseModel): @@ -45,12 +42,36 @@ class TagListQueryParam(BaseModel): keyword: str | None = Field(None, description="Search keyword") +class TagResponse(ResponseModel): + id: str + name: str + type: str | None = None + binding_count: str | None = None + + @field_validator("type", mode="before") + @classmethod + def normalize_type(cls, value: TagType | str | None) -> str | None: + if value is None: + return None + if isinstance(value, TagType): + return value.value + return value + + @field_validator("binding_count", mode="before") + @classmethod + def normalize_binding_count(cls, value: int | str | None) -> str | None: + if value is None: + return None + return str(value) + + register_schema_models( console_ns, TagBasePayload, TagBindingPayload, TagBindingRemovePayload, TagListQueryParam, + TagResponse, ) @@ -62,14 +83,18 @@ class TagListApi(Resource): @console_ns.doc( params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} ) - @marshal_with(dataset_tag_fields) + @console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])}) def get(self): _, current_tenant_id = current_account_with_tenant() raw_args = request.args.to_dict() param = TagListQueryParam.model_validate(raw_args) tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) - return tags, 200 + serialized_tags = [ + TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags + ] + + return serialized_tags, 200 @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @@ -82,9 +107,11 @@ class TagListApi(Resource): raise Forbidden() payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.save_tags(payload.model_dump()) + tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type)) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + response = TagResponse.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + ).model_dump(mode="json") return response, 200 @@ -103,11 +130,13 @@ class TagUpdateDeleteApi(Resource): raise Forbidden() payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.update_tags(payload.model_dump(), tag_id) + tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + response = TagResponse.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + ).model_dump(mode="json") return response, 200 @@ -136,7 +165,9 @@ class TagBindingCreateApi(Resource): raise Forbidden() payload = TagBindingPayload.model_validate(console_ns.payload or {}) - TagService.save_tag_binding(payload.model_dump()) + TagService.save_tag_binding( + TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type) + ) return {"result": "success"}, 200 @@ -154,6 +185,8 @@ class TagBindingDeleteApi(Resource): raise Forbidden() payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) - TagService.delete_tag_binding(payload.model_dump()) + TagService.delete_tag_binding( + TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type) + ) return {"result": "success"}, 200 diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index 971674cee2b..59dd29fdace 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -1,6 +1,7 @@ from collections.abc import Callable from functools import wraps +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -21,12 +22,12 @@ def plugin_permission_required( tenant_id = current_tenant_id with sessionmaker(db.engine).begin() as session: - permission = ( - session.query(TenantPluginPermission) + permission = session.scalar( + select(TenantPluginPermission) .where( TenantPluginPermission.tenant_id == tenant_id, ) - .first() + .limit(1) ) if not permission: @@ -34,22 +35,24 @@ def plugin_permission_required( return view(*args, **kwargs) if install_required: - if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY: - raise Forbidden() - if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS: - if not user.is_admin_or_owner: + match permission.install_permission: + case TenantPluginPermission.InstallPermission.NOBODY: raise Forbidden() - if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE: - pass + case TenantPluginPermission.InstallPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + case TenantPluginPermission.InstallPermission.EVERYONE: + pass if debug_required: - if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY: - raise Forbidden() - if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS: - if not user.is_admin_or_owner: + match permission.debug_permission: + case TenantPluginPermission.DebugPermission.NOBODY: raise Forbidden() - if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE: - pass + case TenantPluginPermission.DebugPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + case TenantPluginPermission.DebugPermission.EVERYONE: + pass return view(*args, **kwargs) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 626d330e9de..c01286cc592 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,14 +1,13 @@ from __future__ import annotations from datetime import datetime -from typing import Literal +from typing import Any, Literal import pytz from flask import request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import supported_language @@ -38,9 +37,11 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db +from fields.base import ResponseModel from fields.member_fields import Account as AccountResponse +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now -from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone +from libs.helper import EmailStr, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import AccountIntegrate, InvitationCode from models.account import AccountStatus, InvitationCodeStatus @@ -75,6 +76,10 @@ class AccountAvatarPayload(BaseModel): avatar: str +class AccountAvatarQuery(BaseModel): + avatar: str = Field(..., description="Avatar file ID") + + class AccountInterfaceLanguagePayload(BaseModel): interface_language: str @@ -160,6 +165,7 @@ def reg(cls: type[BaseModel]): reg(AccountInitPayload) reg(AccountNamePayload) reg(AccountAvatarPayload) +reg(AccountAvatarQuery) reg(AccountInterfaceLanguagePayload) reg(AccountInterfaceThemePayload) reg(AccountTimezonePayload) @@ -175,21 +181,61 @@ reg(CheckEmailUniquePayload) register_schema_models(console_ns, AccountResponse) -def _serialize_account(account) -> dict: +def _serialize_account(account) -> dict[str, Any]: return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") -integrate_fields = { - "provider": fields.String, - "created_at": TimestampField, - "is_bound": fields.Boolean, - "link": fields.String, -} +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -integrate_model = console_ns.model("AccountIntegrate", integrate_fields) -integrate_list_model = console_ns.model( - "AccountIntegrateList", - {"data": fields.List(fields.Nested(integrate_model))}, + +class AccountIntegrateResponse(ResponseModel): + provider: str + created_at: int | None = None + is_bound: bool + link: str | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountIntegrateListResponse(ResponseModel): + data: list[AccountIntegrateResponse] + + +class EducationVerifyResponse(ResponseModel): + token: str | None = None + + +class EducationStatusResponse(ResponseModel): + result: bool | None = None + is_student: bool | None = None + expire_at: int | None = None + allow_refresh: bool | None = None + + @field_validator("expire_at", mode="before") + @classmethod + def _normalize_expire_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class EducationAutocompleteResponse(ResponseModel): + data: list[str] = Field(default_factory=list) + curr_page: int | None = None + has_next: bool | None = None + + +register_schema_models( + console_ns, + AccountIntegrateResponse, + AccountIntegrateListResponse, + EducationVerifyResponse, + EducationStatusResponse, + EducationAutocompleteResponse, ) @@ -269,6 +315,18 @@ class AccountNameApi(Resource): @console_ns.route("/account/avatar") class AccountAvatarApi(Resource): + @console_ns.expect(console_ns.models[AccountAvatarQuery.__name__]) + @console_ns.doc("get_account_avatar") + @console_ns.doc(description="Get account avatar url") + @setup_required + @login_required + @account_initialization_required + def get(self): + args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + avatar_url = file_helpers.get_signed_file_url(args.avatar) + return {"avatar_url": avatar_url} + @console_ns.expect(console_ns.models[AccountAvatarPayload.__name__]) @setup_required @login_required @@ -360,7 +418,7 @@ class AccountIntegrateApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(integrate_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__]) def get(self): account, _ = current_account_with_tenant() @@ -396,7 +454,9 @@ class AccountIntegrateApi(Resource): } ) - return {"data": integrate_data} + return AccountIntegrateListResponse( + data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data] + ).model_dump(mode="json") @console_ns.route("/account/delete/verify") @@ -448,31 +508,22 @@ class AccountDeleteUpdateFeedbackApi(Resource): @console_ns.route("/account/education/verify") class EducationVerifyApi(Resource): - verify_fields = { - "token": fields.String, - } - @setup_required @login_required @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(verify_fields) + @console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__]) def get(self): account, _ = current_account_with_tenant() - return BillingService.EducationIdentity.verify(account.id, account.email) + return EducationVerifyResponse.model_validate( + BillingService.EducationIdentity.verify(account.id, account.email) or {} + ).model_dump(mode="json") @console_ns.route("/account/education") class EducationApi(Resource): - status_fields = { - "result": fields.Boolean, - "is_student": fields.Boolean, - "expire_at": TimestampField, - "allow_refresh": fields.Boolean, - } - @console_ns.expect(console_ns.models[EducationActivatePayload.__name__]) @setup_required @login_required @@ -492,37 +543,33 @@ class EducationApi(Resource): @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(status_fields) + @console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__]) def get(self): account, _ = current_account_with_tenant() - res = BillingService.EducationIdentity.status(account.id) + res = BillingService.EducationIdentity.status(account.id) or {} # convert expire_at to UTC timestamp from isoformat if res and "expire_at" in res: res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc) - return res + return EducationStatusResponse.model_validate(res).model_dump(mode="json") @console_ns.route("/account/education/autocomplete") class EducationAutoCompleteApi(Resource): - data_fields = { - "data": fields.List(fields.String), - "curr_page": fields.Integer, - "has_next": fields.Boolean, - } - @console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__]) @setup_required @login_required @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(data_fields) + @console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__]) def get(self): payload = request.args.to_dict(flat=True) args = EducationAutocompleteQuery.model_validate(payload) - return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) + return EducationAutocompleteResponse.model_validate( + BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {} + ).model_dump(mode="json") @console_ns.route("/account/change-email") @@ -548,13 +595,25 @@ class ChangeEmailSendEmailApi(Resource): account = None user_email = None email_for_sending = args.email.lower() - if args.phase is not None and args.phase == "new_email": + # Default to the initial phase; any legacy/unexpected client input is + # coerced back to `old_email` so we never trust the caller to declare + # later phases without a verified predecessor token. + send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD + if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW: + send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW if args.token is None: raise InvalidTokenError() reset_data = AccountService.get_change_email_data(args.token) if reset_data is None: raise InvalidTokenError() + + # The token used to request a new-email code must come from the + # old-email verification step. This prevents the bypass described + # in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here. + token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED: + raise InvalidTokenError() user_email = reset_data.get("email", "") if user_email.lower() != current_user.email.lower(): @@ -562,8 +621,7 @@ class ChangeEmailSendEmailApi(Resource): user_email = current_user.email else: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) if account is None: raise AccountNotFound() email_for_sending = account.email @@ -574,7 +632,7 @@ class ChangeEmailSendEmailApi(Resource): email=email_for_sending, old_email=user_email, language=language, - phase=args.phase, + phase=send_phase, ) return {"result": "success", "data": token} @@ -609,12 +667,31 @@ class ChangeEmailCheckApi(Resource): AccountService.add_change_email_error_rate_limit(user_email) raise EmailCodeError() + # Only advance tokens that were minted by the matching send-code step; + # refuse tokens that have already progressed or lack a phase marker so + # the chain `old_email -> old_email_verified -> new_email -> new_email_verified` + # is strictly enforced. + phase_transitions = { + AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if not isinstance(token_phase, str): + raise InvalidTokenError() + refreshed_phase = phase_transitions.get(token_phase) + if refreshed_phase is None: + raise InvalidTokenError() + # Verified, revoke the first token AccountService.revoke_change_email_token(args.token) - # Refresh token data by generating a new token + # Refresh token data by generating a new token that carries the + # upgraded phase so later steps can check it. _, new_token = AccountService.generate_change_email_token( - user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} + user_email, + code=args.code, + old_email=token_data.get("old_email"), + additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase}, ) AccountService.reset_change_email_error_rate_limit(user_email) @@ -644,13 +721,29 @@ class ChangeEmailResetApi(Resource): if not reset_data: raise InvalidTokenError() - AccountService.revoke_change_email_token(args.token) + # Only tokens that completed both verification phases may be used to + # change the email. This closes GHSA-4q3w-q5mc-45rq where a token from + # the initial send-code step could be replayed directly here. + token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED: + raise InvalidTokenError() + + # Bind the new email to the token that was mailed and verified, so a + # verified token cannot be reused with a different `new_email` value. + token_email = reset_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if normalized_token_email != normalized_new_email: + raise InvalidTokenError() old_email = reset_data.get("old_email", "") current_user, _ = current_account_with_tenant() if current_user.email.lower() != old_email.lower(): raise AccountNotFound() + # Revoke only after all checks pass so failed attempts don't burn a + # legitimately verified token. + AccountService.revoke_change_email_token(args.token) + updated_account = AccountService.update_account_email(current_user, email=normalized_new_email) AccountService.send_change_email_completed_notify_email( diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 3fdcbc47108..764f4887558 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index b6b9deb1f92..f45b72f3904 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -2,13 +2,13 @@ from typing import Any from flask import request from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index e4cfca9fa4c..2a6f37aec81 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index cbb9677309e..4b10561fdb6 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 9182dbb510f..b2d07ff8f96 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService @@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource): class ParserValidate(BaseModel): model: str model_type: ModelType - credentials: dict + credentials: dict[str, Any] console_ns.schema_model( diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index aa674a63b30..b3e344ccea8 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,7 +4,6 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -15,6 +14,7 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c9956501e28..34c9534de80 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,7 +5,6 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -28,6 +27,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID @@ -1131,6 +1131,14 @@ class ToolMCPAuthApi(Resource): with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) + parsed = urlparse(server_url) + sanitized_url = f"{parsed.scheme}://{parsed.hostname}{parsed.path}" + logger.warning( + "MCP authorization failed for provider %s (url=%s)", + provider_id, + sanitized_url, + exc_info=True, + ) raise ValueError(f"Failed to connect to MCP server: {e}") from e diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 7a28a09861c..d11b66244fc 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,7 +3,6 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden @@ -16,6 +15,7 @@ from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index a06b4fd195a..565099db615 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,9 @@ import logging +from datetime import datetime from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource, fields, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import Unauthorized @@ -26,9 +27,10 @@ from controllers.console.wraps import ( ) from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from fields.base import ResponseModel from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required -from models.account import Tenant, TenantStatus +from models.account import Tenant, TenantCustomConfigDict, TenantStatus from services.account_service import TenantService from services.billing_service import BillingService, SubscriptionPlan from services.enterprise.enterprise_service import EnterpriseService @@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel): name: str +class TenantInfoResponse(ResponseModel): + id: str + name: str | None = None + plan: str | None = None + status: str | None = None + created_at: int | None = None + role: str | None = None + in_trial: bool | None = None + trial_end_reason: str | None = None + custom_config: dict | None = None + trial_credits: int | None = None + trial_credits_used: int | None = None + next_credit_reset_date: int | None = None + + @field_validator("plan", "status", "trial_end_reason", mode="before") + @classmethod + def _normalize_enum_like(cls, value): + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None): + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) @@ -66,6 +99,7 @@ reg(WorkspaceListQuery) reg(SwitchWorkspacePayload) reg(WorkspaceCustomConfigPayload) reg(WorkspaceInfoPayload) +reg(TenantInfoResponse) provider_fields = { "provider_name": fields.String, @@ -180,7 +214,7 @@ class TenantApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(tenant_fields) + @console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__]) def post(self): if request.path == "/info": logger.warning("Deprecated URL /info was used.") @@ -200,7 +234,13 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") - return WorkspaceService.get_tenant_info(tenant), 200 + return ( + TenantInfoResponse.model_validate( + WorkspaceService.get_tenant_info(tenant), + from_attributes=True, + ).model_dump(mode="json"), + 200, + ) @console_ns.route("/workspaces/switch") @@ -240,8 +280,10 @@ class CustomConfigWorkspaceApi(Resource): args = WorkspaceCustomConfigPayload.model_validate(payload) tenant = db.get_or_404(Tenant, current_tenant_id) - custom_config_dict = { - "remove_webapp_brand": args.remove_webapp_brand, + custom_config_dict: TenantCustomConfigDict = { + "remove_webapp_brand": args.remove_webapp_brand + if args.remove_webapp_brand is not None + else tenant.custom_config_dict.get("remove_webapp_brand", False), "replace_webapp_logo": args.replace_webapp_logo if args.replace_webapp_logo is not None else tenant.custom_config_dict.get("replace_webapp_logo"), diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 4b5fb7ca5b9..ef2931ce9b9 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -20,7 +20,7 @@ from models.account import AccountStatus from models.dataset import RateLimitLog from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus -from services.operation_service import OperationService +from services.operation_service import OperationService, UtmInfo from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout @@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]: utm_info = request.cookies.get("utm_info") if utm_info: - utm_info_dict: dict = json.loads(utm_info) + utm_info_dict: UtmInfo = json.loads(utm_info) OperationService.record_utm(current_tenant_id, utm_info_dict) return view(*args, **kwargs) diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index 3b673d6e1d3..915a11dcddc 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -18,7 +18,8 @@ from controllers.inner_api.wraps import enterprise_inner_api_only from extensions.ext_database import db from models import Account, App from models.account import AccountStatus -from services.app_dsl_service import AppDslService, ImportMode, ImportStatus +from services.app_dsl_service import AppDslService +from services.entities.dsl_entities import ImportMode, ImportStatus class InnerAppDSLImportPayload(BaseModel): @@ -55,7 +56,7 @@ class EnterpriseAppDSLImport(Resource): account.set_tenant_id(workspace_id) - with Session(db.engine) as session: + with Session(db.engine, expire_on_commit=False) as session: dsl_service = AppDslService(session) result = dsl_service.import_app( account=account, @@ -64,7 +65,10 @@ class EnterpriseAppDSLImport(Resource): name=args.name, description=args.description, ) - session.commit() + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 83c8fa02fee..72cab3de737 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,5 +1,4 @@ from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -30,6 +29,7 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import get_signed_file_url_for_plugin +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 1d378c754c6..2f309262cb8 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel): def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ - Get current user + Get current user. NOTE: user_id is not trusted, it could be maliciously set to any value. - As a result, it could only be considered as an end user id. + As a result, it could only be considered as an end user id. Even when a + concrete end-user ID is supplied, lookups must stay tenant-scoped so one + tenant cannot bind another tenant's user record into the plugin request + context. """ if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID @@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: .limit(1) ) else: - user_model = session.get(EndUser, user_id) + user_model = session.scalar( + select(EndUser) + .where( + EndUser.id == user_id, + EndUser.tenant_id == tenant_id, + ) + .limit(1) + ) if not user_model: user_model = EndUser( @@ -94,10 +104,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]: def plugin_data[**P, R]( - view: Callable[P, R] | None = None, *, payload_type: type[BaseModel], -) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: +) -> Callable[[Callable[P, R]], Callable[P, R]]: def decorator(view_func: Callable[P, R]) -> Callable[P, R]: @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: @@ -116,7 +125,4 @@ def plugin_data[**P, R]( return decorated_view - if view is None: - return decorator - else: - return decorator(view) + return decorator diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 3c59535a48f..f652bbc5814 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,8 +2,8 @@ from typing import Any, Union from flask import Response from flask_restx import Resource -from graphon.variables.input_entities import VariableEntity from pydantic import BaseModel, Field, ValidationError +from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from controllers.common.schema import register_schema_model @@ -11,6 +11,7 @@ from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser @@ -80,11 +81,11 @@ class MCPAppApi(Resource): def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]: """Get and validate MCP server and app in one query session""" - mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() + mcp_server = session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1)) if not mcp_server: raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found") - app = session.query(App).where(App.id == mcp_server.app_id).first() + app = session.scalar(select(App).where(App.id == mcp_server.app_id).limit(1)) if not app: raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found") @@ -157,14 +158,20 @@ class MCPAppApi(Resource): except ValidationError as e: raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") - def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]: + def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]: """Convert raw user input form to VariableEntity objects""" return [self._create_variable_entity(item) for item in raw_form] - def _create_variable_entity(self, item: dict) -> VariableEntity: + def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity: """Create a single VariableEntity from raw form item""" - variable_type = item.get("type", "") or list(item.keys())[0] - variable = item[variable_type] + variable_type_raw: str = item.get("type", "") or list(item.keys())[0] + try: + variable_type = VariableEntityType(variable_type_raw) + except ValueError as e: + raise MCPRequestError( + mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}" + ) from e + variable = item[variable_type_raw] return VariableEntity( type=variable_type, @@ -177,7 +184,7 @@ class MCPAppApi(Resource): json_schema=variable.get("json_schema"), ) - def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification: + def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification: """Parse and validate MCP request""" try: return mcp_types.ClientRequest.model_validate(args) @@ -190,12 +197,12 @@ class MCPAppApi(Resource): def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None: """Get end user - manages its own database session""" with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - return ( - session.query(EndUser) + return session.scalar( + select(EndUser) .where(EndUser.tenant_id == tenant_id) .where(EndUser.session_id == mcp_server_id) .where(EndUser.type == "mcp") - .first() + .limit(1) ) def _create_end_user( diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index c22190cbc92..00bb9aa4639 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import Annotation, AnnotationList from models.model import App -from services.annotation_service import AppAnnotationService +from services.annotation_service import ( + AppAnnotationService, + EnableAnnotationArgs, + InsertAnnotationArgs, + UpdateAnnotationArgs, +) class AnnotationCreatePayload(BaseModel): @@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource): @validate_app_token def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" - args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() + payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}) match action: case "enable": - result = AppAnnotationService.enable_app_annotation(args, app_model.id) + enable_args: EnableAnnotationArgs = { + "score_threshold": payload.score_threshold, + "embedding_provider_name": payload.embedding_provider_name, + "embedding_model_name": payload.embedding_model_name, + } + result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id) case "disable": result = AppAnnotationService.disable_app_annotation(app_model.id) return result, 200 @@ -135,8 +145,9 @@ class AnnotationListApi(Resource): @validate_app_token def post(self, app_model: App): """Create a new annotation.""" - args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() - annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) + payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) + insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer} + annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json"), HTTPStatus.CREATED @@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource): @edit_permission_required def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" - args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() - annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) + payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) + update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer} + annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 6228cfc25be..e818573b8f4 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,11 +2,10 @@ import logging from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.controller_schemas import TextToAudioPayload from controllers.common.schema import register_schema_model from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( @@ -22,6 +21,7 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( @@ -86,13 +86,6 @@ class AudioApi(Resource): raise InternalServerError() -class TextToAudioPayload(BaseModel): - message_id: str | None = Field(default=None, description="Message ID") - voice: str | None = Field(default=None, description="Voice to use for TTS") - text: str | None = Field(default=None, description="Text to convert to audio") - streaming: bool | None = Field(default=None, description="Enable streaming response") - - register_schema_model(service_api_ns, TextToAudioPayload) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 3142e5118e9..31f2797d66e 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,7 +4,6 @@ from uuid import UUID from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -29,6 +28,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 8c9a3eb5e97..ca4b18cb5e0 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,26 +1,27 @@ +from datetime import datetime from typing import Any, Literal from flask import request from flask_restx import Resource -from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, NotFound import services +from controllers.common.controller_schemas import ConversationRenamePayload from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db +from fields._value_type_serializer import serialize_value_type +from fields.base import ResponseModel from fields.conversation_fields import ( ConversationInfiniteScrollPagination, SimpleConversation, ) -from fields.conversation_variable_fields import ( - build_conversation_variable_infinite_scroll_pagination_model, - build_conversation_variable_model, -) +from graphon.variables.types import SegmentType from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService @@ -34,18 +35,6 @@ class ConversationListQuery(BaseModel): ) -class ConversationRenamePayload(BaseModel): - name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)") - auto_generate: bool = Field(default=False, description="Auto-generate conversation name") - - @model_validator(mode="after") - def validate_name_requirement(self): - if not self.auto_generate: - if self.name is None or not self.name.strip(): - raise ValueError("name is required when auto_generate is false") - return self - - class ConversationVariablesQuery(BaseModel): last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") @@ -81,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel): value: Any +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type()) + if isinstance(value, str): + try: + return str(SegmentType(value).exposed_type()) + except ValueError: + return value + try: + return serialize_value_type(value) + except (AttributeError, TypeError, ValueError): + pass + + try: + return serialize_value_type({"value_type": value}) + except (AttributeError, TypeError, ValueError): + value_attr = getattr(value, "value", None) + if value_attr is not None: + return str(value_attr) + return str(value) + + @field_validator("value", mode="before") + @classmethod + def normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[ConversationVariableResponse] + + register_schema_models( service_api_ns, ConversationListQuery, ConversationRenamePayload, ConversationVariablesQuery, ConversationVariableUpdatePayload, + ConversationVariableResponse, + ConversationVariableInfiniteScrollPaginationResponse, ) @@ -215,8 +262,12 @@ class ConversationVariablesApi(Resource): 404: "Conversation not found", } ) + @service_api_ns.response( + 200, + "Variables retrieved successfully", + service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__], + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser, c_id): """List all variables for a conversation. @@ -233,9 +284,12 @@ class ConversationVariablesApi(Resource): last_id = str(query_args.last_id) if query_args.last_id else None try: - return ConversationService.get_conversational_variable( + pagination = ConversationService.get_conversational_variable( app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name ) + return ConversationVariableInfiniteScrollPaginationResponse.model_validate( + pagination, from_attributes=True + ).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -254,8 +308,12 @@ class ConversationVariableDetailApi(Resource): 404: "Conversation or variable not found", } ) + @service_api_ns.response( + 200, + "Variable updated successfully", + service_api_ns.models[ConversationVariableResponse.__name__], + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns)) def put(self, app_model: App, end_user: EndUser, c_id, variable_id): """Update a conversation variable's value. @@ -272,9 +330,10 @@ class ConversationVariableDetailApi(Resource): payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {}) try: - return ConversationService.update_conversation_variable( + variable = ConversationService.update_conversation_variable( app_model, conversation_id, variable_id, end_user, payload.value ) + return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationVariableNotExistsError: diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 77fee9c1421..b75b299f6fe 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,5 +1,4 @@ import logging -from typing import Literal from flask import request from flask_restx import Resource @@ -7,6 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services +from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import NotChatAppError @@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem -from libs.helper import UUIDStrOrEmpty from models.enums import FeedbackRating from models.model import App, AppMode, EndUser from services.errors.message import ( @@ -27,17 +26,6 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) -class MessageListQuery(BaseModel): - conversation_id: UUIDStrOrEmpty - first_id: UUIDStrOrEmpty | None = None - limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") - - -class MessageFeedbackPayload(BaseModel): - rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") - content: str | None = Field(default=None, description="Feedback content") - - class FeedbackListQuery(BaseModel): page: int = Field(default=1, ge=1, description="Page number") limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index d7992a2a3a1..cc763fa89c3 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,16 +1,16 @@ import logging -from typing import Any, Literal +from collections.abc import Mapping +from datetime import datetime +from typing import Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Namespace, Resource, fields -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound +from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( @@ -32,9 +32,13 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -46,9 +50,7 @@ from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) -class WorkflowRunPayload(BaseModel): - inputs: dict[str, Any] - files: list[dict[str, Any]] | None = None +class WorkflowRunPayload(WorkflowRunPayloadBase): response_mode: Literal["blocking", "streaming"] | None = None @@ -66,38 +68,142 @@ class WorkflowLogQuery(BaseModel): register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +def _enum_value(value): + return getattr(value, "value", value) + + class WorkflowRunStatusField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - return obj.status.value + return _enum_value(obj.status) class WorkflowRunOutputsField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - if obj.status == WorkflowExecutionStatus.PAUSED: + status = _enum_value(obj.status) + if status == WorkflowExecutionStatus.PAUSED.value: return {} outputs = obj.outputs_dict return outputs or {} -workflow_run_fields = { - "id": fields.String, - "workflow_id": fields.String, - "status": WorkflowRunStatusField, - "inputs": fields.Raw, - "outputs": WorkflowRunOutputsField, - "error": fields.String, - "total_steps": fields.Integer, - "total_tokens": fields.Integer, - "created_at": TimestampField, - "finished_at": OptionalTimestampField, - "elapsed_time": fields.Float, -} +class WorkflowRunResponse(ResponseModel): + id: str + workflow_id: str + status: str + inputs: dict | list | str | int | float | bool | None = None + outputs: dict = Field(default_factory=dict) + error: str | None = None + total_steps: int | None = None + total_tokens: int | None = None + created_at: int | None = None + finished_at: int | None = None + elapsed_time: float | int | None = None + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_workflow_run_model(api_or_ns: Namespace): - """Build the workflow run model for the API or Namespace.""" - return api_or_ns.model("WorkflowRun", workflow_run_fields) +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | int | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", "triggered_from", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: dict | list | str | int | float | bool | None = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_from", "created_by_role", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +register_schema_models( + service_api_ns, + WorkflowRunResponse, + WorkflowRunForLogResponse, + WorkflowAppLogPartialResponse, + WorkflowAppLogPaginationResponse, +) + + +def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict: + status = _enum_value(workflow_run.status) + raw_outputs = workflow_run.outputs_dict + if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None: + outputs: dict = {} + elif isinstance(raw_outputs, dict): + outputs = raw_outputs + elif isinstance(raw_outputs, Mapping): + outputs = dict(raw_outputs) + else: + outputs = {} + return WorkflowRunResponse.model_validate( + { + "id": workflow_run.id, + "workflow_id": workflow_run.workflow_id, + "status": status, + "inputs": workflow_run.inputs, + "outputs": outputs, + "error": workflow_run.error, + "total_steps": workflow_run.total_steps, + "total_tokens": workflow_run.total_tokens, + "created_at": workflow_run.created_at, + "finished_at": workflow_run.finished_at, + "elapsed_time": workflow_run.elapsed_time, + } + ).model_dump(mode="json") + + +def _serialize_workflow_log_pagination(pagination) -> dict: + return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json") @service_api_ns.route("/workflows/run/") @@ -113,7 +219,11 @@ class WorkflowRunDetailApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) + @service_api_ns.response( + 200, + "Workflow run details retrieved successfully", + service_api_ns.models[WorkflowRunResponse.__name__], + ) def get(self, app_model: App, workflow_run_id: str): """Get a workflow task running detail. @@ -134,7 +244,7 @@ class WorkflowRunDetailApi(Resource): ) if not workflow_run: raise NotFound("Workflow run not found.") - return workflow_run + return _serialize_workflow_run(workflow_run) @service_api_ns.route("/workflows/run") @@ -300,7 +410,11 @@ class WorkflowAppLogApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) + @service_api_ns.response( + 200, + "Logs retrieved successfully", + service_api_ns.models[WorkflowAppLogPaginationResponse.__name__], + ) def get(self, app_model: App): """Get workflow app logs. @@ -328,4 +442,4 @@ class WorkflowAppLogApi(Resource): created_by_account=args.created_by_account, ) - return workflow_app_log_pagination + return _serialize_workflow_log_pagination(workflow_app_log_pagination) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 80205b283bc..76519cad0a0 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,7 +2,6 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound @@ -19,13 +18,21 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum +from models.enums import TagType from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel -from services.tag_service import TagService +from services.tag_service import ( + SaveTagPayload, + TagBindingCreatePayload, + TagBindingDeletePayload, + TagService, + UpdateTagPayload, +) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -513,7 +520,7 @@ class DatasetTagsApi(DatasetApiResource): raise Forbidden() payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) - tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) + tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE)) response = DataSetTag.model_validate( {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} @@ -536,9 +543,8 @@ class DatasetTagsApi(DatasetApiResource): raise Forbidden() payload = TagUpdatePayload.model_validate(service_api_ns.payload or {}) - params = {"name": payload.name, "type": "knowledge"} tag_id = payload.tag_id - tag = TagService.update_tags(params, tag_id) + tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -585,7 +591,9 @@ class DatasetTagBindingApi(DatasetApiResource): raise Forbidden() payload = TagBindingPayload.model_validate(service_api_ns.payload or {}) - TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"}) + TagService.save_tag_binding( + TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE) + ) return "", 204 @@ -609,7 +617,9 @@ class DatasetTagUnbindingApi(DatasetApiResource): raise Forbidden() payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) - TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"}) + TagService.delete_tag_binding( + TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE) + ) return "", 204 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 2c094aa3e6e..bc28ecb6b77 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,4 +1,12 @@ +"""Service API endpoints for dataset document management. + +The canonical Service API paths use hyphenated route segments. Legacy underscore +aliases remain registered for backward compatibility, but they must stay marked +deprecated in generated API docs so clients migrate toward the canonical paths. +""" + import json +from collections.abc import Mapping from contextlib import ExitStack from typing import Self from uuid import UUID @@ -10,6 +18,7 @@ from sqlalchemy import desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services +from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload from controllers.common.errors import ( FilenameNotExistsError, FileTooLargeError, @@ -31,6 +40,7 @@ from controllers.service_api.wraps import ( cloud_edition_billing_resource_check, ) from core.errors.error import ProviderTokenNotInitError +from core.rag.entities import PreProcessingRule, Rule, Segmentation from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields @@ -40,11 +50,8 @@ from models.enums import SegmentStatus from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( KnowledgeConfig, - PreProcessingRule, ProcessRule, RetrievalModel, - Rule, - Segmentation, ) from services.file_service import FileService from services.summary_index_service import SummaryIndexService @@ -102,15 +109,6 @@ class DocumentListQuery(BaseModel): status: str | None = Field(default=None, description="Document status filter") -DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 - - -class DocumentBatchDownloadZipPayload(BaseModel): - """Request payload for bulk downloading uploaded documents as a ZIP archive.""" - - document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS) - - register_enum_models(service_api_ns, RetrievalMethod) register_schema_models( @@ -127,12 +125,137 @@ register_schema_models( ) -@service_api_ns.route( - "/datasets//document/create_by_text", - "/datasets//document/create-by-text", -) +def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]: + """Create a document from text for both canonical and legacy routes.""" + payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {}) + args = payload.model_dump(exclude_none=True) + + dataset_id_str = str(dataset_id) + tenant_id_str = str(tenant_id) + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1) + ) + + if not dataset: + raise ValueError("Dataset does not exist.") + + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") + + embedding_model_provider = payload.embedding_model_provider + embedding_model = payload.embedding_model + if embedding_model_provider and embedding_model: + DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model) + + retrieval_model = payload.retrieval_model + if ( + retrieval_model + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name + ): + DatasetService.check_reranking_model_setting( + tenant_id_str, + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, + ) + + if not current_user: + raise ValueError("current_user is required") + + upload_file = FileService(db.engine).upload_text( + text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str + ) + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } + args["data_source"] = data_source + knowledge_config = KnowledgeConfig.model_validate(args) + DocumentService.document_create_args_validate(knowledge_config) + + if not current_user: + raise ValueError("current_user is required") + + try: + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=current_user, + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} + return documents_and_batch_fields, 200 + + +def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]: + """Update a document from text for both canonical and legacy routes.""" + payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + ) + args = payload.model_dump(exclude_none=True) + if not dataset: + raise ValueError("Dataset does not exist.") + + retrieval_model = payload.retrieval_model + if ( + retrieval_model + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name + ): + DatasetService.check_reranking_model_setting( + tenant_id, + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, + ) + + # indexing_technique is already set in dataset since this is an update + args["indexing_technique"] = dataset.indexing_technique + + if args.get("text"): + text = args.get("text") + name = args.get("name") + if not current_user: + raise ValueError("current_user is required") + upload_file = FileService(db.engine).upload_text( + text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id + ) + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } + args["data_source"] = data_source + + args["original_document_id"] = str(document_id) + knowledge_config = KnowledgeConfig.model_validate(args) + DocumentService.document_create_args_validate(knowledge_config) + + try: + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=current_user, + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} + return documents_and_batch_fields, 200 + + +@service_api_ns.route("/datasets//document/create-by-text") class DocumentAddByTextApi(DatasetApiResource): - """Resource for documents.""" + """Resource for the canonical text document creation route.""" @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__]) @service_api_ns.doc("create_document_by_text") @@ -148,81 +271,43 @@ class DocumentAddByTextApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id): + def post(self, tenant_id: str, dataset_id: UUID): """Create document by text.""" - payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {}) - args = payload.model_dump(exclude_none=True) + return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id) - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + +@service_api_ns.route("/datasets//document/create_by_text") +class DeprecatedDocumentAddByTextApi(DatasetApiResource): + """Deprecated resource alias for text document creation.""" + + @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__]) + @service_api_ns.doc("create_document_by_text_deprecated") + @service_api_ns.doc(deprecated=True) + @service_api_ns.doc( + description=( + "Deprecated legacy alias for creating a new document by providing text content. " + "Use /datasets/{dataset_id}/document/create-by-text instead." ) - - if not dataset: - raise ValueError("Dataset does not exist.") - - if not dataset.indexing_technique and not args["indexing_technique"]: - raise ValueError("indexing_technique is required.") - - embedding_model_provider = payload.embedding_model_provider - embedding_model = payload.embedding_model - if embedding_model_provider and embedding_model: - DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) - - retrieval_model = payload.retrieval_model - if ( - retrieval_model - and retrieval_model.reranking_model - and retrieval_model.reranking_model.reranking_provider_name - and retrieval_model.reranking_model.reranking_model_name - ): - DatasetService.check_reranking_model_setting( - tenant_id, - retrieval_model.reranking_model.reranking_provider_name, - retrieval_model.reranking_model.reranking_model_name, - ) - - if not current_user: - raise ValueError("current_user is required") - - upload_file = FileService(db.engine).upload_text( - text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id - ) - data_source = { - "type": "upload_file", - "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + ) + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Document created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid parameters", } - args["data_source"] = data_source - knowledge_config = KnowledgeConfig.model_validate(args) - # validate args - DocumentService.document_create_args_validate(knowledge_config) - - if not current_user: - raise ValueError("current_user is required") - - try: - documents, batch = DocumentService.save_document_with_dataset_id( - dataset=dataset, - knowledge_config=knowledge_config, - account=current_user, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, - created_from="api", - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - document = documents[0] - - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} - return documents_and_batch_fields, 200 + ) + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def post(self, tenant_id: str, dataset_id: UUID): + """Create document by text through the deprecated underscore alias.""" + return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id) -@service_api_ns.route( - "/datasets//documents//update_by_text", - "/datasets//documents//update-by-text", -) +@service_api_ns.route("/datasets//documents//update-by-text") class DocumentUpdateByTextApi(DatasetApiResource): - """Resource for update documents.""" + """Resource for the canonical text document update route.""" @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__]) @service_api_ns.doc("update_document_by_text") @@ -239,62 +324,35 @@ class DocumentUpdateByTextApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" - payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) - dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) + + +@service_api_ns.route("/datasets//documents//update_by_text") +class DeprecatedDocumentUpdateByTextApi(DatasetApiResource): + """Deprecated resource alias for text document updates.""" + + @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__]) + @service_api_ns.doc("update_document_by_text_deprecated") + @service_api_ns.doc(deprecated=True) + @service_api_ns.doc( + description=( + "Deprecated legacy alias for updating an existing document by providing text content. " + "Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead." ) - args = payload.model_dump(exclude_none=True) - if not dataset: - raise ValueError("Dataset does not exist.") - - retrieval_model = payload.retrieval_model - if ( - retrieval_model - and retrieval_model.reranking_model - and retrieval_model.reranking_model.reranking_provider_name - and retrieval_model.reranking_model.reranking_model_name - ): - DatasetService.check_reranking_model_setting( - tenant_id, - retrieval_model.reranking_model.reranking_provider_name, - retrieval_model.reranking_model.reranking_model_name, - ) - - # indexing_technique is already set in dataset since this is an update - args["indexing_technique"] = dataset.indexing_technique - - if args.get("text"): - text = args.get("text") - name = args.get("name") - if not current_user: - raise ValueError("current_user is required") - upload_file = FileService(db.engine).upload_text( - text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id - ) - data_source = { - "type": "upload_file", - "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, - } - args["data_source"] = data_source - # validate args - args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig.model_validate(args) - DocumentService.document_create_args_validate(knowledge_config) - - try: - documents, batch = DocumentService.save_document_with_dataset_id( - dataset=dataset, - knowledge_config=knowledge_config, - account=current_user, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, - created_from="api", - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - document = documents[0] - - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} - return documents_and_batch_fields, 200 + ) + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): + """Update document by text through the deprecated underscore alias.""" + return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) @service_api_ns.route( @@ -529,7 +587,7 @@ class DocumentListApi(DatasetApiResource): if not dataset: raise NotFound("Dataset not found.") - query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) + query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id) if query_params.status: query = DocumentService.apply_display_status_filter(query, query_params.status) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 52166f7fcc6..21db7d0cb8c 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -2,9 +2,9 @@ from typing import Literal from flask_login import current_user from flask_restx import marshal -from pydantic import BaseModel from werkzeug.exceptions import NotFound +from controllers.common.controller_schemas import MetadataUpdatePayload from controllers.common.schema import register_schema_model, register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check @@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.metadata_service import MetadataService - -class MetadataUpdatePayload(BaseModel): - name: str - - register_schema_model(service_api_ns, MetadataUpdatePayload) register_schema_models( service_api_ns, diff --git a/api/controllers/service_api/dataset/rag_pipeline/serializers.py b/api/controllers/service_api/dataset/rag_pipeline/serializers.py index 8533c9c01d5..a5e8484037e 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/serializers.py +++ b/api/controllers/service_api/dataset/rag_pipeline/serializers.py @@ -4,13 +4,23 @@ Serialization helpers for Service API knowledge pipeline endpoints. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, TypedDict if TYPE_CHECKING: from models.model import UploadFile -def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]: +class UploadFileDict(TypedDict): + id: str + name: str + size: int + extension: str + mime_type: str | None + created_by: str + created_at: str | None + + +def serialize_upload_file(upload_file: UploadFile) -> UploadFileDict: return { "id": upload_file.id, "name": upload_file.name, diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 5b16da81e08..5992fa74101 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,12 +2,12 @@ from typing import Any from flask import request from flask_restx import marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config +from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError @@ -22,6 +22,7 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService @@ -32,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from services.summary_index_service import SummaryIndexService -def _marshal_segment_with_summary(segment, dataset_id: str) -> dict: +def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]: """Marshal a single segment and enrich it with summary content.""" - segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) segment_dict["summary"] = summary.summary_content if summary else None return segment_dict -def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]: +def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]: """Marshal multiple segments and enrich them with summary content (batch query).""" segment_ids = [segment.id for segment in segments] - summaries: dict = {} + summaries: dict[str, str | None] = {} if segment_ids: summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()} - result = [] + result: list[dict[str, Any]] = [] for segment in segments: - segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] segment_dict["summary"] = summaries.get(segment.id) result.append(segment_dict) return result @@ -69,20 +70,12 @@ class SegmentUpdatePayload(BaseModel): segment: SegmentUpdateArgs -class ChildChunkCreatePayload(BaseModel): - content: str - - class ChildChunkListQuery(BaseModel): limit: int = Field(default=20, ge=1) keyword: str | None = None page: int = Field(default=1, ge=1) -class ChildChunkUpdatePayload(BaseModel): - content: str - - register_schema_models( service_api_ns, SegmentCreatePayload, diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index c0a6cb0a763..5ac65fc4e6d 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token +from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 2dd916bb312..b9389ccc479 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,9 +1,10 @@ +import inspect import logging import time from collections.abc import Callable from enum import StrEnum, auto from functools import wraps -from typing import Any, cast, overload +from typing import cast, overload from flask import current_app, request from flask_login import user_logged_in @@ -230,94 +231,73 @@ def cloud_edition_billing_rate_limit_check[**P, R]( return interceptor -def validate_dataset_token( - view: Callable[..., Any] | None = None, -) -> Callable[..., Any] | Callable[[Callable[..., Any]], Callable[..., Any]]: - def decorator(view_func: Callable[..., Any]) -> Callable[..., Any]: - @wraps(view_func) - def decorated(*args: Any, **kwargs: Any) -> Any: - api_token = validate_and_get_api_token("dataset") +def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]: + positional_parameters = [ + parameter + for parameter in inspect.signature(view).parameters.values() + if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) + ] + expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"}) - # get url path dataset_id from positional args or kwargs - # Flask passes URL path parameters as positional arguments - dataset_id = None + @wraps(view) + def decorated(*args: object, **kwargs: object) -> R: + api_token = validate_and_get_api_token("dataset") - # First try to get from kwargs (explicit parameter) - dataset_id = kwargs.get("dataset_id") + # Flask may pass URL path parameters positionally, so inspect both kwargs and args. + dataset_id = kwargs.get("dataset_id") - # If not in kwargs, try to extract from positional args - if not dataset_id and args: - # For class methods: args[0] is self, args[1] is dataset_id (if exists) - # Check if first arg is likely a class instance (has __dict__ or __class__) - if len(args) > 1 and hasattr(args[0], "__dict__"): - # This is a class method, dataset_id should be in args[1] - potential_id = args[1] - # Validate it's a string-like UUID, not another object - try: - # Try to convert to string and check if it's a valid UUID format - str_id = str(potential_id) - # Basic check: UUIDs are 36 chars with hyphens - if len(str_id) == 36 and str_id.count("-") == 4: - dataset_id = str_id - except Exception: - logger.exception("Failed to parse dataset_id from class method args") - elif len(args) > 0: - # Not a class method, check if args[0] looks like a UUID - potential_id = args[0] - try: - str_id = str(potential_id) - if len(str_id) == 36 and str_id.count("-") == 4: - dataset_id = str_id - except Exception: - logger.exception("Failed to parse dataset_id from positional args") + if not dataset_id and args: + potential_id = args[0] + try: + str_id = str(potential_id) + if len(str_id) == 36 and str_id.count("-") == 4: + dataset_id = str_id + except Exception: + logger.exception("Failed to parse dataset_id from positional args") - # Validate dataset if dataset_id is provided - if dataset_id: - dataset_id = str(dataset_id) - dataset = db.session.scalar( - select(Dataset) - .where( - Dataset.id == dataset_id, - Dataset.tenant_id == api_token.tenant_id, - ) - .limit(1) + if dataset_id: + dataset_id = str(dataset_id) + dataset = db.session.scalar( + select(Dataset) + .where( + Dataset.id == dataset_id, + Dataset.tenant_id == api_token.tenant_id, ) - if not dataset: - raise NotFound("Dataset not found.") - if not dataset.enable_api: - raise Forbidden("Dataset api access is not enabled.") - tenant_account_join = db.session.execute( - select(Tenant, TenantAccountJoin) - .where(Tenant.id == api_token.tenant_id) - .where(TenantAccountJoin.tenant_id == Tenant.id) - .where(TenantAccountJoin.role.in_(["owner"])) - .where(Tenant.status == TenantStatus.NORMAL) - ).one_or_none() # TODO: only owner information is required, so only one is returned. - if tenant_account_join: - tenant, ta = tenant_account_join - account = db.session.get(Account, ta.account_id) - # Login admin - if account: - account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore - else: - raise Unauthorized("Tenant owner account does not exist.") + .limit(1) + ) + if not dataset: + raise NotFound("Dataset not found.") + if not dataset.enable_api: + raise Forbidden("Dataset api access is not enabled.") + + tenant_account_join = db.session.execute( + select(Tenant, TenantAccountJoin) + .where(Tenant.id == api_token.tenant_id) + .where(TenantAccountJoin.tenant_id == Tenant.id) + .where(TenantAccountJoin.role.in_(["owner"])) + .where(Tenant.status == TenantStatus.NORMAL) + ).one_or_none() # TODO: only owner information is required, so only one is returned. + if tenant_account_join: + tenant, ta = tenant_account_join + account = db.session.get(Account, ta.account_id) + # Login admin + if account: + account.current_tenant = tenant + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore else: - raise Unauthorized("Tenant does not exist.") - if args and isinstance(args[0], Resource): - return view_func(args[0], api_token.tenant_id, *args[1:], **kwargs) + raise Unauthorized("Tenant owner account does not exist.") + else: + raise Unauthorized("Tenant does not exist.") - return view_func(api_token.tenant_id, *args, **kwargs) + if expects_bound_instance: + if not args: + raise TypeError("validate_dataset_token expected a bound resource instance.") + return view(args[0], api_token.tenant_id, *args[1:], **kwargs) - return decorated + return view(api_token.tenant_id, *args, **kwargs) - if view: - return decorator(view) - - # if view is None, it means that the decorator is used without parentheses - # use the decorator as a function for method_decorators - return decorator + return decorated def validate_and_get_api_token(scope: str | None = None): diff --git a/api/controllers/trigger/webhook.py b/api/controllers/trigger/webhook.py index eb579da5d4d..213704383c3 100644 --- a/api/controllers/trigger/webhook.py +++ b/api/controllers/trigger/webhook.py @@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge from controllers.trigger import bp from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key -from services.trigger.webhook_service import WebhookService +from services.trigger.webhook_service import RawWebhookDataDict, WebhookService logger = logging.getLogger(__name__) @@ -23,6 +23,7 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False): webhook_id, is_debug=is_debug ) + webhook_data: RawWebhookDataDict try: # Use new unified extraction and validation webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config) diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 9ba1dc4a3ac..3ad595f1f4f 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,11 +2,11 @@ import logging from flask import request from flask_restx import fields, marshal_with -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, field_validator +from pydantic import field_validator from werkzeug.exceptions import InternalServerError import services +from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, @@ -21,6 +21,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService @@ -34,12 +35,7 @@ from services.errors.audio import ( from ..common.schema import register_schema_models -class TextToAudioPayload(BaseModel): - message_id: str | None = None - voice: str | None = None - text: str | None = None - streaming: bool | None = None - +class TextToAudioPayload(TextToAudioPayloadBase): @field_validator("message_id") @classmethod def validate_message_id(cls, value: str | None) -> str | None: diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e37f9af5f0a..0528184d79e 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,7 +1,6 @@ import logging from typing import Any, Literal -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,6 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index d5baa5fb7d1..3975dd85c8a 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,10 +1,11 @@ from typing import Literal from flask import request -from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound +from controllers.common.controller_schemas import ConversationRenamePayload from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotChatAppError @@ -37,18 +38,6 @@ class ConversationListQuery(BaseModel): return uuid_value(value) -class ConversationRenamePayload(BaseModel): - name: str | None = None - auto_generate: bool = False - - @model_validator(mode="after") - def validate_name_requirement(self): - if not self.auto_generate: - if self.name is None or not self.name.strip(): - raise ValueError("name is required when auto_generate is false") - return self - - register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload) diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index d69571cc9cf..61fd794c229 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -3,8 +3,6 @@ import secrets from flask import request from flask_restx import Resource -from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( @@ -19,33 +17,15 @@ from controllers.console.error import EmailSendIpLimitError from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required from controllers.web import web_ns from extensions.ext_database import db -from libs.helper import EmailStr, extract_remote_ip -from libs.password import hash_password, valid_password +from libs.helper import extract_remote_ip +from libs.password import hash_password from models.account import Account from services.account_service import AccountService - - -class ForgotPasswordSendPayload(BaseModel): - email: EmailStr - language: str | None = None - - -class ForgotPasswordCheckPayload(BaseModel): - email: EmailStr - code: str - token: str = Field(min_length=1) - - -class ForgotPasswordResetPayload(BaseModel): - token: str = Field(min_length=1) - new_password: str - password_confirm: str - - @field_validator("new_password", "password_confirm") - @classmethod - def validate_password(cls, value: str) -> str: - return valid_password(value) - +from services.entities.auth_entities import ( + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordSendPayload, +) register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload) @@ -81,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) - token = None + account = AccountService.get_account_by_email_with_case_fallback(request_email) if account is None: raise AuthenticationFailedError() else: @@ -180,13 +158,14 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt) - else: - raise AuthenticationFailedError() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AuthenticationFailedError() return {"result": "success"} diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 36728a47d10..44876f83033 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -5,9 +5,11 @@ Web App Human Input Form APIs. import json import logging from datetime import datetime +from typing import Any, NotRequired, TypedDict from flask import Response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden @@ -23,6 +25,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ logger = logging.getLogger(__name__) + +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict + action: str + + _FORM_SUBMIT_RATE_LIMITER = RateLimiter( prefix="web_form_submit_rate_limit", max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, @@ -51,10 +59,19 @@ def _to_timestamp(value: datetime) -> int: return int(value.timestamp()) +class FormDefinitionPayload(TypedDict): + form_content: Any + inputs: Any + resolved_default_values: dict[str, str] + user_actions: Any + expiration_time: int + site: NotRequired[dict] + + def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response: """Return the form payload (optionally with site) as a JSON response.""" definition_payload = form.get_definition().model_dump() - payload = { + payload: FormDefinitionPayload = { "form_content": definition_payload["rendered_content"], "inputs": definition_payload["inputs"], "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), @@ -85,7 +102,7 @@ class HumanInputFormApi(Resource): _FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address) service = HumanInputService(db.engine) - # TODO(QuantumGhost): forbid submision for form tokens + # TODO(QuantumGhost): forbid submission for form tokens # that are only for console. form = service.get_form_by_token(form_token) @@ -112,10 +129,7 @@ class HumanInputFormApi(Resource): "action": "Approve" } """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() + payload = HumanInputFormSubmitPayload.model_validate(request.get_json()) ip_address = extract_remote_ip(request) if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address): @@ -135,8 +149,8 @@ class HumanInputFormApi(Resource): service.submit_form_by_token( recipient_type=recipient_type, form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], + selected_action_id=payload.action, + form_data=payload.inputs, submission_end_user_id=None, # submission_end_user_id=_end_user.id, ) diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index a824f6d4879..2255dd03321 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,7 +1,10 @@ +import logging + from flask import make_response, request from flask_restx import Resource from jwt import InvalidTokenError from pydantic import BaseModel, Field, field_validator +from werkzeug.exceptions import Unauthorized import services from configs import dify_config @@ -20,7 +23,7 @@ from controllers.console.wraps import ( ) from controllers.web import web_ns from controllers.web.wraps import decode_jwt_token -from libs.helper import EmailStr +from libs.helper import EmailStr, extract_remote_ip from libs.passport import PassportService from libs.password import valid_password from libs.token import ( @@ -29,13 +32,13 @@ from libs.token import ( ) from services.account_service import AccountService from services.app_service import AppService +from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase from services.webapp_auth_service import WebAppAuthService +logger = logging.getLogger(__name__) -class LoginPayload(BaseModel): - email: EmailStr - password: str +class LoginPayload(LoginPayloadBase): @field_validator("password") @classmethod def validate_password(cls, value: str) -> str: @@ -78,14 +81,18 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" payload = LoginPayload.model_validate(web_ns.payload or {}) + normalized_email = payload.email.lower() try: account = WebAppAuthService.authenticate(payload.email, payload.password) except services.errors.account.AccountLoginError: + _log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED) raise AccountBannedError() except services.errors.account.AccountPasswordError: + _log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS) raise AuthenticationFailedError() except services.errors.account.AccountNotFoundError: + _log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND) raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) @@ -214,21 +221,30 @@ class EmailCodeLoginApi(Resource): token_data = WebAppAuthService.get_email_code_login_data(payload.token) if token_data is None: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN) raise InvalidTokenError() token_email = token_data.get("email") if not isinstance(token_email, str): + _log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH) raise InvalidEmailError() normalized_token_email = token_email.lower() if normalized_token_email != user_email: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH) raise InvalidEmailError() if token_data["code"] != payload.code: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE) raise EmailCodeError() WebAppAuthService.revoke_email_code_login_token(payload.token) - account = WebAppAuthService.get_user_through_email(token_email) + try: + account = WebAppAuthService.get_user_through_email(token_email) + except Unauthorized as exc: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED) + raise AccountBannedError() from exc if not account: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND) raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) @@ -236,3 +252,12 @@ class EmailCodeLoginApi(Resource): response = make_response({"result": "success", "data": {"access_token": token}}) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) return response + + +def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None: + logger.warning( + "Web login failed: email=%s reason=%s ip_address=%s", + email, + reason, + extract_remote_ip(request), + ) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index c5505dd60de..07ecf8035bb 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,10 +2,10 @@ import logging from typing import Literal from flask import request -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field, TypeAdapter, field_validator +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( @@ -23,8 +23,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.helper import uuid_value from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -40,24 +40,6 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) -class MessageListQuery(BaseModel): - conversation_id: str = Field(description="Conversation UUID") - first_id: str | None = Field(default=None, description="First message ID for pagination") - limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") - - @field_validator("conversation_id", "first_id") - @classmethod - def validate_uuid(cls, value: str | None) -> str | None: - if value is None: - return value - return uuid_value(value) - - -class MessageFeedbackPayload(BaseModel): - rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating") - content: str | None = Field(default=None, description="Feedback content") - - class MessageMoreLikeThisQuery(BaseModel): response_mode: Literal["blocking", "streaming"] = Field( description="Response mode", diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 6a2e0b65fb8..0293df74b08 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,5 +1,6 @@ import uuid from datetime import UTC, datetime, timedelta +from typing import Any from flask import make_response, request from flask_restx import Resource @@ -103,21 +104,23 @@ class PassportResource(Resource): return response -def decode_enterprise_webapp_user_id(jwt_token: str | None): +def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None: """ Decode the enterprise user session from the Authorization header. """ if not jwt_token: return None - decoded = PassportService().verify(jwt_token) + decoded: dict[str, Any] = PassportService().verify(jwt_token) source = decoded.get("token_source") if not source or source != "webapp_login_token": raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.") return decoded -def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType): +def exchange_token_for_existing_web_user( + app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType +): """ Exchange a token for an existing web user session. """ @@ -138,12 +141,15 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() - if auth_type == WebAppAuthType.PUBLIC: - return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) - elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": - raise WebAppAuthRequiredError("Please login as external user.") - elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": - raise WebAppAuthRequiredError("Please login as internal user.") + match auth_type: + case WebAppAuthType.PUBLIC: + return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) + case WebAppAuthType.EXTERNAL: + if user_auth_type != "external": + raise WebAppAuthRequiredError("Please login as external user.") + case WebAppAuthType.INTERNAL: + if user_auth_type != "internal": + raise WebAppAuthRequiredError("Please login as internal user.") end_user = None if end_user_id: diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 38aeccc642b..fe31e9d4acd 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -14,6 +13,7 @@ from controllers.common.errors import ( from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 29993100f60..5b206f9a98a 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,27 +1,17 @@ from flask import request -from pydantic import BaseModel, Field, TypeAdapter +from pydantic import TypeAdapter from werkzeug.exceptions import NotFound +from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem -from libs.helper import UUIDStrOrEmpty from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService - -class SavedMessageListQuery(BaseModel): - last_id: UUIDStrOrEmpty | None = None - limit: int = Field(default=20, ge=1, le=100) - - -class SavedMessageCreatePayload(BaseModel): - message_id: UUIDStrOrEmpty - - register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload) diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 1a0c6d4252d..7d2080dd91c 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from flask_restx import fields, marshal, marshal_with from sqlalchemy import select @@ -113,12 +113,12 @@ class AppSiteInfo: } -def serialize_site(site: Site) -> dict: +def serialize_site(site: Site) -> dict[str, Any]: """Serialize Site model using the same schema as AppSiteApi.""" - return cast(dict, marshal(site, AppSiteApi.site_fields)) + return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields)) -def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict: +def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]: can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo) - return cast(dict, marshal(app_site_info, AppSiteApi.app_fields)) + return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields)) diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 7f5521f9f58..98211193a0a 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,11 +1,8 @@ import logging -from typing import Any -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError +from controllers.common.controller_schemas import WorkflowRunPayload from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( @@ -25,17 +22,13 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError - -class WorkflowRunPayload(BaseModel): - inputs: dict[str, Any] = Field(description="Input variables for the workflow") - files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow") - - logger = logging.getLogger(__name__) register_schema_models(web_ns, WorkflowRunPayload) diff --git a/api/controllers/web/workflow_events.py b/api/controllers/web/workflow_events.py index 61568e70e67..474f9c09579 100644 --- a/api/controllers/web/workflow_events.py +++ b/api/controllers/web/workflow_events.py @@ -72,12 +72,13 @@ class WorkflowEventsApi(WebApiResource): app_mode = AppMode.value_of(app_model.mode) msg_generator = MessageGenerator() generator: BaseAppGenerator - if app_mode == AppMode.ADVANCED_CHAT: - generator = AdvancedChatAppGenerator() - elif app_mode == AppMode.WORKFLOW: - generator = WorkflowAppGenerator() - else: - raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") + match app_mode: + case AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + case AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + case _: + raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}") include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 06c746990d2..c22102c2bac 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,20 +4,6 @@ import uuid from decimal import Decimal from typing import Union, cast -from graphon.file import file_manager -from graphon.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity @@ -43,6 +29,20 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 11e2aa062d2..0bc93ad34d3 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -2,16 +2,7 @@ import json import logging from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Any - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) +from typing import Any, TypedDict from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -24,11 +15,26 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) from models.model import Message logger = logging.getLogger(__name__) +class ActionDict(TypedDict): + """Shape produced by AgentScratchpadUnit.Action.to_dict().""" + + action: str + action_input: dict[str, Any] | str + + class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ["wenxin"] @@ -331,7 +337,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): return tool_invoke_response, tool_invoke_meta - def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: + def _convert_dict_to_action(self, action: ActionDict) -> AgentScratchpadUnit.Action: """ convert dict to action """ diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index a4c438e9296..a2186be100c 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,5 +1,6 @@ import json +from core.agent.cot_agent_runner import CotAgentRunner from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -11,8 +12,6 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.utils.encoders import jsonable_encoder -from core.agent.cot_agent_runner import CotAgentRunner - class CotChatAgentRunner(CotAgentRunner): def _organize_system_prompt(self) -> SystemPromptMessage: @@ -79,21 +78,18 @@ class CotChatAgentRunner(CotAgentRunner): if not agent_scratchpad: assistant_messages = [] else: - assistant_message = AssistantPromptMessage(content="") - assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str + content = "" for unit in agent_scratchpad: if unit.is_final(): - assert isinstance(assistant_message.content, str) - assistant_message.content += f"Final Answer: {unit.agent_response}" + content += f"Final Answer: {unit.agent_response}" else: - assert isinstance(assistant_message.content, str) - assistant_message.content += f"Thought: {unit.thought}\n\n" + content += f"Thought: {unit.thought}\n\n" if unit.action_str: - assistant_message.content += f"Action: {unit.action_str}\n\n" + content += f"Action: {unit.action_str}\n\n" if unit.observation: - assistant_message.content += f"Observation: {unit.observation}\n\n" + content += f"Observation: {unit.observation}\n\n" - assistant_messages = [assistant_message] + assistant_messages = [AssistantPromptMessage(content=content)] # query messages query_messages = self._organize_user_query(self._query, []) diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index d4c52a8eb16..51a30998ae2 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,5 +1,6 @@ import json +from core.agent.cot_agent_runner import CotAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -8,8 +9,6 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.utils.encoders import jsonable_encoder -from core.agent.cot_agent_runner import CotAgentRunner - class CotCompletionAgentRunner(CotAgentRunner): def _organize_instruction_prompt(self) -> str: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index fdffde85d01..29de0b8b1c4 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,6 +4,13 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -19,14 +26,6 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes - -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) @@ -300,7 +299,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): # update prompt tool for prompt_tool in prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + tool_instance = tool_instances.get(prompt_tool.name) + if tool_instance: + self.update_prompt_message_tool(tool_instance, prompt_tool) iteration_step += 1 diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 46c1f1230d0..f341ca5a0be 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -1,17 +1,16 @@ import json import re from collections.abc import Generator -from typing import Union - -from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from typing import Any, Union from core.agent.entities import AgentScratchpadUnit +from graphon.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: @classmethod def handle_react_stream_output( - cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any] ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]: action_name = None diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 90aa7b5fd47..8d25863a913 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel): identity: AgentStrategyIdentity parameters: list[AgentStrategyParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The description of the agent strategy") - output_schema: dict | None = None + output_schema: dict[str, Any] | None = None features: list[AgentFeature] | None = None meta_version: str | None = None # pydantic configs diff --git a/api/core/app/app_config/common/parameters_mapping/__init__.py b/api/core/app/app_config/common/parameters_mapping/__init__.py index 460fdfb3ba9..68686ceda65 100644 --- a/api/core/app/app_config/common/parameters_mapping/__init__.py +++ b/api/core/app/app_config/common/parameters_mapping/__init__.py @@ -5,6 +5,10 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS +class FeatureToggleDict(TypedDict): + enabled: bool + + class SystemParametersDict(TypedDict): image_file_size_limit: int video_file_size_limit: int @@ -16,12 +20,12 @@ class SystemParametersDict(TypedDict): class AppParametersDict(TypedDict): opening_statement: str | None suggested_questions: list[str] - suggested_questions_after_answer: dict[str, Any] - speech_to_text: dict[str, Any] - text_to_speech: dict[str, Any] - retriever_resource: dict[str, Any] - annotation_reply: dict[str, Any] - more_like_this: dict[str, Any] + suggested_questions_after_answer: FeatureToggleDict + speech_to_text: FeatureToggleDict + text_to_speech: FeatureToggleDict + retriever_resource: FeatureToggleDict + annotation_reply: FeatureToggleDict + more_like_this: FeatureToggleDict user_input_form: list[dict[str, Any]] sensitive_word_avoidance: dict[str, Any] file_upload: dict[str, Any] diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 7d1b11c0081..c8ec7cb44dc 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def validate_and_set_defaults( - cls, tenant_id: str, config: dict, only_structure_validate: bool = False - ) -> tuple[dict, list[str]]: + cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False + ) -> tuple[dict[str, Any], list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = {"enabled": False} diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index f04a8df1196..3d857a4e9c0 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -138,7 +138,9 @@ class DatasetConfigManager: ) @classmethod - def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults( + cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any] + ) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for dataset feature @@ -172,7 +174,7 @@ class DatasetConfigManager: return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] @classmethod - def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict): + def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]): """ Extract dataset config for legacy compatibility diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index b7dd55632e2..5df3df2b3e0 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,14 +1,13 @@ from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 5cc385c3781..02498c23e1d 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,10 +1,9 @@ from collections.abc import Mapping from typing import Any -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType - from core.app.app_config.entities import ModelConfigEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -41,7 +40,7 @@ class ModelConfigManager: ) @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for model config @@ -108,7 +107,7 @@ class ModelConfigManager: return dict(config), ["model"] @classmethod - def validate_model_completion_params(cls, cp: dict): + def validate_model_completion_params(cls, cp: dict[str, Any]): # model.completion_params if not isinstance(cp, dict): raise ValueError("model.completion_params must be of object type") diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 76196e7034e..4c07445df3d 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,7 +1,5 @@ from typing import Any -from graphon.model_runtime.entities.message_entities import PromptMessageRole - from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -9,6 +7,7 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict @@ -65,7 +64,7 @@ class PromptTemplateConfigManager: ) @classmethod - def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate pre_prompt and set defaults for prompt feature depending on the config['model'] @@ -130,7 +129,7 @@ class PromptTemplateConfigManager: return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] @classmethod - def validate_post_prompt_and_set_defaults(cls, config: dict): + def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]): """ Validate post_prompt and set defaults for prompt feature diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index f0b71c58016..ddb500cccf5 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,10 +1,9 @@ import re -from typing import cast - -from graphon.variables.input_entities import VariableEntity, VariableEntityType +from typing import Any, cast from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( @@ -82,7 +81,7 @@ class BasicVariablesConfigManager: return variable_entities, external_data_variables @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for user input form @@ -99,7 +98,7 @@ class BasicVariablesConfigManager: return config, related_config_keys @classmethod - def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for user input form @@ -164,7 +163,9 @@ class BasicVariablesConfigManager: return config, ["user_input_form"] @classmethod - def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + def validate_external_data_tools_and_set_defaults( + cls, tenant_id: str, config: dict[str, Any] + ) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for external data fetch feature diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 536617edba4..53563dc5dab 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,14 +1,14 @@ -from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal +from pydantic import BaseModel, Field + +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.entities import MetadataFilteringCondition from graphon.file import FileUploadConfig from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity -from pydantic import BaseModel, Field - -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from models.model import AppMode @@ -111,31 +111,6 @@ class ExternalDataVariableEntity(BaseModel): config: dict[str, Any] = Field(default_factory=dict) -SupportedComparisonOperator = Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - "in", - "not in", - # for number - "=", - "≠", - ">", - "<", - "≥", - "≤", - # for time - "before", - "after", -] - - class ModelConfig(BaseModel): provider: str name: str @@ -143,25 +118,6 @@ class ModelConfig(BaseModel): completion_params: dict[str, Any] = Field(default_factory=dict) -class Condition(BaseModel): - """ - Condition detail - """ - - name: str - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | None | int | float = None - - -class MetadataFilteringCondition(BaseModel): - """ - Metadata Filtering Condition. - """ - - logical_operator: Literal["and", "or"] | None = "and" - conditions: list[Condition] | None = Field(default=None, deprecated=True) - - class DatasetRetrieveConfigEntity(BaseModel): """ Dataset Retrieve Config Entity. diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index e96517c4264..8f20ef2ff96 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,9 +1,8 @@ from collections.abc import Mapping from typing import Any -from graphon.file import FileUploadConfig - from constants import DEFAULT_FILE_NUMBER_LIMITS +from graphon.file import FileUploadConfig class FileUploadConfigManager: @@ -30,7 +29,7 @@ class FileUploadConfigManager: return FileUploadConfig.model_validate(file_upload_dict) @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for file upload feature diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py index ef71bb348a0..b167c04ab59 100644 --- a/api/core/app/app_config/features/more_like_this/manager.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, ConfigDict, Field, ValidationError @@ -13,7 +15,7 @@ class AppConfigModel(BaseModel): class MoreLikeThisConfigManager: @classmethod - def convert(cls, config: dict) -> bool: + def convert(cls, config: dict[str, Any]) -> bool: """ Convert model config to model config @@ -23,7 +25,7 @@ class MoreLikeThisConfigManager: return AppConfigModel.model_validate(validated_config).more_like_this.enabled @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: try: return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"] except ValidationError: diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index 92b4185abf0..33f5aec1837 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,6 +1,9 @@ +from typing import Any + + class OpeningStatementConfigManager: @classmethod - def convert(cls, config: dict) -> tuple[str, list]: + def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]: """ Convert model config to model config @@ -15,7 +18,7 @@ class OpeningStatementConfigManager: return opening_statement, suggested_questions_list @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for opening statement feature diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py index d098abac2fa..8157fb41db9 100644 --- a/api/core/app/app_config/features/retrieval_resource/manager.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -1,6 +1,9 @@ +from typing import Any + + class RetrievalResourceConfigManager: @classmethod - def convert(cls, config: dict) -> bool: + def convert(cls, config: dict[str, Any]) -> bool: show_retrieve_source = False retriever_resource_dict = config.get("retriever_resource") if retriever_resource_dict: @@ -10,7 +13,7 @@ class RetrievalResourceConfigManager: return show_retrieve_source @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for retriever resource feature diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py index e10ae03e043..679b8c343be 100644 --- a/api/core/app/app_config/features/speech_to_text/manager.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -1,6 +1,9 @@ +from typing import Any + + class SpeechToTextConfigManager: @classmethod - def convert(cls, config: dict) -> bool: + def convert(cls, config: dict[str, Any]) -> bool: """ Convert model config to model config @@ -15,7 +18,7 @@ class SpeechToTextConfigManager: return speech_to_text @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for speech to text feature diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py index 9ac5114d12d..2dddce349c0 100644 --- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -1,6 +1,9 @@ +from typing import Any + + class SuggestedQuestionsAfterAnswerConfigManager: @classmethod - def convert(cls, config: dict) -> bool: + def convert(cls, config: dict[str, Any]) -> bool: """ Convert model config to model config @@ -15,7 +18,7 @@ class SuggestedQuestionsAfterAnswerConfigManager: return suggested_questions_after_answer @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for suggested questions feature diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py index 1c759817852..ca84ec9c3b5 100644 --- a/api/core/app/app_config/features/text_to_speech/manager.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -1,9 +1,11 @@ +from typing import Any + from core.app.app_config.entities import TextToSpeechEntity class TextToSpeechConfigManager: @classmethod - def convert(cls, config: dict): + def convert(cls, config: dict[str, Any]): """ Convert model config to model config @@ -22,7 +24,7 @@ class TextToSpeechConfigManager: return text_to_speech @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for text to speech feature diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 62e0c31d1ae..13ace32fd60 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,8 +1,7 @@ import re -from graphon.variables.input_entities import VariableEntity - from core.app.app_config.entities import RagPipelineVariableEntity +from graphon.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 985ded0f74d..9e64b471cb9 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -18,11 +18,6 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader - from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -48,6 +43,10 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index a884a1c7f9b..4e57b4dedc5 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,14 +3,8 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import Variable from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager @@ -43,6 +37,12 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable @@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): :return: List of conversation variables ready for use """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: existing_variables = self._load_existing_conversation_variables(session) if not existing_variables: @@ -376,7 +376,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # Convert to Variable objects for use in the workflow conversation_variables = [var.to_variable() for var in existing_variables] - session.commit() return cast(list[Variable], conversation_variables) def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 5c9bc439920..fe2702ed69a 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -57,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, Any, None]: + ) -> Generator[dict[str, Any] | str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -88,7 +88,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, Any, None]: + ) -> Generator[dict[str, Any] | str, Any, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 5203de225cc..78b582bdf5f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -9,14 +9,8 @@ from datetime import datetime from threading import Thread from typing import Any, Union -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -77,6 +71,12 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus @@ -328,13 +328,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): @contextmanager def _database_session(self): """Context manager for database sessions.""" - with Session(db.engine, expire_on_commit=False) as session: - try: - yield session - session.commit() - except Exception: - session.rollback() - raise + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + yield session def _ensure_workflow_initialized(self): """Fluent validation for workflow state.""" diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 5872f6b2647..5cdc477028c 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, In from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index a20d3f3c38f..cae0eee0df0 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,9 +1,6 @@ import logging from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -19,6 +16,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 0c146c388fd..731c6ee12e6 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 66390116d46..d5edfaeb25c 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -3,11 +3,10 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from typing import Any, Union -from graphon.model_runtime.errors.invoke import InvokeError - from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) @@ -24,7 +23,7 @@ class AppGenerateResponseConverter(ABC): return cls.convert_blocking_full_response(response) else: - def _generate_full_response() -> Generator[dict | str, Any, None]: + def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]: yield from cls.convert_stream_full_response(response) return _generate_full_response() @@ -33,7 +32,7 @@ class AppGenerateResponseConverter(ABC): return cls.convert_blocking_simple_response(response) else: - def _generate_simple_response() -> Generator[dict | str, Any, None]: + def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]: yield from cls.convert_stream_simple_response(response) return _generate_simple_response() @@ -52,14 +51,14 @@ class AppGenerateResponseConverter(ABC): @abstractmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: raise NotImplementedError @classmethod @abstractmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: raise NotImplementedError @classmethod @@ -107,13 +106,13 @@ class AppGenerateResponseConverter(ABC): return metadata @classmethod - def _error_to_stream_response(cls, e: Exception): + def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]: """ Error to stream response. :param e: exception :return: """ - error_responses = { + error_responses: dict[type[Exception], dict[str, Any]] = { ValueError: {"code": "invalid_param", "status": 400}, ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, QuotaExceededError: { @@ -127,7 +126,7 @@ class AppGenerateResponseConverter(ABC): } # Determine the response based on the type of exception - data = None + data: dict[str, Any] | None = None for k, v in error_responses.items(): if isinstance(e, k): data = v diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 7eccd59d17c..8e8ccf2b903 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -2,9 +2,6 @@ from collections.abc import Generator, Mapping, Sequence from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final -from graphon.enums import NodeType -from graphon.file import File, FileUploadConfig -from graphon.variables.input_entities import VariableEntityType from sqlalchemy.orm import Session from core.app.apps.draft_variable_saver import ( @@ -16,6 +13,9 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from extensions.ext_database import db from factories import file_factory +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 20bf81aeecf..d1771452c5f 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,7 +7,6 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod -from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -22,6 +21,7 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from extensions.ext_redis import redis_client +from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 4aebc0cb30e..1251b397e27 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,17 +5,6 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError - from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -41,6 +30,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 891dcece73d..58afefe2968 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeF from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 050f763e958..077c5239f39 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,8 +1,6 @@ import logging from typing import cast -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -18,6 +16,8 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index f23ee7f89fd..3d0375151dc 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index ab277857fe5..2a90fbdad0e 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,9 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from graphon.runtime import GraphRuntimeState - from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.runtime import GraphRuntimeState if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index a5155316163..bd685d51894 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -6,19 +6,6 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import FILE_MODEL_IDENTITY, File -from graphon.runtime import GraphRuntimeState -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.variables import Variable -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -68,6 +55,19 @@ from core.workflow.human_input_forms import load_form_tokens_by_form_id from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 61339b316ae..423bfdac519 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, I from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index b216f7cf7b1..6bb1ecdcb19 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,8 +1,6 @@ import logging from typing import cast -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -16,6 +14,8 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index a4f574642d6..71886b39ba4 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py index cfacd8640d6..02b3160b7c2 100644 --- a/api/core/app/apps/pipeline/generate_response_converter.py +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = WorkflowAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): return dict(blocking_response.model_dump()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py index 72b7f4bef6f..8bbd7455382 100644 --- a/api/core/app/apps/pipeline/pipeline_config_manager.py +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig @@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager): return pipeline_config @classmethod - def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + def config_validate( + cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False + ) -> dict[str, Any]: """ Validate for pipeline config diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 139c7e73e00..4b2f17189bc 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,8 +10,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, cast, overload from flask import Flask, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -43,6 +41,8 @@ from core.repositories.factory import ( WorkflowNodeExecutionRepository, ) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline @@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator): user_id: str, all_files: list, datasource_info: Mapping[str, Any], - next_page_parameters: dict | None = None, + next_page_parameters: dict[str, Any] | None = None, ): """ Get files in a folder. diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index b4d2310da85..2ee0ae27ebc 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,13 +2,6 @@ import logging import time from typing import cast -from graphon.entities import GraphInitParams -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -22,11 +15,17 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id from core.workflow.system_variables import build_bootstrap_variables, build_system_variables from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow @@ -265,22 +264,23 @@ class PipelineRunner(WorkflowBasedAppRunner): # graph_config["nodes"] = real_run_nodes # graph_config["edges"] = real_edges # init graph - # Create required parameters for Graph.init - graph_init_params = GraphInitParams( + # Create explicit graph init context for Graph.init. + run_context = build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id=self.application_generate_entity.user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow.id, graph_config=graph_config, - run_context=build_dify_run_context( - tenant_id=workflow.tenant_id, - app_id=self._app_id, - user_id=self.application_generate_entity.user_id, - user_from=user_from, - invoke_from=invoke_from, - ), + run_context=run_context, call_depth=0, ) - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) if start_node_id is None: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6074e81d1e2..6937014a066 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,10 +8,6 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, overload from flask import Flask, current_app -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -38,6 +34,10 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 2cb8088971a..cfb9208486f 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,12 +3,6 @@ import time from collections.abc import Sequence from typing import cast -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader - from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -21,6 +15,11 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index c64f44a603d..c69826cbefa 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -37,7 +37,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -66,7 +66,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 49af169e88a..15645add575 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,10 +4,7 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager @@ -61,6 +58,9 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser @@ -252,13 +252,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): @contextmanager def _database_session(self): """Context manager for database sessions.""" - with Session(db.engine, expire_on_commit=False) as session: - try: - yield session - session.commit() - except Exception: - session.rollback() - raise + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + yield session def _ensure_workflow_initialized(self): """Fluent validation for workflow state.""" @@ -687,15 +682,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None): invoke_from = self._application_generate_entity.invoke_from - if invoke_from == InvokeFrom.SERVICE_API: - created_from = WorkflowAppLogCreatedFrom.SERVICE_API - elif invoke_from == InvokeFrom.EXPLORE: - created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP - elif invoke_from == InvokeFrom.WEB_APP: - created_from = WorkflowAppLogCreatedFrom.WEB_APP - else: - # not save log for debugging - return + match invoke_from: + case InvokeFrom.SERVICE_API: + created_from = WorkflowAppLogCreatedFrom.SERVICE_API + case InvokeFrom.EXPLORE: + created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP + case InvokeFrom.WEB_APP: + created_from = WorkflowAppLogCreatedFrom.WEB_APP + case InvokeFrom.DEBUGGER | InvokeFrom.TRIGGER | InvokeFrom.PUBLISHED_PIPELINE | InvokeFrom.VALIDATION: + # not save log for debugging + return if not workflow_run_id: return diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index f68c8e60b4f..047b54c86c9 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,7 +3,52 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast -from graphon.entities import GraphInitParams +from pydantic import ValidationError + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.agent_strategy import AgentStrategyInfo +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueAgentLogEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueLoopCompletedEvent, + QueueLoopNextEvent, + QueueLoopStartEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueRetrieverResourcesEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.rag.entities import RetrievalSourceMetadata +from core.workflow.node_factory import ( + DifyGraphInitContext, + DifyNodeFactory, + get_default_root_node_id, + resolve_workflow_node_class, +) +from core.workflow.system_variables import ( + build_bootstrap_variables, + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.entities.pause_reason import HumanInputRequired from graphon.graph import Graph @@ -37,47 +82,6 @@ from graphon.graph_events import ( ) from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool -from pydantic import ValidationError - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.agent_strategy import AgentStrategyInfo -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueAgentLogEvent, - QueueHumanInputFormFilledEvent, - QueueHumanInputFormTimeoutEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueLoopCompletedEvent, - QueueLoopNextEvent, - QueueLoopStartEvent, - QueueNodeExceptionEvent, - QueueNodeFailedEvent, - QueueNodeRetryEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueRetrieverResourcesEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowPartialSuccessEvent, - QueueWorkflowPausedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class -from core.workflow.system_variables import ( - build_bootstrap_variables, - default_system_variables, - get_node_creation_preload_selectors, - inject_default_system_variable_mappings, - preload_node_creation_variables, -) -from core.workflow.variable_pool_initializer import add_variables_to_pool -from core.workflow.workflow_entry import WorkflowEntry -from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @@ -127,24 +131,25 @@ class WorkflowBasedAppRunner: if not isinstance(graph_config.get("edges"), list): raise ValueError("edges in workflow graph must be a list") - # Create required parameters for Graph.init - graph_init_params = GraphInitParams( + # Create explicit graph init context for Graph.init. + run_context = build_dify_run_context( + tenant_id=tenant_id or "", + app_id=self._app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow_id, graph_config=graph_config, - run_context=build_dify_run_context( - tenant_id=tenant_id or "", - app_id=self._app_id, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - ), + run_context=run_context, call_depth=0, ) # Use the provided graph_runtime_state for consistent state management - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) @@ -289,22 +294,23 @@ class WorkflowBasedAppRunner: typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs] - # Create required parameters for Graph.init - graph_init_params = GraphInitParams( + # Create explicit graph init context for Graph.init. + run_context = build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow.id, graph_config=graph_config, - run_context=build_dify_run_context( - tenant_id=workflow.tenant_id, - app_id=self._app_id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ), + run_context=run_context, call_depth=0, ) - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 0cdbb5f50a1..09992f4bbf4 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,14 +1,14 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any -from graphon.file import File, FileUploadConfig -from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager @@ -131,7 +131,7 @@ class AppGenerateEntity(BaseModel): extras: dict[str, Any] = Field(default_factory=dict) # tracing instance - trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False) + trace_manager: "TraceQueueManager | None" = Field(default=None, exclude=True, repr=False) class EasyUIBasedAppGenerateEntity(AppGenerateEntity): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 5e56341f892..221b7fb058b 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any +from pydantic import BaseModel, ConfigDict, Field + +from core.app.entities.agent_strategy import AgentStrategyInfo +from core.rag.entities import RetrievalSourceMetadata from graphon.entities import WorkflowStartReason from graphon.entities.pause_reason import PauseReason from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from pydantic import BaseModel, ConfigDict, Field - -from core.app.entities.agent_strategy import AgentStrategyInfo -from core.rag.entities.citation_metadata import RetrievalSourceMetadata class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index ba3b2e356f7..6e4ca69cf00 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,14 +2,14 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any +from pydantic import BaseModel, ConfigDict, Field + +from core.app.entities.agent_strategy import AgentStrategyInfo +from core.rag.entities import RetrievalSourceMetadata from graphon.entities import WorkflowStartReason from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage from graphon.nodes.human_input.entities import FormInput, UserAction -from pydantic import BaseModel, ConfigDict, Field - -from core.app.entities.agent_strategy import AgentStrategyInfo -from core.rag.entities.citation_metadata import RetrievalSourceMetadata class AnnotationReplyAccount(BaseModel): @@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = Field(default_factory=dict) + extras: dict[str, Any] = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} inputs_truncated: bool = False @@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse): title: str index: int created_at: int - extras: dict = Field(default_factory=dict) + extras: dict[str, Any] = Field(default_factory=dict) event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse): outputs: Mapping | None = None outputs_truncated: bool = False created_at: int - extras: dict | None = None + extras: dict[str, Any] | None = None inputs: Mapping | None = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus @@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = Field(default_factory=dict) + extras: dict[str, Any] = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} inputs_truncated: bool = False @@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse): outputs: Mapping | None = None outputs_truncated: bool = False created_at: int - extras: dict | None = None + extras: dict[str, Any] | None = None inputs: Mapping | None = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d2d2fea4fb8..d59f5125e3c 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,9 +1,8 @@ import logging -from graphon.model_runtime.entities.message_entities import PromptMessage - from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation +from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py index 80d504ef1c0..a583301f9b7 100644 --- a/api/core/app/file_access/scope.py +++ b/api/core/app/file_access/scope.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Generator # Changed from Iterator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass @@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None: @contextmanager -def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]: +def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None] token = _current_file_access_scope.set(scope) try: yield diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index e09869f5f8f..d5e6b04a4a6 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -9,11 +9,10 @@ scope updates that matter to chat applications. import logging -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent - from core.workflow.system_variables import SystemVariableKey, get_system_text from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index c027f427882..9811f9f8308 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,14 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index 8c8daf87122..bb9fc1b6fa7 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore + from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent - from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index 77c7bec67e6..b60fe82ffe7 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,12 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index 278d0cb30b5..c49c4eb0acf 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,16 +2,15 @@ from __future__ import annotations from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.llm.entities import ModelConfig -from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from graphon.nodes.llm.protocols import CredentialsProvider - from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 182f1b767dc..b6039e1e4e3 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,6 +1,5 @@ -from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from core.entities.model_entities import ModelStatus @@ -8,6 +7,7 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID @@ -57,37 +57,37 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL used_quota = 1 if used_quota is not None and system_configuration.current_quota_type is not None: - if system_configuration.current_quota_type == ProviderQuotaType.TRIAL: - from services.credit_pool_service import CreditPoolService + match system_configuration.current_quota_type: + case ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - ) - elif system_configuration.current_quota_type == ProviderQuotaType.PAID: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - pool_type="paid", - ) - else: - with Session(db.engine) as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=naive_utc_now(), - ) + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, ) - session.execute(stmt) - session.commit() + case ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + case ProviderQuotaType.FREE: + with sessionmaker(bind=db.engine).begin() as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=naive_utc_now(), + ) + ) + session.execute(stmt) diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 10b9c36d3e2..9e688589db7 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,7 +1,6 @@ import logging import time -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,6 +17,7 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index e0e6a6f5c36..e2e07ebaff5 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -4,15 +4,8 @@ from collections.abc import Generator from threading import Thread from typing import Any, cast -from graphon.file import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -60,6 +53,13 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile @@ -266,9 +266,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): event = message.event if isinstance(event, QueueErrorEvent): - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: err = self.handle_error(event=event, session=session, message_id=self._message_id) - session.commit() yield self.error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): @@ -288,10 +287,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): answer=output_moderation_answer ) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # Save message self._save_message(session=session, trace_manager=trace_manager) - session.commit() message_end_resp = self._message_end_to_stream_response() yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): @@ -509,8 +507,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :return: """ with Session(db.engine, expire_on_commit=False) as session: - agent_thought: MessageAgentThought | None = ( - session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() + agent_thought: MessageAgentThought | None = session.scalar( + select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1) ) if agent_thought: diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index b23a33923b3..1dd713821f4 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,9 +1,8 @@ from typing import TypedDict +from core.tools.signature import sign_tool_file from graphon.file import FileTransferMethod from graphon.file import helpers as file_helpers - -from core.tools.signature import sign_tool_file from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 @@ -40,41 +39,44 @@ def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, Upl size = 0 extension = "" - if message_file.transfer_method == FileTransferMethod.REMOTE_URL: - url = message_file.url - if message_file.url: - filename = message_file.url.split("/")[-1].split("?")[0] - if "." in filename: - extension = "." + filename.rsplit(".", 1)[1] - elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE: - if upload_file: - url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) - filename = upload_file.name - mime_type = upload_file.mime_type or "application/octet-stream" - size = upload_file.size or 0 - extension = f".{upload_file.extension}" if upload_file.extension else "" - elif message_file.upload_file_id: - url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) - elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url: - if message_file.url.startswith(("http://", "https://")): + match message_file.transfer_method: + case FileTransferMethod.REMOTE_URL: url = message_file.url - filename = message_file.url.split("/")[-1].split("?")[0] - if "." in filename: - extension = "." + filename.rsplit(".", 1)[1] - else: - url_parts = message_file.url.split("/") - if url_parts: - file_part = url_parts[-1].split("?")[0] - if "." in file_part: - tool_file_id, ext = file_part.rsplit(".", 1) - extension = f".{ext}" - if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH: + if message_file.url: + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + case FileTransferMethod.LOCAL_FILE: + if upload_file: + url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) + filename = upload_file.name + mime_type = upload_file.mime_type or "application/octet-stream" + size = upload_file.size or 0 + extension = f".{upload_file.extension}" if upload_file.extension else "" + elif message_file.upload_file_id: + url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) + case FileTransferMethod.TOOL_FILE if message_file.url: + if message_file.url.startswith(("http://", "https://")): + url = message_file.url + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + else: + url_parts = message_file.url.split("/") + if url_parts: + file_part = url_parts[-1].split("?")[0] + if "." in file_part: + tool_file_id, ext = file_part.rsplit(".", 1) + extension = f".{ext}" + if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH: + extension = ".bin" + else: + tool_file_id = file_part extension = ".bin" - else: - tool_file_id = file_part - extension = ".bin" - url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) - filename = file_part + url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) + filename = file_part + case FileTransferMethod.TOOL_FILE | FileTransferMethod.DATASOURCE_FILE: + pass transfer_method_value = message_file.transfer_method.value remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else "" diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 8604235ef28..3a6f9d575ae 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -9,17 +9,17 @@ import urllib.parse from collections.abc import Generator from typing import TYPE_CHECKING, Literal -from graphon.file import FileTransferMethod -from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from graphon.file.runtime import set_workflow_file_runtime - from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage +from graphon.file import FileTransferMethod +from graphon.file.protocols import WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime +from graphon.http.protocols import HttpResponseProtocol if TYPE_CHECKING: from graphon.file import File @@ -44,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): return dify_config.MULTIMODAL_SEND_FORMAT def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - return ssrf_proxy.get(url, follow_redirects=follow_redirects) + return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects) def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index c577ce07541..4a7918032ee 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,17 +7,16 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final, override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm import deduct_llm_quota, ensure_llm_quota_available +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent from graphon.nodes.base.node import Node -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.llm import deduct_llm_quota, ensure_llm_quota_available -from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance - if TYPE_CHECKING: from graphon.nodes.llm.node import LLMNode from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 99e8015c0b2..8b5a5b9d7ff 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -12,10 +12,6 @@ from contextvars import Token from dataclasses import dataclass from typing import cast, final, override -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node from opentelemetry import context as context_api from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context @@ -28,6 +24,10 @@ from extensions.otel.parser import ( ToolNodeOTelParser, ) from extensions.otel.runtime import is_instrument_flag_enabled +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index ada065a9433..d521304615a 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,6 +14,13 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from graphon.entities import WorkflowExecution, WorkflowNodeExecution from graphon.enums import ( WorkflowExecutionStatus, @@ -38,14 +45,6 @@ from graphon.graph_events import ( NodeRunSucceededEvent, ) from graphon.node_events import NodeRunResult - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now @@ -350,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs - execution.exceptions_count = runtime_state.exceptions_count + execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count) def _update_node_execution( self, diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 3d8a7a54f31..9e3c1872108 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,9 +6,6 @@ import re import threading from collections.abc import Iterable -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelType - from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -18,6 +15,8 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 6a071192447..205e0042901 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -6,7 +6,7 @@ from sqlalchemy import select, update from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent -from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.entities import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document from extensions.ext_database import db diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 143d1e696bf..f0dcb13b622 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,9 +3,6 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts @@ -31,6 +28,9 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService @@ -345,18 +345,18 @@ class DatasourceManager: @classmethod def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: with session_factory.create_session() as session: - upload_file = ( - session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first() + upload_file = session.scalar( + select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1) ) if not upload_file: raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") file_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.CUSTOM, + file_type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference(record_id=str(upload_file.id)), diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 14d1af2e8b4..352e6bfd492 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ -from typing import Literal, Optional +from typing import Any, Literal, TypedDict -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter -from core.tools.entities.common_entities import I18nObject +from core.tools.entities.common_entities import I18nObject, I18nObjectDict +from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): @@ -14,10 +14,27 @@ class DatasourceApiEntity(BaseModel): description: I18nObject parameters: list[DatasourceParameter] | None = None labels: list[str] = Field(default_factory=list) - output_schema: dict | None = None + output_schema: dict[str, Any] | None = None -ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] +ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None + + +class DatasourceProviderApiEntityDict(TypedDict): + id: str + author: str + name: str + plugin_id: str | None + plugin_unique_identifier: str | None + description: I18nObjectDict + icon: str | dict + label: I18nObjectDict + type: str + team_credentials: dict[str, Any] | None + is_team_authorization: bool + allow_delete: bool + datasources: list[Any] + labels: list[str] class DatasourceProviderApiEntity(BaseModel): @@ -28,8 +45,8 @@ class DatasourceProviderApiEntity(BaseModel): icon: str | dict label: I18nObject # label type: str - masked_credentials: dict | None = None - original_credentials: dict | None = None + masked_credentials: dict[str, Any] | None = None + original_credentials: dict[str, Any] | None = None is_team_authorization: bool = False allow_delete: bool = True plugin_id: str | None = Field(default="", description="The plugin id of the datasource") @@ -42,7 +59,7 @@ class DatasourceProviderApiEntity(BaseModel): def convert_none_to_empty_list(cls, v): return v if v is not None else [] - def to_dict(self) -> dict: + def to_dict(self) -> DatasourceProviderApiEntityDict: # ------------- # overwrite datasource parameter types for temp fix datasources = jsonable_encoder(self.datasources) @@ -53,7 +70,7 @@ class DatasourceProviderApiEntity(BaseModel): parameter["type"] = "files" # ------------- - return { + result: DatasourceProviderApiEntityDict = { "id": self.id, "author": self.author, "name": self.name, @@ -69,3 +86,4 @@ class DatasourceProviderApiEntity(BaseModel): "datasources": datasources, "labels": self.labels, } + return result diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py index 3c64632dbba..726dafaa62a 100644 --- a/api/core/datasource/entities/common_entities.py +++ b/api/core/datasource/entities/common_entities.py @@ -1,22 +1,3 @@ -from pydantic import BaseModel, Field, model_validator +from core.tools.entities.common_entities import I18nObject, I18nObjectDict - -class I18nObject(BaseModel): - """ - Model class for i18n object. - """ - - en_US: str - zh_Hans: str | None = Field(default=None) - pt_BR: str | None = Field(default=None) - ja_JP: str | None = Field(default=None) - - @model_validator(mode="after") - def _(self): - self.zh_Hans = self.zh_Hans or self.en_US - self.pt_BR = self.pt_BR or self.en_US - self.ja_JP = self.ja_JP or self.en_US - return self - - def to_dict(self) -> dict: - return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} +__all__ = ["I18nObject", "I18nObjectDict"] diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index a063a3680bc..443b503a693 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -2,14 +2,14 @@ from __future__ import annotations import enum from enum import StrEnum -from typing import Any +from typing import Any, TypedDict from pydantic import BaseModel, Field, ValidationInfo, field_validator from yarl import URL from configs import dify_config from core.entities.provider_entities import ProviderConfig -from core.plugin.entities.oauth import OAuthSchema +from core.plugin.entities import OAuthSchema from core.plugin.entities.parameters import ( PluginParameter, PluginParameterOption, @@ -129,7 +129,7 @@ class DatasourceEntity(BaseModel): identity: DatasourceIdentity parameters: list[DatasourceParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The label of the datasource") - output_schema: dict | None = None + output_schema: dict[str, Any] | None = None @field_validator("parameters", mode="before") @classmethod @@ -179,6 +179,12 @@ class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): datasources: list[DatasourceEntity] = Field(default_factory=list) +class DatasourceInvokeMetaDict(TypedDict): + time_cost: float + error: str | None + tool_config: dict[str, Any] | None + + class DatasourceInvokeMeta(BaseModel): """ Datasource invoke meta @@ -186,7 +192,7 @@ class DatasourceInvokeMeta(BaseModel): time_cost: float = Field(..., description="The time cost of the tool invoke") error: str | None = None - tool_config: dict | None = None + tool_config: dict[str, Any] | None = None @classmethod def empty(cls) -> DatasourceInvokeMeta: @@ -202,12 +208,13 @@ class DatasourceInvokeMeta(BaseModel): """ return cls(time_cost=0.0, error=error, tool_config={}) - def to_dict(self) -> dict: - return { + def to_dict(self) -> DatasourceInvokeMetaDict: + result: DatasourceInvokeMetaDict = { "time_cost": self.time_cost, "error": self.error, "tool_config": self.tool_config, } + return result class DatasourceLabel(BaseModel): @@ -235,7 +242,7 @@ class OnlineDocumentPage(BaseModel): page_id: str = Field(..., description="The page id") page_name: str = Field(..., description="The page title") - page_icon: dict | None = Field(None, description="The page icon") + page_icon: dict[str, Any] | None = Field(None, description="The page icon") type: str = Field(..., description="The type of the page") last_edited_time: str = Field(..., description="The last edited time") parent_id: str | None = Field(None, description="The parent page id") @@ -294,7 +301,7 @@ class GetWebsiteCrawlRequest(BaseModel): Get website crawl request """ - crawl_parameters: dict = Field(..., description="The crawl parameters") + crawl_parameters: dict[str, Any] = Field(..., description="The crawl parameters") class WebSiteInfoDetail(BaseModel): @@ -351,7 +358,7 @@ class OnlineDriveFileBucket(BaseModel): bucket: str | None = Field(None, description="The file bucket") files: list[OnlineDriveFile] = Field(..., description="The file list") is_truncated: bool = Field(False, description="Whether the result is truncated") - next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") + next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page") class OnlineDriveBrowseFilesRequest(BaseModel): @@ -362,7 +369,7 @@ class OnlineDriveBrowseFilesRequest(BaseModel): bucket: str | None = Field(None, description="The file bucket") prefix: str = Field(..., description="The parent folder ID") max_keys: int = Field(20, description="Page size for pagination") - next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") + next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page") class OnlineDriveBrowseFilesResponse(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 04f15dee31d..6a3f9e684a9 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,11 +2,10 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type -from graphon.file import File, FileTransferMethod, FileType - from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) @@ -71,8 +70,8 @@ class DatasourceFileMessageTransformer: if not isinstance(message.message, DatasourceMessage.BlobMessage): raise ValueError("unexpected message type") - # FIXME: should do a type check here. - assert isinstance(message.message.blob, bytes) + if not isinstance(message.message.blob, bytes): + raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}") tool_file_manager = ToolFileManager() blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw( user_id=user_id, diff --git a/api/core/entities/__init__.py b/api/core/entities/__init__.py index b848da36649..e77eac87ba9 100644 --- a/api/core/entities/__init__.py +++ b/api/core/entities/__init__.py @@ -1 +1,8 @@ +from core.entities.plugin_credential_type import PluginCredentialType + DEFAULT_PLUGIN_ID = "langgenius" + +__all__ = [ + "DEFAULT_PLUGIN_ID", + "PluginCredentialType", +] diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index d304c982cdb..04ae193396b 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias -from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field +from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index b1ba3c3e2a6..a13938f3fbd 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field, field_validator @@ -37,7 +39,7 @@ class PipelineDocument(BaseModel): id: str position: int data_source_type: str - data_source_info: dict | None = None + data_source_info: dict[str, Any] | None = None name: str indexing_status: str error: str | None = None diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index a440829b46b..bfa4f569155 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,7 +6,6 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -16,6 +15,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 84d95c38c6c..e99a131500f 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,10 +1,11 @@ from collections.abc import Sequence from enum import StrEnum, auto +from pydantic import BaseModel, ConfigDict + from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel from graphon.model_runtime.entities.provider_entities import ProviderEntity -from pydantic import BaseModel, ConfigDict class ModelStatus(StrEnum): diff --git a/api/core/entities/plugin_credential_type.py b/api/core/entities/plugin_credential_type.py new file mode 100644 index 00000000000..005e92473c8 --- /dev/null +++ b/api/core/entities/plugin_credential_type.py @@ -0,0 +1,9 @@ +import enum + + +class PluginCredentialType(enum.Enum): + MODEL = 0 # must be 0 for API contract compatibility + TOOL = 1 # must be 1 for API contract compatibility + + def to_number(self): + return self.value diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 782897aea98..38b87e2cd1f 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -6,22 +6,14 @@ import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError +from typing import Any -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session from constants import HIDDEN_VALUE +from core.entities import PluginCredentialType from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.provider_entities import ( CustomConfiguration, @@ -32,6 +24,16 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -46,7 +48,6 @@ from models.provider import ( TenantPreferredModelProvider, ) from models.provider_ids import ModelProviderID -from services.enterprise.plugin_manager_service import PluginCredentialType logger = logging.getLogger(__name__) @@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel): return ModelProviderFactory(model_runtime=self._bound_model_runtime) return create_plugin_model_provider_factory(tenant_id=self.tenant_id) - def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: + def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None: """ Get current credentials. @@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel): return session.execute(stmt).scalar_one_or_none() - def _get_specific_provider_credential(self, credential_id: str) -> dict | None: + def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID @@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel): stmt = stmt.where(ProviderCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None - def get_provider_credential(self, credential_id: str | None = None) -> dict | None: + def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None: """ Get provider credentials. @@ -317,32 +318,28 @@ class ProviderConfiguration(BaseModel): else [], ) - def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): + def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""): """ Validate custom credentials. :param credentials: provider credentials :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate - :param session: optional database session :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.id == credential_id, ) - credential_record = s.execute(stmt).scalar_one_or_none() - # fix origin data + credential_record = session.execute(stmt).scalar_one_or_none() if credential_record and credential_record.encrypted_config: if not credential_record.encrypted_config.startswith("{"): original_credentials = {"openai_api_key": credential_record.encrypted_config} @@ -353,31 +350,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def _generate_provider_credential_name(self, session) -> str: """ @@ -447,21 +436,23 @@ class ProviderConfiguration(BaseModel): provider_names.append(model_provider_id.provider_name) return provider_names - def create_provider_credential(self, credentials: dict, credential_name: str | None): + def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None): """ Add custom provider credentials. :param credentials: provider credentials :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: - if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session): raise ValueError(f"Credential with name '{credential_name}' already exists.") else: - credential_name = self._generate_provider_credential_name(session) + credential_name = self._generate_provider_credential_name(pre_session) - credentials = self.validate_provider_credentials(credentials=credentials, session=session) + credentials = self.validate_provider_credentials(credentials=credentials) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) try: new_record = ProviderCredential( @@ -474,7 +465,6 @@ class ProviderConfiguration(BaseModel): session.flush() if not provider_record: - # If provider record does not exist, create it provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, @@ -515,7 +505,7 @@ class ProviderConfiguration(BaseModel): def update_provider_credential( self, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ): @@ -527,15 +517,15 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_provider_credential_name_exists( - credential_name=credential_name, session=session, exclude_id=credential_id + credential_name=credential_name, session=pre_session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") - credentials = self.validate_provider_credentials( - credentials=credentials, credential_id=credential_id, session=session - ) + credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, @@ -543,12 +533,10 @@ class ProviderConfiguration(BaseModel): ProviderCredential.provider_name.in_(self._get_provider_names()), ) - # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: @@ -760,7 +748,7 @@ class ProviderConfiguration(BaseModel): def _get_specific_custom_model_credential( self, model_type: ModelType, model: str, credential_id: str - ) -> dict | None: + ) -> dict[str, Any] | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID @@ -832,7 +820,9 @@ class ProviderConfiguration(BaseModel): stmt = stmt.where(ProviderModelCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None - def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None: + def get_custom_model_credential( + self, model_type: ModelType, model: str, credential_id: str | None + ) -> dict[str, Any] | None: """ Get custom model credentials. @@ -872,9 +862,8 @@ class ProviderConfiguration(BaseModel): self, model_type: ModelType, model: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str = "", - session: Session | None = None, ): """ Validate custom model credentials. @@ -885,16 +874,14 @@ class ProviderConfiguration(BaseModel): :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, @@ -903,7 +890,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type, ) - credential_record = s.execute(stmt).scalar_one_or_none() + credential_record = session.execute(stmt).scalar_one_or_none() original_credentials = ( json.loads(credential_record.encrypted_config) if credential_record and credential_record.encrypted_config @@ -912,34 +899,26 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def create_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None + self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None ) -> None: """ Create a custom model credential. @@ -949,20 +928,22 @@ class ProviderConfiguration(BaseModel): :param credentials: model credentials dict :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: if self._check_custom_model_credential_name_exists( - model=model, model_type=model_type, credential_name=credential_name, session=session + model=model, model_type=model_type, credential_name=credential_name, session=pre_session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") else: credential_name = self._generate_custom_model_credential_name( - model=model, model_type=model_type, session=session + model=model, model_type=model_type, session=pre_session ) - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, model=model, credentials=credentials, session=session - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, model=model, credentials=credentials + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) try: @@ -977,7 +958,6 @@ class ProviderConfiguration(BaseModel): session.add(credential) session.flush() - # save provider model if not provider_model_record: provider_model_record = ProviderModel( tenant_id=self.tenant_id, @@ -1002,7 +982,12 @@ class ProviderConfiguration(BaseModel): raise def update_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str + self, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + credential_name: str | None, + credential_id: str, ) -> None: """ Update a custom model credential. @@ -1014,23 +999,24 @@ class ProviderConfiguration(BaseModel): :param credential_id: credential id :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, - session=session, + session=pre_session, exclude_id=credential_id, ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, - model=model, - credentials=credentials, - credential_id=credential_id, - session=session, - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, + model=model, + credentials=credentials, + credential_id=credential_id, + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) stmt = select(ProviderModelCredential).where( @@ -1045,7 +1031,6 @@ class ProviderConfiguration(BaseModel): raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: @@ -1412,7 +1397,9 @@ class ProviderConfiguration(BaseModel): # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) - def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None: + def get_model_schema( + self, model_type: ModelType, model: str, credentials: dict[str, Any] | None + ) -> AIModelEntity | None: """ Get model schema """ @@ -1471,7 +1458,7 @@ class ProviderConfiguration(BaseModel): return secret_input_form_variables - def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): + def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]): """ Obfuscated credentials. diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 2c8767a32b8..72b29c2277f 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,9 +1,8 @@ from __future__ import annotations from enum import StrEnum, auto -from typing import Union +from typing import Any, Union -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -13,6 +12,7 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): @@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel): enabled: bool current_quota_type: ProviderQuotaType | None = None quota_configurations: list[QuotaConfiguration] = [] - credentials: dict | None = None + credentials: dict[str, Any] | None = None class CustomProviderConfiguration(BaseModel): @@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel): Model class for provider custom configuration. """ - credentials: dict + credentials: dict[str, Any] current_credential_id: str | None = None current_credential_name: str | None = None available_credentials: list[CredentialConfiguration] = [] @@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType - credentials: dict | None + credentials: dict[str, Any] | None current_credential_id: str | None = None current_credential_name: str | None = None available_model_credentials: list[CredentialConfiguration] = [] @@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel): id: str name: str - credentials: dict + credentials: dict[str, Any] credential_source_type: str | None = None credential_id: str | None = None diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index f9e60990497..01139d07e27 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast import httpx @@ -14,7 +14,7 @@ class APIBasedExtensionRequestor: self.api_endpoint = api_endpoint self.api_key = api_key - def request(self, point: APIBasedExtensionPoint, params: dict): + def request(self, point: APIBasedExtensionPoint, params: dict[str, Any]) -> dict[str, Any]: """ Request the api. @@ -49,4 +49,4 @@ class APIBasedExtensionRequestor: if response.status_code != 200: raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}") - return cast(dict, response.json()) + return cast(dict[str, Any], response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index c2789a7a353..c08e319aaca 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -21,8 +21,8 @@ class ExtensionModule(StrEnum): class ModuleExtension(BaseModel): extension_class: Any | None = None name: str - label: dict | None = None - form_schema: list | None = None + label: dict[str, Any] | None = None + form_schema: list[dict[str, Any]] | None = None builtin: bool = True position: int | None = None @@ -32,9 +32,9 @@ class Extensible: name: str tenant_id: str - config: dict | None = None + config: dict[str, Any] | None = None - def __init__(self, tenant_id: str, config: dict | None = None): + def __init__(self, tenant_id: str, config: dict[str, Any] | None = None): self.tenant_id = tenant_id self.config = config diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 564801f1890..8ce068cfbbe 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any, TypedDict + from sqlalchemy import select from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor @@ -7,6 +10,16 @@ from extensions.ext_database import db from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +class ApiToolConfig(TypedDict, total=False): + """Expected config shape for ApiExternalDataTool. + + Not used directly in method signatures (base class accepts dict[str, Any]); + kept here to document the keys this tool reads from config. + """ + + api_based_extension_id: str + + class ApiExternalDataTool(ExternalDataTool): """ The api external data tool. @@ -16,7 +29,7 @@ class ApiExternalDataTool(ExternalDataTool): """the unique name of external data tool""" @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -37,7 +50,7 @@ class ApiExternalDataTool(ExternalDataTool): if not api_based_extension: raise ValueError("api_based_extension_id is invalid") - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index cbec2e4e421..12bea4e9e50 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -1,4 +1,6 @@ from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any from core.extension.extensible import Extensible, ExtensionModule @@ -15,14 +17,14 @@ class ExternalDataTool(Extensible, ABC): variable: str """the tool variable name of app tool""" - def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None): + def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict[str, Any] | None = None): super().__init__(tenant_id, config) self.app_id = app_id self.variable = variable @classmethod @abstractmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -33,7 +35,7 @@ class ExternalDataTool(Extensible, ABC): raise NotImplementedError @abstractmethod - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 6c542d681b3..f404aa7286f 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension class ExternalDataToolFactory: - def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict): + def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]): extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) self.__extension_instance = extension_class( tenant_id=tenant_id, app_id=app_id, variable=variable, config=config ) @classmethod - def validate_config(cls, name: str, tenant_id: str, config: dict): + def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]) -> None: """ Validate the incoming form config data. diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 68ee7509eff..96a8cf2fc1b 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -5,7 +5,6 @@ from threading import Lock from typing import Any import httpx -from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -15,6 +14,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client +from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b96a9ce3808..38864a1830c 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -102,7 +102,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode() + inputs_json_str = dumps_with_segments(inputs).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/credential_utils.py b/api/core/helper/credential_utils.py index 240f498181d..882639a16a9 100644 --- a/api/core/helper/credential_utils.py +++ b/api/core/helper/credential_utils.py @@ -2,7 +2,7 @@ Credential utility functions for checking credential existence and policy compliance. """ -from services.enterprise.plugin_manager_service import PluginCredentialType +from core.entities import PluginCredentialType def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool: diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 00fcfe0b80b..10d79a82392 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -1,6 +1,7 @@ import json from enum import StrEnum from json import JSONDecodeError +from typing import Any from extensions.ext_redis import redis_client @@ -15,7 +16,7 @@ class ProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - def get(self) -> dict | None: + def get(self) -> dict[str, Any] | None: """ Get cached model provider credentials. @@ -33,7 +34,7 @@ class ProviderCredentialsCache: else: return None - def set(self, credentials: dict): + def set(self, credentials: dict[str, Any]): """ Cache model provider credentials. diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index a1e782a094e..f169f247cff 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,14 +2,13 @@ import logging import secrets from typing import cast -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index ffb51483868..9f167ca49c1 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC): """Generate cache key based on subclass implementation""" pass - def get(self) -> dict | None: + def get(self) -> dict[str, Any] | None: """Get cached provider credentials""" cached_credentials = redis_client.get(self.cache_key) if cached_credentials: @@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): class NoOpProviderCredentialCache: """No-op provider credential cache""" - def get(self) -> dict | None: + def get(self) -> dict[str, Any] | None: """Get cached provider credentials""" return None diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index e38592bb7bb..91e92712b7c 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client from core.tools.errors import ToolSSRFError +from graphon.http.response import HttpResponse logger = logging.getLogger(__name__) @@ -267,4 +268,47 @@ class SSRFProxy: return patch(url=url, max_retries=max_retries, **kwargs) +def _to_graphon_http_response(response: httpx.Response) -> HttpResponse: + """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper.""" + return HttpResponse( + status_code=response.status_code, + headers=dict(response.headers), + content=response.content, + url=str(response.url) if response.url else None, + reason_phrase=response.reason_phrase, + fallback_text=response.text, + ) + + +class GraphonSSRFProxy: + """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``.""" + + @property + def max_retries_exceeded_error(self) -> type[Exception]: + return max_retries_exceeded_error + + @property + def request_error(self) -> type[Exception]: + return request_error + + def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs)) + + def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs)) + + def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs)) + + def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs)) + + def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs)) + + def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs)) + + ssrf_proxy = SSRFProxy() +graphon_ssrf_proxy = GraphonSSRFProxy() diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index 54674d4ff6b..bf5bf9af03b 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -1,6 +1,7 @@ import json from enum import StrEnum from json import JSONDecodeError +from typing import Any from extensions.ext_redis import redis_client @@ -18,7 +19,7 @@ class ToolParameterCache: f":identity_id:{identity_id}" ) - def get(self) -> dict | None: + def get(self) -> dict[str, Any] | None: """ Get cached model provider credentials. @@ -36,7 +37,7 @@ class ToolParameterCache: else: return None - def set(self, parameters: dict): + def set(self, parameters: dict[str, Any]): """Cache model provider credentials.""" redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 60f5434bc1e..8bcb899b231 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,12 @@ +from typing import Any + from flask import Flask -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel +from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): @@ -28,7 +30,7 @@ class FreeHostingQuota(HostingQuota): class HostingProvider(BaseModel): enabled: bool = False - credentials: dict | None = None + credentials: dict[str, Any] | None = None quota_unit: QuotaUnit | None = None quotas: list[HostingQuota] = [] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 8c911124f0c..9a6323fe32c 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -11,7 +11,6 @@ from typing import Any import psutil from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, func, select, update from sqlalchemy.orm.exc import ObjectDeletedError @@ -37,6 +36,7 @@ from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models import Account @@ -778,7 +778,9 @@ class IndexingRunner: @staticmethod def _update_document_index_status( - document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None + document_id: str, + after_indexing_status: IndexingStatus, + extra_update_params: Mapping[Any, Any] | None = None, ): """ Update the document indexing status. @@ -805,7 +807,7 @@ class IndexingRunner: db.session.commit() @staticmethod - def _update_segments_by_document(dataset_document_id: str, update_params: dict): + def _update_segments_by_document(dataset_document_id: str, update_params: Mapping[Any, Any]): """ Update the document segment by document id. """ diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index d39630ad951..348526b0efb 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -2,14 +2,9 @@ import json import logging import re from collections.abc import Sequence -from typing import Protocol, cast +from typing import Any, Protocol, TypedDict, cast import json_repair -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from core.app.app_config.entities import ModelConfig @@ -35,6 +30,11 @@ from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models import App, Message, WorkflowNodeExecutionModel from models.workflow import Workflow @@ -49,6 +49,17 @@ class WorkflowServiceInterface(Protocol): pass +class CodeGenerateResultDict(TypedDict): + code: str + language: str + error: str + + +class StructuredOutputResultDict(TypedDict): + output: str + error: str + + class LLMGenerator: @classmethod def generate_conversation_name( @@ -293,7 +304,7 @@ class LLMGenerator: cls, tenant_id: str, args: RuleCodeGeneratePayload, - ): + ) -> CodeGenerateResultDict: if args.code_language == "python": prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) else: @@ -362,7 +373,9 @@ class LLMGenerator: return answer.strip() @classmethod - def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): + def generate_structured_output( + cls, tenant_id: str, args: RuleStructuredOutputPayload + ) -> StructuredOutputResultDict: model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, @@ -454,7 +467,7 @@ class LLMGenerator: ): session = db.session() - app: App | None = session.query(App).where(App.id == flow_id).first() + app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1)) if not app: raise ValueError("App not found.") workflow = workflow_service.get_draft_workflow(app_model=app) @@ -520,7 +533,7 @@ class LLMGenerator: def __instruction_modify_common( tenant_id: str, model_config: ModelConfig, - last_run: dict | None, + last_run: dict[str, Any] | None, current: str | None, error_message: str | None, instruction: str, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index a1710f11ace..d2e375626f5 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,6 +5,11 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -21,11 +26,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance class ResponseFormat(StrEnum): @@ -200,9 +200,9 @@ def _handle_native_json_schema( provider: str, model_schema: AIModelEntity, structured_output_schema: Mapping, - model_parameters: dict, + model_parameters: dict[str, Any], rules: list[ParameterRule], -): +) -> dict[str, Any]: """ Handle structured output for models with native JSON schema support. @@ -224,7 +224,7 @@ def _handle_native_json_schema( return model_parameters -def _set_response_format(model_parameters: dict, rules: list): +def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None: """ Set the appropriate response format parameter based on model rules. @@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema return {"schema": processed_schema, "name": "llm_response"} -def remove_additional_properties(schema: dict): +def remove_additional_properties(schema: dict[str, Any]) -> None: """ Remove additionalProperties fields from JSON schema. Used for models like Gemini that don't support this property. @@ -349,7 +349,7 @@ def remove_additional_properties(schema: dict): remove_additional_properties(item) -def convert_boolean_to_string(schema: dict): +def convert_boolean_to_string(schema: dict[str, Any]) -> None: """ Convert boolean type specifications to string in JSON schema. diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py index 1e8aa8d566b..dee1432363c 100644 --- a/api/core/logging/filters.py +++ b/api/core/logging/filters.py @@ -6,6 +6,7 @@ import logging import flask from core.logging.context import get_request_id, get_trace_id +from core.logging.structured_formatter import IdentityDict class TraceContextFilter(logging.Filter): @@ -60,7 +61,7 @@ class IdentityContextFilter(logging.Filter): record.user_type = identity.get("user_type", "") return True - def _extract_identity(self) -> dict[str, str]: + def _extract_identity(self) -> IdentityDict: """Extract identity from current_user if in request context.""" try: if not flask.has_request_context(): @@ -77,7 +78,7 @@ class IdentityContextFilter(logging.Filter): from models import Account from models.model import EndUser - identity: dict[str, str] = {} + identity: IdentityDict = {} if isinstance(user, Account): if user.current_tenant_id: diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py index 4295d2dd349..ae7be91c173 100644 --- a/api/core/logging/structured_formatter.py +++ b/api/core/logging/structured_formatter.py @@ -3,13 +3,32 @@ import logging import traceback from datetime import UTC, datetime -from typing import Any +from typing import Any, NotRequired, TypedDict import orjson from configs import dify_config +class IdentityDict(TypedDict, total=False): + tenant_id: str + user_id: str + user_type: str + + +class LogDict(TypedDict): + ts: str + severity: str + service: str + caller: str + message: str + trace_id: NotRequired[str] + span_id: NotRequired[str] + identity: NotRequired[IdentityDict] + attributes: NotRequired[dict[str, Any]] + stack_trace: NotRequired[str] + + class StructuredJSONFormatter(logging.Formatter): """ JSON log formatter following the specified schema: @@ -49,9 +68,9 @@ class StructuredJSONFormatter(logging.Formatter): return json.dumps(log_dict, default=str, ensure_ascii=False) - def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: + def _build_log_dict(self, record: logging.LogRecord) -> LogDict: # Core fields - log_dict: dict[str, Any] = { + log_dict: LogDict = { "ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"), "severity": self.SEVERITY_MAP.get(record.levelno, "INFO"), "service": self._service_name, @@ -84,7 +103,7 @@ class StructuredJSONFormatter(logging.Formatter): return log_dict - def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None: + def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None: tenant_id = getattr(record, "tenant_id", None) user_id = getattr(record, "user_id", None) user_type = getattr(record, "user_type", None) @@ -92,7 +111,7 @@ class StructuredJSONFormatter(logging.Formatter): if not any([tenant_id, user_id, user_type]): return None - identity: dict[str, str] = {} + identity: IdentityDict = {} if tenant_id: identity["tenant_id"] = tenant_id if user_id: diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index d015769b54e..1d8356acf67 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -146,7 +146,7 @@ def discover_protected_resource_metadata( return ProtectedResourceMetadata.model_validate(response.json()) elif response.status_code == 404: continue # Try next URL - except (RequestError, ValidationError): + except (RequestError, ValidationError, json.JSONDecodeError): continue # Try next URL return None @@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata( return OAuthMetadata.model_validate(response.json()) elif response.status_code == 404: continue # Try next URL - except (RequestError, ValidationError): + except (RequestError, ValidationError, json.JSONDecodeError): continue # Try next URL return None @@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: else: return False, "" return False, "" - except RequestError: + except (RequestError, json.JSONDecodeError, IndexError): # Not support resource discovery, fall back to well-known OAuth metadata return False, "" diff --git a/api/core/mcp/auth_client.py b/api/core/mcp/auth_client.py index d8724b8de5b..173913196eb 100644 --- a/api/core/mcp/auth_client.py +++ b/api/core/mcp/auth_client.py @@ -122,7 +122,7 @@ class MCPClientWithAuthRetry(MCPClient): logger.exception("Authentication retry failed") raise MCPAuthError(f"Authentication retry failed: {e}") from e - def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any: + def _execute_with_retry[**P, R](self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Execute a function with authentication retry logic. diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 5c3cd0d8f80..acba3e666b1 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -303,9 +303,16 @@ class StreamableHTTPTransport: if response.status_code == 404: if isinstance(message.root, JSONRPCRequest): + error_msg = ( + f"MCP server URL returned 404 Not Found: {self.url} " + "— verify the server URL is correct and the server is running" + if is_initialization + else "Session terminated by server" + ) self._send_session_terminated_error( ctx.server_to_client_queue, message.root.id, + message=error_msg, ) return @@ -381,12 +388,13 @@ class StreamableHTTPTransport: self, server_to_client_queue: ServerToClientQueue, request_id: RequestId, + message: str = "Session terminated by server", ): """Send a session terminated error response.""" jsonrpc_error = JSONRPCError( jsonrpc="2.0", id=request_id, - error=ErrorData(code=32600, message="Session terminated by server"), + error=ErrorData(code=32600, message=message), ) session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) server_to_client_queue.put(session_message) diff --git a/api/core/mcp/entities.py b/api/core/mcp/entities.py index d6d3a677c68..21edc86a57f 100644 --- a/api/core/mcp/entities.py +++ b/api/core/mcp/entities.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import StrEnum -from typing import Any, TypeVar +from typing import Any from pydantic import BaseModel @@ -9,12 +9,9 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) -LifespanContextT = TypeVar("LifespanContextT") - @dataclass -class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]: +class RequestContext[SessionT: BaseSession, LifespanContextT]: request_id: RequestId meta: RequestParams.Meta | None session: SessionT diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 278add8cc90..884610ca82d 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -1,20 +1,30 @@ import json import logging from collections.abc import Mapping -from typing import Any, cast - -from graphon.variables.input_entities import VariableEntity, VariableEntityType +from typing import Any, NotRequired, TypedDict, cast from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService logger = logging.getLogger(__name__) +class ToolParameterSchemaDict(TypedDict): + type: str + properties: dict[str, Any] + required: list[str] + + +class ToolArgumentsDict(TypedDict): + query: NotRequired[str] + inputs: dict[str, Any] + + def handle_mcp_request( app: App, request: mcp_types.ClientRequest, @@ -119,7 +129,7 @@ def handle_list_tools( mcp_types.Tool( name=app_name, description=description, - inputSchema=parameter_schema, + inputSchema=cast(dict[str, Any], parameter_schema), ) ], ) @@ -154,7 +164,7 @@ def build_parameter_schema( app_mode: str, user_input_form: list[VariableEntity], parameters_dict: dict[str, str], -) -> dict[str, Any]: +) -> ToolParameterSchemaDict: """Build parameter schema for the tool""" parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict) @@ -174,17 +184,18 @@ def build_parameter_schema( } -def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]: +def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict: """Prepare arguments based on app mode""" - if app.mode == AppMode.WORKFLOW: - return {"inputs": arguments} - elif app.mode == AppMode.COMPLETION: - return {"query": "", "inputs": arguments} - else: - # Chat modes - create a copy to avoid modifying original dict - args_copy = arguments.copy() - query = args_copy.pop("query", "") - return {"query": query, "inputs": args_copy} + match app.mode: + case AppMode.WORKFLOW: + return {"inputs": arguments} + case AppMode.COMPLETION: + return {"query": "", "inputs": arguments} + case _: + # Chat modes - create a copy to avoid modifying original dict + args_copy = arguments.copy() + query = args_copy.pop("query", "") + return {"query": query, "inputs": args_copy} def extract_answer_from_response(app: App, response: Any) -> str: @@ -218,17 +229,13 @@ def process_streaming_response(response: RateLimitGenerator) -> str: def process_mapping_response(app: App, response: Mapping) -> str: """Process mapping response based on app mode""" - if app.mode in { - AppMode.ADVANCED_CHAT, - AppMode.COMPLETION, - AppMode.CHAT, - AppMode.AGENT_CHAT, - }: - return response.get("answer", "") - elif app.mode == AppMode.WORKFLOW: - return json.dumps(response["data"]["outputs"], ensure_ascii=False) - else: - raise ValueError("Invalid app mode: " + str(app.mode)) + match app.mode: + case AppMode.ADVANCED_CHAT | AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT: + return response.get("answer", "") + case AppMode.WORKFLOW: + return json.dumps(response["data"]["outputs"], ensure_ascii=False) + case _: + raise ValueError("Invalid app mode: " + str(app.mode)) def convert_input_form_to_parameters( diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index e50fd42198a..70d45b15c45 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -4,7 +4,7 @@ from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from datetime import timedelta from types import TracebackType -from typing import Any, Self, cast +from typing import Any, Self from httpx import HTTPStatusError from pydantic import BaseModel @@ -55,7 +55,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul request: ReceiveRequestT _session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]" - _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any] + _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object] def __init__( self, @@ -63,7 +63,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul request_meta: RequestParams.Meta | None, request: ReceiveRequestT, session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]", - on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], + on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], object], ): self.request_id = request_id self.request_meta = request_meta @@ -338,12 +338,11 @@ class BaseSession[ validated_request = self._receive_request_type.model_validate( message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) ) - validated_request = cast(ReceiveRequestT, validated_request) responder = RequestResponder[ReceiveRequestT, SendResultT]( request_id=message.message.root.id, request_meta=validated_request.root.params.meta if validated_request.root.params else None, - request=validated_request, + request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), ) @@ -359,15 +358,14 @@ class BaseSession[ notification = self._receive_notification_type.model_validate( message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) ) - notification = cast(ReceiveNotificationT, notification) # Handle cancellation notifications if isinstance(notification.root, CancelledNotification): cancelled_id = notification.root.params.requestId if cancelled_id in self._in_flight: self._in_flight[cancelled_id].cancel() else: - self._received_notification(notification) - self._handle_incoming(notification) + self._received_notification(notification) # type: ignore[arg-type] + self._handle_incoming(notification) # type: ignore[arg-type] except Exception as e: # For other validation errors, log and continue logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root) diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 2653d20a7db..10e3082aa36 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -31,7 +31,6 @@ ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] RequestId = Annotated[int | str, Field(union_mode="left_to_right")] -type AnyFunction = Callable[..., Any] class RequestParams(BaseModel): diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 7e350441768..7b5a7635f1c 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse -from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError +from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 09c84538a9a..d840ee213c6 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,5 +1,14 @@ from collections.abc import Sequence +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from factories import file_factory from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -10,15 +19,6 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.file_access import DatabaseFileAccessController -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from extensions.ext_database import db -from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository @@ -61,27 +61,28 @@ class TokenBufferMemory: :param is_user_message: whether this is a user message :return: PromptMessage """ - if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: - file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) - elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - app = self.conversation.app - if not app: - raise ValueError("App not found for conversation") + match self.conversation.mode: + case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT: + file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW: + app = self.conversation.app + if not app: + raise ValueError("App not found for conversation") - if not message.workflow_run_id: - raise ValueError("Workflow run ID not found") + if not message.workflow_run_id: + raise ValueError("Workflow run ID not found") - workflow_run = self.workflow_run_repo.get_workflow_run_by_id( - tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id - ) - if not workflow_run: - raise ValueError(f"Workflow run not found: {message.workflow_run_id}") - workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) - if not workflow: - raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - else: - raise AssertionError(f"Invalid app mode: {self.conversation.mode}") + workflow_run = self.workflow_run_repo.get_workflow_run_by_id( + tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id + ) + if not workflow_run: + raise ValueError(f"Workflow run not found: {message.workflow_run_id}") + workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + case _: + raise AssertionError(f"Invalid app mode: {self.conversation.mode}") detail = ImagePromptMessageContent.DETAIL.HIGH if file_extra_config and app_record: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 87d1d7fba60..86d0e3baaa6 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,22 +1,9 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence -from typing import IO, Any, Literal, Optional, Union, cast, overload - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType -from graphon.model_runtime.entities.rerank_entities import RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from configs import dify_config +from core.entities import PluginCredentialType from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration @@ -24,10 +11,24 @@ from core.errors.error import ProviderTokenNotInitError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_redis import redis_client +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from models.provider import ProviderType -from services.enterprise.plugin_manager_service import PluginCredentialType logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") class ModelInstance: @@ -77,7 +78,7 @@ class ModelInstance: @staticmethod def _get_load_balancing_manager( - configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict + configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any] ) -> Optional["LBModelManager"]: """ Get load balancing model credentials @@ -115,7 +116,7 @@ class ModelInstance: def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True] = True, @@ -126,7 +127,7 @@ class ModelInstance: def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False] = False, @@ -137,7 +138,7 @@ class ModelInstance: def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, @@ -147,7 +148,7 @@ class ModelInstance: def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, @@ -169,13 +170,13 @@ class ModelInstance: return cast( Union[LLMResult, Generator], self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, - prompt_messages=prompt_messages, + prompt_messages=list(prompt_messages), model_parameters=model_parameters, - tools=tools, - stop=stop, + tools=list(tools) if tools else None, + stop=list(stop) if stop else None, stream=stream, callbacks=callbacks, ), @@ -193,15 +194,12 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") - return cast( - int, - self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model_name, - credentials=self.credentials, - prompt_messages=prompt_messages, - tools=tools, - ), + return self._round_robin_invoke( + self.model_type_instance.get_num_tokens, + model=self.model_name, + credentials=self.credentials, + prompt_messages=list(prompt_messages), + tools=list(tools) if tools else None, ) def invoke_text_embedding( @@ -216,15 +214,12 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - return cast( - EmbeddingResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - texts=texts, - input_type=input_type, - ), + return self._round_robin_invoke( + self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + texts=texts, + input_type=input_type, ) def invoke_multimodal_embedding( @@ -241,15 +236,12 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - return cast( - EmbeddingResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - multimodel_documents=multimodel_documents, - input_type=input_type, - ), + return self._round_robin_invoke( + self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + multimodel_documents=multimodel_documents, + input_type=input_type, ) def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: @@ -261,14 +253,11 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") - return cast( - list[int], - self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model_name, - credentials=self.credentials, - texts=texts, - ), + return self._round_robin_invoke( + self.model_type_instance.get_num_tokens, + model=self.model_name, + credentials=self.credentials, + texts=texts, ) def invoke_rerank( @@ -289,23 +278,20 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") - return cast( - RerankResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ), + return self._round_robin_invoke( + self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, ) def invoke_multimodal_rerank( self, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, ) -> RerankResult: @@ -320,17 +306,14 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") - return cast( - RerankResult, - self._round_robin_invoke( - function=self.model_type_instance.invoke_multimodal_rerank, - model=self.model_name, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ), + return self._round_robin_invoke( + self.model_type_instance.invoke_multimodal_rerank, + model=self.model_name, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, ) def invoke_moderation(self, text: str) -> bool: @@ -342,14 +325,11 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") - return cast( - bool, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - text=text, - ), + return self._round_robin_invoke( + self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + text=text, ) def invoke_speech2text(self, file: IO[bytes]) -> str: @@ -361,14 +341,11 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") - return cast( - str, - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - file=file, - ), + return self._round_robin_invoke( + self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + file=file, ) def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: @@ -381,18 +358,15 @@ class ModelInstance: """ if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") - return cast( - Iterable[bytes], - self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model_name, - credentials=self.credentials, - content_text=content_text, - voice=voice, - ), + return self._round_robin_invoke( + self.model_type_instance.invoke, + model=self.model_name, + credentials=self.credentials, + content_text=content_text, + voice=voice, ) - def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): + def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Round-robin invoke :param function: function to invoke @@ -430,9 +404,8 @@ class ModelInstance: continue try: - if "credentials" in kwargs: - del kwargs["credentials"] - return function(*args, **kwargs, credentials=lb_config.credentials) + kwargs["credentials"] = lb_config.credentials + return function(*args, **kwargs) except InvokeRateLimitError as e: # expire in 60 seconds self.load_balancing_manager.cooldown(lb_config, expire=60) @@ -556,7 +529,7 @@ class LBModelManager: model_type: ModelType, model: str, load_balancing_configs: list[ModelLoadBalancingConfiguration], - managed_credentials: dict | None = None, + managed_credentials: dict[str, Any] | None = None, ): """ Load balancing model manager diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 2d72b17a04d..28165592fc5 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field from sqlalchemy import select @@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension class ModerationInputParams(BaseModel): app_id: str = "" - inputs: dict = Field(default_factory=dict) + inputs: dict[str, Any] = Field(default_factory=dict) query: str = "" @@ -23,7 +25,7 @@ class ApiModeration(Moderation): name: str = "api" @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -41,7 +43,7 @@ class ApiModeration(Moderation): if not extension: raise ValueError("API-based Extension not found. Please check it again.") - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" if self.config is None: @@ -73,7 +75,7 @@ class ApiModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) - def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict): + def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]): if self.config is None: raise ValueError("The config is not set.") extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 31dd0d5568a..e090ee89add 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from enum import StrEnum, auto +from typing import Any from pydantic import BaseModel, Field @@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel): flagged: bool = False action: ModerationAction preset_response: str = "" - inputs: dict = Field(default_factory=dict) + inputs: dict[str, Any] = Field(default_factory=dict) query: str = "" @@ -33,13 +34,13 @@ class Moderation(Extensible, ABC): module: ExtensionModule = ExtensionModule.MODERATION - def __init__(self, app_id: str, tenant_id: str, config: dict | None = None): + def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None): super().__init__(tenant_id, config) self.app_id = app_id @classmethod @abstractmethod - def validate_config(cls, tenant_id: str, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None: """ Validate the incoming form config data. @@ -50,7 +51,7 @@ class Moderation(Extensible, ABC): raise NotImplementedError @abstractmethod - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: """ Moderation for inputs. After the user inputs, this method will be called to perform sensitive content review @@ -75,7 +76,7 @@ class Moderation(Extensible, ABC): raise NotImplementedError @classmethod - def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool): + def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool): # inputs_config inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index c2c8be6d6d0..c22306ac94e 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -1,3 +1,5 @@ +from typing import Any + from core.extension.extensible import ExtensionModule from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult from extensions.ext_code_based_extension import code_based_extension @@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension class ModerationFactory: __extension_instance: Moderation - def __init__(self, name: str, app_id: str, tenant_id: str, config: dict): + def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]): extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) self.__extension_instance = extension_class(app_id, tenant_id, config) @classmethod - def validate_config(cls, name: str, tenant_id: str, config: dict): + def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -24,7 +26,7 @@ class ModerationFactory: # FIXME: mypy error, try to fix it instead of using type: ignore extension_class.validate_config(tenant_id, config) # type: ignore - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: """ Moderation for inputs. After the user inputs, this method will be called to perform sensitive content review diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 8d8d1537439..7d80d3a53c8 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -8,7 +8,7 @@ class KeywordsModeration(Moderation): name: str = "keywords" @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -28,7 +28,7 @@ class KeywordsModeration(Moderation): if len(keywords_row_len) > 100: raise ValueError("the number of rows for the keywords must be less than 100") - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" if self.config is None: @@ -66,7 +66,7 @@ class KeywordsModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) - def _is_violated(self, inputs: dict, keywords_list: list) -> bool: + def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool: return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool: diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index dd038c77f13..6e6e94502cc 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,14 +1,15 @@ -from graphon.model_runtime.entities.model_entities import ModelType +from typing import Any from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult +from graphon.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): name: str = "openai_moderation" @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -18,7 +19,7 @@ class OpenAIModeration(Moderation): """ cls._validate_inputs_and_outputs_config(config, True) - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" if self.config is None: @@ -49,7 +50,7 @@ class OpenAIModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) - def _is_violated(self, inputs: dict): + def _is_violated(self, inputs: dict[str, Any]): text = "\n".join(str(inputs.values())) model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) model_instance = model_manager.get_model_instance( diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index 8c081ae2253..a1f96b9edf4 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -56,8 +56,10 @@ class BaseTraceInstance(ABC): if not service_account: raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") - current_tenant = ( - session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() + current_tenant = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True)) + .limit(1) ) if not current_tenant: raise ValueError(f"Current tenant not found for account {service_account.id}") diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index fda00ac3b9c..d78ce90aa1a 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -1,8 +1,8 @@ from enum import StrEnum -from pydantic import BaseModel, ValidationInfo, field_validator +from pydantic import BaseModel -from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path +from core.ops.utils import validate_project_name, validate_url class TracingProviderEnum(StrEnum): @@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel): return validate_project_name(v, default_name) -class ArizeConfig(BaseTracingConfig): - """ - Model class for Arize tracing config. - """ - - api_key: str | None = None - space_id: str | None = None - project: str | None = None - endpoint: str = "https://otlp.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://otlp.arize.com") - - -class PhoenixConfig(BaseTracingConfig): - """ - Model class for Phoenix tracing config. - """ - - api_key: str | None = None - project: str | None = None - endpoint: str = "https://app.phoenix.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://app.phoenix.arize.com") - - -class LangfuseConfig(BaseTracingConfig): - """ - Model class for Langfuse tracing config. - """ - - public_key: str - secret_key: str - host: str = "https://api.langfuse.com" - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://api.langfuse.com") - - -class LangSmithConfig(BaseTracingConfig): - """ - Model class for Langsmith tracing config. - """ - - api_key: str - project: str - endpoint: str = "https://api.smith.langchain.com" - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # LangSmith only allows HTTPS - return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) - - -class OpikConfig(BaseTracingConfig): - """ - Model class for Opik tracing config. - """ - - api_key: str | None = None - project: str | None = None - workspace: str | None = None - url: str = "https://www.comet.com/opik/api/" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "Default Project") - - @field_validator("url") - @classmethod - def url_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") - - -class WeaveConfig(BaseTracingConfig): - """ - Model class for Weave tracing config. - """ - - api_key: str - entity: str | None = None - project: str - endpoint: str = "https://trace.wandb.ai" - host: str | None = None - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # Weave only allows HTTPS for endpoint - return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - if v is not None and v.strip() != "": - return validate_url(v, v, allowed_schemes=("https", "http")) - return v - - -class AliyunConfig(BaseTracingConfig): - """ - Model class for Aliyun tracing config. - """ - - app_name: str = "dify_app" - license_key: str - endpoint: str - - @field_validator("app_name") - @classmethod - def app_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - @field_validator("license_key") - @classmethod - def license_key_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("License key cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # aliyun uses two URL formats, which may include a URL path - return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") - - -class TencentConfig(BaseTracingConfig): - """ - Tencent APM tracing config - """ - - token: str - endpoint: str - service_name: str - - @field_validator("token") - @classmethod - def token_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("Token cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") - - @field_validator("service_name") - @classmethod - def service_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - -class MLflowConfig(BaseTracingConfig): - """ - Model class for MLflow tracing config. - """ - - tracking_uri: str = "http://localhost:5000" - experiment_id: str = "0" # Default experiment id in MLflow is 0 - username: str | None = None - password: str | None = None - - @field_validator("tracking_uri") - @classmethod - def tracking_uri_validator(cls, v, info: ValidationInfo): - if isinstance(v, str) and v.startswith("databricks"): - raise ValueError( - "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." - ) - return validate_url_with_path(v, "http://localhost:5000") - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - -class DatabricksConfig(BaseTracingConfig): - """ - Model class for Databricks (Databricks-managed MLflow) tracing config. - """ - - experiment_id: str - host: str - client_id: str | None = None - client_secret: str | None = None - personal_access_token: str | None = None - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index fd235faf805..e7ba6e502bb 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict): class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): def __getitem__(self, provider: str) -> TracingProviderConfigEntry: - match provider: - case TracingProviderEnum.LANGFUSE: - from core.ops.entities.config_entity import LangfuseConfig - from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace + try: + match provider: + case TracingProviderEnum.LANGFUSE: + from dify_trace_langfuse.config import LangfuseConfig + from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace - return { - "config_class": LangfuseConfig, - "secret_keys": ["public_key", "secret_key"], - "other_keys": ["host", "project_key"], - "trace_instance": LangFuseDataTrace, - } + return { + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, + } - case TracingProviderEnum.LANGSMITH: - from core.ops.entities.config_entity import LangSmithConfig - from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace + case TracingProviderEnum.LANGSMITH: + from dify_trace_langsmith.config import LangSmithConfig + from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace - return { - "config_class": LangSmithConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": LangSmithDataTrace, - } + return { + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + } - case TracingProviderEnum.OPIK: - from core.ops.entities.config_entity import OpikConfig - from core.ops.opik_trace.opik_trace import OpikDataTrace + case TracingProviderEnum.OPIK: + from dify_trace_opik.config import OpikConfig + from dify_trace_opik.opik_trace import OpikDataTrace - return { - "config_class": OpikConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "url", "workspace"], - "trace_instance": OpikDataTrace, - } + return { + "config_class": OpikConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "url", "workspace"], + "trace_instance": OpikDataTrace, + } - case TracingProviderEnum.WEAVE: - from core.ops.entities.config_entity import WeaveConfig - from core.ops.weave_trace.weave_trace import WeaveDataTrace + case TracingProviderEnum.WEAVE: + from dify_trace_weave.config import WeaveConfig + from dify_trace_weave.weave_trace import WeaveDataTrace - return { - "config_class": WeaveConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint", "host"], - "trace_instance": WeaveDataTrace, - } - case TracingProviderEnum.ARIZE: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import ArizeConfig + return { + "config_class": WeaveConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "entity", "endpoint", "host"], + "trace_instance": WeaveDataTrace, + } + case TracingProviderEnum.ARIZE: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import ArizeConfig - return { - "config_class": ArizeConfig, - "secret_keys": ["api_key", "space_id"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.PHOENIX: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import PhoenixConfig + return { + "config_class": ArizeConfig, + "secret_keys": ["api_key", "space_id"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.PHOENIX: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import PhoenixConfig - return { - "config_class": PhoenixConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.ALIYUN: - from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace - from core.ops.entities.config_entity import AliyunConfig + return { + "config_class": PhoenixConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.ALIYUN: + from dify_trace_aliyun.aliyun_trace import AliyunDataTrace + from dify_trace_aliyun.config import AliyunConfig - return { - "config_class": AliyunConfig, - "secret_keys": ["license_key"], - "other_keys": ["endpoint", "app_name"], - "trace_instance": AliyunDataTrace, - } - case TracingProviderEnum.MLFLOW: - from core.ops.entities.config_entity import MLflowConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": AliyunConfig, + "secret_keys": ["license_key"], + "other_keys": ["endpoint", "app_name"], + "trace_instance": AliyunDataTrace, + } + case TracingProviderEnum.MLFLOW: + from dify_trace_mlflow.config import MLflowConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": MLflowConfig, - "secret_keys": ["password"], - "other_keys": ["tracking_uri", "experiment_id", "username"], - "trace_instance": MLflowDataTrace, - } - case TracingProviderEnum.DATABRICKS: - from core.ops.entities.config_entity import DatabricksConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": MLflowConfig, + "secret_keys": ["password"], + "other_keys": ["tracking_uri", "experiment_id", "username"], + "trace_instance": MLflowDataTrace, + } + case TracingProviderEnum.DATABRICKS: + from dify_trace_mlflow.config import DatabricksConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": DatabricksConfig, - "secret_keys": ["personal_access_token", "client_secret"], - "other_keys": ["host", "client_id", "experiment_id"], - "trace_instance": MLflowDataTrace, - } + return { + "config_class": DatabricksConfig, + "secret_keys": ["personal_access_token", "client_secret"], + "other_keys": ["host", "client_id", "experiment_id"], + "trace_instance": MLflowDataTrace, + } - case TracingProviderEnum.TENCENT: - from core.ops.entities.config_entity import TencentConfig - from core.ops.tencent_trace.tencent_trace import TencentDataTrace + case TracingProviderEnum.TENCENT: + from dify_trace_tencent.config import TencentConfig + from dify_trace_tencent.tencent_trace import TencentDataTrace - return { - "config_class": TencentConfig, - "secret_keys": ["token"], - "other_keys": ["endpoint", "service_name"], - "trace_instance": TencentDataTrace, - } + return { + "config_class": TencentConfig, + "secret_keys": ["token"], + "other_keys": ["endpoint", "service_name"], + "trace_instance": TencentDataTrace, + } - case _: - raise KeyError(f"Unsupported tracing provider: {provider}") + case _: + raise KeyError(f"Unsupported tracing provider: {provider}") + except ImportError: + raise ImportError(f"Provider {provider} is not installed.") provider_config_map = OpsTraceProviderConfigMap() @@ -324,7 +327,7 @@ class OpsTraceManager: @classmethod def encrypt_tracing_config( - cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None + cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any], current_trace_config=None ): """ Encrypt tracing config. @@ -363,7 +366,7 @@ class OpsTraceManager: return encrypted_config.model_dump() @classmethod - def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict): + def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any]): """ Decrypt tracing config :param tenant_id: tenant id @@ -408,7 +411,7 @@ class OpsTraceManager: return dict(decrypted_config) @classmethod - def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict): + def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict[str, Any]): """ Decrypt tracing config :param tracing_provider: tracing provider @@ -581,7 +584,7 @@ class OpsTraceManager: return app_trace_config @staticmethod - def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str): + def check_trace_config_is_effective(tracing_config: dict[str, Any], tracing_provider: str): """ Check trace config is effective :param tracing_config: tracing config @@ -596,7 +599,7 @@ class OpsTraceManager: return trace_instance(config).api_check() @staticmethod - def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): + def get_trace_config_project_key(tracing_config: dict[str, Any], tracing_provider: str): """ get trace config is project key :param tracing_config: tracing config @@ -611,7 +614,7 @@ class OpsTraceManager: return trace_instance(config).get_project_key() @staticmethod - def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): + def get_trace_config_project_url(tracing_config: dict[str, Any], tracing_provider: str): """ get trace config is project key :param tracing_config: tracing config @@ -1322,8 +1325,8 @@ class TraceTask: error=error, ) - def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: - node_data: dict = kwargs.get("node_execution_data", {}) + def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict[str, Any]: + node_data: dict[str, Any] = kwargs.get("node_execution_data", {}) if not node_data: return {} @@ -1431,7 +1434,7 @@ class TraceTask: return node_trace return DraftNodeExecutionTrace(**node_trace.model_dump()) - def _extract_streaming_metrics(self, message_data) -> dict: + def _extract_streaming_metrics(self, message_data) -> dict[str, Any]: if not message_data.message_metadata: return {} diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index be11d2223ca..c76cb865c31 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Any, Union, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -72,17 +72,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): conversation_id = conversation_id or "" - if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}: - if not query: - raise ValueError("missing query") + match app.mode: + case AppMode.ADVANCED_CHAT | AppMode.AGENT_CHAT | AppMode.CHAT: + if not query: + raise ValueError("missing query") - return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files) - elif app.mode == AppMode.WORKFLOW: - return cls.invoke_workflow_app(app, user, stream, inputs, files) - elif app.mode == AppMode.COMPLETION: - return cls.invoke_completion_app(app, user, stream, inputs, files) - - raise ValueError("unexpected app type") + return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files) + case AppMode.WORKFLOW: + return cls.invoke_workflow_app(app, user, stream, inputs, files) + case AppMode.COMPLETION: + return cls.invoke_completion_app(app, user, stream, inputs, files) + case _: + raise ValueError("unexpected app type") @classmethod def invoke_chat_app( @@ -98,60 +99,61 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): """ invoke chat app """ - if app.mode == AppMode.ADVANCED_CHAT: - workflow = app.workflow - if not workflow: + match app.mode: + case AppMode.ADVANCED_CHAT: + workflow = app.workflow + if not workflow: + raise ValueError("unexpected app type") + + pause_config = PauseStateLayerConfig( + session_factory=db.engine, + state_owner_user_id=workflow.created_by, + ) + + return AdvancedChatAppGenerator().generate( + app_model=app, + workflow=workflow, + user=user, + args={ + "inputs": inputs, + "query": query, + "files": files, + "conversation_id": conversation_id, + }, + invoke_from=InvokeFrom.SERVICE_API, + workflow_run_id=str(uuid.uuid4()), + streaming=stream, + pause_state_config=pause_config, + ) + case AppMode.AGENT_CHAT: + return AgentChatAppGenerator().generate( + app_model=app, + user=user, + args={ + "inputs": inputs, + "query": query, + "files": files, + "conversation_id": conversation_id, + }, + invoke_from=InvokeFrom.SERVICE_API, + streaming=stream, + ) + case AppMode.CHAT: + return ChatAppGenerator().generate( + app_model=app, + user=user, + args={ + "inputs": inputs, + "query": query, + "files": files, + "conversation_id": conversation_id, + }, + invoke_from=InvokeFrom.SERVICE_API, + streaming=stream, + ) + case _: raise ValueError("unexpected app type") - pause_config = PauseStateLayerConfig( - session_factory=db.engine, - state_owner_user_id=workflow.created_by, - ) - - return AdvancedChatAppGenerator().generate( - app_model=app, - workflow=workflow, - user=user, - args={ - "inputs": inputs, - "query": query, - "files": files, - "conversation_id": conversation_id, - }, - invoke_from=InvokeFrom.SERVICE_API, - workflow_run_id=str(uuid.uuid4()), - streaming=stream, - pause_state_config=pause_config, - ) - elif app.mode == AppMode.AGENT_CHAT: - return AgentChatAppGenerator().generate( - app_model=app, - user=user, - args={ - "inputs": inputs, - "query": query, - "files": files, - "conversation_id": conversation_id, - }, - invoke_from=InvokeFrom.SERVICE_API, - streaming=stream, - ) - elif app.mode == AppMode.CHAT: - return ChatAppGenerator().generate( - app_model=app, - user=user, - args={ - "inputs": inputs, - "query": query, - "files": files, - "conversation_id": conversation_id, - }, - invoke_from=InvokeFrom.SERVICE_API, - streaming=stream, - ) - else: - raise ValueError("unexpected app type") - @classmethod def invoke_workflow_app( cls, @@ -205,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): ) @classmethod - def _get_user(cls, user_id: str) -> Union[EndUser, Account]: + def _get_user(cls, user_id: str) -> EndUser | Account: """ get the user by user id """ diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index c715b9171c6..c92438960a0 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -1,20 +1,7 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator - -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType +from typing import Any from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output @@ -32,6 +19,19 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant @@ -226,7 +226,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): # invoke model response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) - def handle() -> Generator[dict, None, None]: + def handle() -> Generator[dict[str, Any], None, None]: for chunk in response: yield {"result": hexlify(chunk).decode("utf-8")} diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 94789974942..9550e499922 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,3 +1,4 @@ +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from graphon.enums import BuiltinNodeTypes from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( @@ -8,8 +9,6 @@ from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) - -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService diff --git a/api/core/plugin/entities/__init__.py b/api/core/plugin/entities/__init__.py new file mode 100644 index 00000000000..9456ff01816 --- /dev/null +++ b/api/core/plugin/entities/__init__.py @@ -0,0 +1,5 @@ +from core.plugin.entities.oauth import OAuthSchema + +__all__ = [ + "OAuthSchema", +] diff --git a/api/core/plugin/entities/endpoint.py b/api/core/plugin/entities/endpoint.py index e5bca140f82..64199636687 100644 --- a/api/core/plugin/entities/endpoint.py +++ b/api/core/plugin/entities/endpoint.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any from pydantic import BaseModel, Field, model_validator @@ -31,7 +32,7 @@ class EndpointEntity(BasePluginEntity): entity of an endpoint """ - settings: dict + settings: dict[str, Any] tenant_id: str plugin_id: str expired_at: datetime diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 2177e8af908..03398873e30 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,10 +1,12 @@ -from graphon.model_runtime.entities.provider_entities import ProviderEntity +from typing import Any + from pydantic import BaseModel, Field, computed_field, model_validator from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): @@ -40,7 +42,7 @@ class MarketplacePluginDeclaration(BaseModel): @model_validator(mode="before") @classmethod - def transform_declaration(cls, data: dict): + def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]: if "endpoint" in data and not data["endpoint"]: del data["endpoint"] if "model" in data and not data["model"]: diff --git a/api/core/plugin/entities/oauth.py b/api/core/plugin/entities/oauth.py index d284b827287..483ebbc5351 100644 --- a/api/core/plugin/entities/oauth.py +++ b/api/core/plugin/entities/oauth.py @@ -1,5 +1,3 @@ -from collections.abc import Sequence - from pydantic import BaseModel, Field from core.entities.provider_entities import ProviderConfig @@ -10,12 +8,12 @@ class OAuthSchema(BaseModel): OAuth schema """ - client_schema: Sequence[ProviderConfig] = Field( + client_schema: list[ProviderConfig] = Field( default_factory=list, description="client schema like client_id, client_secret, etc.", ) - credentials_schema: Sequence[ProviderConfig] = Field( + credentials_schema: list[ProviderConfig] = Field( default_factory=list, description="credentials schema like access_token, refresh_token, etc.", ) diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index b095b4998d7..89e0e8881cb 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,7 +3,6 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any -from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -14,6 +13,7 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): @@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel): @model_validator(mode="before") @classmethod - def validate_category(cls, values: dict): + def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]: # auto detect category if values.get("tool"): values["category"] = PluginCategory.Tool diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index b57180690e8..257638ad77c 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,8 +6,6 @@ from datetime import datetime from enum import StrEnum from typing import Any -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -18,6 +16,8 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginDaemonBasicResponse[T: BaseModel | dict | list | bool | str](BaseModel): @@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel): """ result: bool - credentials: dict | None = None + credentials: dict[str, Any] | None = None class PluginModelSchemaEntity(BaseModel): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 059f3fa9be1..14748832049 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,6 +4,10 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -21,10 +25,6 @@ from graphon.nodes.parameter_extractor.entities import ( from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): @@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel): tool_type: Literal["builtin", "workflow", "api", "mcp"] provider: str tool: str - tool_parameters: dict + tool_parameters: dict[str, Any] credential_id: str | None = None @@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel): opt: Literal["encrypt", "decrypt", "clear"] namespace: Literal["endpoint"] identity: str - data: dict = Field(default_factory=dict) + data: dict[str, Any] = Field(default_factory=dict) config: list[BasicProviderConfig] = Field(default_factory=list) diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7f36560b49e..9ee8469892a 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,14 +5,6 @@ from collections.abc import Callable, Generator from typing import Any, cast import httpx -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -37,6 +29,14 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index ce1ef714944..56c08addbab 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient): Fetch datasource providers for the given tenant. """ - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]) -> dict[str, Any]: if json_response.get("data"): for provider in json_response.get("data", []): declaration = provider.get("declaration", {}) or {} @@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient): Fetch datasource providers for the given tenant. """ - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]) -> dict[str, Any]: if json_response.get("data"): for provider in json_response.get("data", []): declaration = provider.get("declaration", {}) or {} @@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient): tool_provider_id = DatasourceProviderID(provider_id) - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]) -> dict[str, Any]: data = json_response.get("data") if data: for datasource in data.get("declaration", {}).get("datasources", []): diff --git a/api/core/plugin/impl/endpoint.py b/api/core/plugin/impl/endpoint.py index 2db5185a2c7..b335b427633 100644 --- a/api/core/plugin/impl/endpoint.py +++ b/api/core/plugin/impl/endpoint.py @@ -1,3 +1,5 @@ +from typing import Any + from core.plugin.entities.endpoint import EndpointEntityWithInstance from core.plugin.impl.base import BasePluginClient from core.plugin.impl.exc import PluginDaemonInternalServerError @@ -5,7 +7,12 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError class PluginEndpointClient(BasePluginClient): def create_endpoint( - self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict + self, + tenant_id: str, + user_id: str, + plugin_unique_identifier: str, + name: str, + settings: dict[str, Any], ) -> bool: """ Create an endpoint for the given plugin. @@ -49,7 +56,9 @@ class PluginEndpointClient(BasePluginClient): params={"plugin_id": plugin_id, "page": page, "page_size": page_size}, ) - def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict): + def update_endpoint( + self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any] + ) -> bool: """ Update the settings of the given endpoint. """ diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 1e38c24717f..47608bdfa6e 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,13 +2,6 @@ import binascii from collections.abc import Generator, Sequence from typing import IO, Any -from graphon.model_runtime.entities.llm_entities import LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.utils.encoders import jsonable_encoder - from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -20,6 +13,12 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): @@ -50,7 +49,7 @@ class PluginModelClient(BasePluginClient): provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], ) -> AIModelEntity | None: """ Get model schema @@ -80,7 +79,7 @@ class PluginModelClient(BasePluginClient): return None def validate_provider_credentials( - self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict + self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any] ) -> bool: """ validate the credentials of the provider @@ -118,7 +117,7 @@ class PluginModelClient(BasePluginClient): provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], ) -> bool: """ validate the credentials of the provider @@ -157,9 +156,9 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, @@ -206,7 +205,7 @@ class PluginModelClient(BasePluginClient): provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], prompt_messages: list[PromptMessage], tools: list[PromptMessageTool] | None = None, ) -> int: @@ -248,7 +247,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], texts: list[str], input_type: str, ) -> EmbeddingResult: @@ -290,7 +289,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], documents: list[dict], input_type: str, ) -> EmbeddingResult: @@ -332,7 +331,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], texts: list[str], ) -> list[int]: """ @@ -372,7 +371,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], query: str, docs: list[str], score_threshold: float | None = None, @@ -418,7 +417,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], query: MultimodalRerankInput, docs: list[MultimodalRerankInput], score_threshold: float | None = None, @@ -463,7 +462,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], content_text: str, voice: str, ) -> Generator[bytes, None, None]: @@ -508,7 +507,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], language: str | None = None, ): """ @@ -552,7 +551,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], file: IO[bytes], ) -> str: """ @@ -592,7 +591,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], text: str, ) -> bool: """ diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 22c846b6de0..4e66d58b5ed 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -6,13 +6,6 @@ from collections.abc import Generator, Iterable, Sequence from threading import Lock from typing import IO, Any, Union -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.runtime import ModelRuntime from pydantic import ValidationError from redis import RedisError @@ -21,6 +14,13 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) @@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime): if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") file_name = ( - provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us ) elif icon_type.lower() == "icon_small_dark": if not provider_schema.icon_small_dark: raise ValueError(f"Provider {provider} does not have small dark icon.") file_name = ( - provider_schema.icon_small_dark.zh_Hans + provider_schema.icon_small_dark.zh_hans if lang.lower() == "zh_hans" - else provider_schema.icon_small_dark.en_US + else provider_schema.icon_small_dark.en_us ) else: raise ValueError(f"Unsupported icon type: {icon_type}.") diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 4b29a6fc56b..35abd2ae8c1 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -2,9 +2,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory - from core.plugin.impl.model import PluginModelClient +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory if TYPE_CHECKING: from core.model_manager import ModelManager diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index ec4858ae2e1..8a7175bb51f 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import Any from requests import HTTPError @@ -209,7 +210,10 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/decode/from_identifier", PluginDecodeResponse, - params={"plugin_unique_identifier": plugin_unique_identifier}, + params={ + "plugin_unique_identifier": plugin_unique_identifier, + "PluginUniqueIdentifier": plugin_unique_identifier, # compat with daemon <= 0.5.4 + }, ) def fetch_plugin_installation_by_ids( @@ -260,7 +264,7 @@ class PluginInstaller(BasePluginClient): original_plugin_unique_identifier: str, new_plugin_unique_identifier: str, source: PluginInstallationSource, - meta: dict, + meta: dict[str, Any], ) -> PluginInstallTaskStartResponse: """ Upgrade a plugin. diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 90350f84000..12d8e282b2a 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,8 +1,7 @@ from typing import Any -from graphon.file import File - from core.tools.entities.tool_entities import ToolSelector +from graphon.file import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 19b5e9223a8..24e05ef8659 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,13 @@ from collections.abc import Mapping, Sequence from typing import cast +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -13,14 +20,6 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.runtime import VariablePool -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser - class AdvancedPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 9be70199b7d..7c6280fe937 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,17 +1,16 @@ from typing import cast -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.prompt_transform import PromptTransform +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 4539ae9f11b..6ff2f44cdc8 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,12 +1,11 @@ from typing import Any -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index c706353ffeb..1665bdeb528 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -2,8 +2,14 @@ import json import os from collections.abc import Mapping, Sequence from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, @@ -13,13 +19,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: @@ -34,6 +33,13 @@ class ModelMode(StrEnum): prompt_file_contents: dict[str, Any] = {} +class PromptTemplateConfigDict(TypedDict): + prompt_template: PromptTemplateParser + custom_variable_keys: list[str] + special_variable_keys: list[str] + prompt_rules: dict[str, Any] + + class SimplePromptTransform(PromptTransform): """ Simple Prompt Transform for Chatbot App Basic Mode. @@ -89,11 +95,11 @@ class SimplePromptTransform(PromptTransform): app_mode: AppMode, model_config: ModelConfigWithCredentialsEntity, pre_prompt: str, - inputs: dict, + inputs: dict[str, Any], query: str | None = None, context: str | None = None, histories: str | None = None, - ) -> tuple[str, dict]: + ) -> tuple[str, dict[str, Any]]: # get prompt template prompt_template_config = self.get_prompt_template( app_mode=app_mode, @@ -105,18 +111,13 @@ class SimplePromptTransform(PromptTransform): with_memory_prompt=histories is not None, ) - custom_variable_keys_obj = prompt_template_config["custom_variable_keys"] - special_variable_keys_obj = prompt_template_config["special_variable_keys"] + custom_variable_keys = prompt_template_config["custom_variable_keys"] + if not isinstance(custom_variable_keys, list): + raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}") - # Type check for custom_variable_keys - if not isinstance(custom_variable_keys_obj, list): - raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}") - custom_variable_keys = cast(list[str], custom_variable_keys_obj) - - # Type check for special_variable_keys - if not isinstance(special_variable_keys_obj, list): - raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}") - special_variable_keys = cast(list[str], special_variable_keys_obj) + special_variable_keys = prompt_template_config["special_variable_keys"] + if not isinstance(special_variable_keys, list): + raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}") variables = {k: inputs[k] for k in custom_variable_keys if k in inputs} @@ -150,7 +151,7 @@ class SimplePromptTransform(PromptTransform): has_context: bool, query_in_prompt: bool, with_memory_prompt: bool = False, - ) -> dict[str, object]: + ) -> PromptTemplateConfigDict: prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys: list[str] = [] @@ -173,18 +174,19 @@ class SimplePromptTransform(PromptTransform): prompt += prompt_rules.get("query_prompt", "{{#query#}}") special_variable_keys.append("#query#") - return { + result: PromptTemplateConfigDict = { "prompt_template": PromptTemplateParser(template=prompt), "custom_variable_keys": custom_variable_keys, "special_variable_keys": special_variable_keys, "prompt_rules": prompt_rules, } + return result def _get_chat_model_prompt_messages( self, app_mode: AppMode, pre_prompt: str, - inputs: dict, + inputs: dict[str, Any], query: str, context: str | None, files: Sequence["File"], @@ -231,7 +233,7 @@ class SimplePromptTransform(PromptTransform): self, app_mode: AppMode, pre_prompt: str, - inputs: dict, + inputs: dict[str, Any], query: str, context: str | None, files: Sequence["File"], @@ -310,7 +312,7 @@ class SimplePromptTransform(PromptTransform): return prompt_message - def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str): + def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict[str, Any]: """ Get simple prompt rule. :param app_mode: app mode @@ -322,7 +324,7 @@ class SimplePromptTransform(PromptTransform): # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: - return cast(dict, prompt_file_contents[prompt_file_name]) + return cast(dict[str, Any], prompt_file_contents[prompt_file_name]) # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") @@ -335,7 +337,7 @@ class SimplePromptTransform(PromptTransform): # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content - return cast(dict, content) + return cast(dict[str, Any], content) def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index dbda7499255..ba76eb0c4e0 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from typing import Any, cast +from core.prompt.simple_prompt_transform import ModelMode from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -11,8 +12,6 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, ) -from core.prompt.simple_prompt_transform import ModelMode - class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 5d536e0e324..8969825be4f 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,20 +1,12 @@ from __future__ import annotations import contextlib -import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -41,6 +33,14 @@ from core.helper.position_helper import is_filtered from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, @@ -58,6 +58,8 @@ from services.feature_service import FeatureService if TYPE_CHECKING: from graphon.model_runtime.runtime import ModelRuntime +_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) + class ProviderManager: """ @@ -68,12 +70,32 @@ class ProviderManager: Request-bound managers may carry caller identity in that runtime, and the resulting ``ProviderConfiguration`` objects must reuse it for downstream model-type and schema lookups. + + Configuration assembly is cached per manager instance so call chains that + share one request-scoped manager can reuse the same provider graph instead + of rebuilding it for every lookup. Call ``clear_configurations_cache()`` + when a long-lived manager needs to observe writes performed within the same + instance scope. """ + decoding_rsa_key: Any | None + decoding_cipher_rsa: Any | None + _model_runtime: ModelRuntime + _configurations_cache: dict[str, ProviderConfigurations] + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None self._model_runtime = model_runtime + self._configurations_cache = {} + + def clear_configurations_cache(self, tenant_id: str | None = None) -> None: + """Drop assembled provider configurations cached on this manager instance.""" + if tenant_id is None: + self._configurations_cache.clear() + return + + self._configurations_cache.pop(tenant_id, None) def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -112,6 +134,10 @@ class ProviderManager: :param tenant_id: :return: """ + cached_configurations = self._configurations_cache.get(tenant_id) + if cached_configurations is not None: + return cached_configurations + # Get all provider records of the workspace provider_name_to_provider_records_dict = self._get_all_providers(tenant_id) @@ -271,6 +297,8 @@ class ProviderManager: provider_configurations[str(provider_id_entity)] = provider_configuration + self._configurations_cache[tenant_id] = provider_configurations + # Return the encapsulated object return provider_configurations @@ -854,7 +882,7 @@ class ProviderManager: secret_variables: list[str], cache_type: ProviderCredentialsCacheType, is_provider: bool = False, - ) -> dict: + ) -> dict[str, Any]: """Get and decrypt credentials with caching.""" credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, @@ -875,8 +903,8 @@ class ProviderManager: return {"openai_api_key": encrypted_config} try: - credentials = cast(dict, json.loads(encrypted_config)) - except JSONDecodeError: + credentials = _credentials_adapter.validate_json(encrypted_config) + except (ValueError, JSONDecodeError): return {} # Decrypt secret variables @@ -959,36 +987,37 @@ class ProviderManager: raise ValueError("quota_used is None") if provider_record.quota_limit is None: raise ValueError("quota_limit is None") - if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=trail_pool.quota_used, - quota_limit=trail_pool.quota_limit, - is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + match provider_quota.quota_type: + case ProviderQuotaType.TRIAL if trail_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=trail_pool.quota_used, + quota_limit=trail_pool.quota_limit, + is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) - elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=paid_pool.quota_used, - quota_limit=paid_pool.quota_limit, - is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + case ProviderQuotaType.PAID if paid_pool is not None: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=paid_pool.quota_used, + quota_limit=paid_pool.quota_limit, + is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) - else: - quota_configuration = QuotaConfiguration( - quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, - quota_used=provider_record.quota_used, - quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used - or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models, - ) + case _: + quota_configuration = QuotaConfiguration( + quota_type=provider_quota.quota_type, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, + quota_used=provider_record.quota_used, + quota_limit=provider_record.quota_limit, + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, + ) quota_configurations.append(quota_configuration) @@ -1015,7 +1044,7 @@ class ProviderManager: if not cached_provider_credentials: provider_credentials: dict[str, Any] = {} if provider_records and provider_records[0].encrypted_config: - provider_credentials = json.loads(provider_records[0].encrypted_config) + provider_credentials = _credentials_adapter.validate_json(provider_records[0].encrypted_config) # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( @@ -1162,8 +1191,10 @@ class ProviderManager: if not cached_provider_model_credentials: try: - provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config) - except JSONDecodeError: + provider_model_credentials = _credentials_adapter.validate_json( + load_balancing_model_config.encrypted_config + ) + except (ValueError, JSONDecodeError): continue # Get decoding rsa key and cipher for decrypting credentials @@ -1176,7 +1207,7 @@ class ProviderManager: if variable in provider_model_credentials: try: provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_model_credentials.get(variable), + provider_model_credentials.get(variable) or "", self.decoding_rsa_key, self.decoding_cipher_rsa, ) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 9ce91f52ff9..ca530748edf 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,8 +1,5 @@ from typing import TypedDict -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError - from core.model_manager import ModelInstance, ModelManager from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType @@ -11,6 +8,8 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index ed264878d36..242da520c1c 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -139,8 +139,10 @@ class Jieba(BaseKeyword): "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table - keyword_data_source_type = dataset_keyword_table.data_source_type + keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file" if keyword_data_source_type == "database": + if dataset_keyword_table is None: + return dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict) db.session.commit() else: diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 84f35c25f83..1ca6303af65 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,4 +1,5 @@ import re +from collections.abc import Callable from operator import itemgetter from typing import cast @@ -80,12 +81,14 @@ class JiebaKeywordTableHandler: def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs): # Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable. - top_k = kwargs.pop("topK", top_k) + top_k = cast(int | None, kwargs.pop("topK", top_k)) + if top_k is None: + top_k = 20 cut = getattr(jieba, "cut", None) if self._lcut: tokens = self._lcut(sentence) elif callable(cut): - tokens = list(cut(sentence)) + tokens = list(cast(Callable[[str], list[str]], cut)(sentence)) else: tokens = re.findall(r"\w+", sentence) @@ -108,7 +111,7 @@ class JiebaKeywordTableHandler: sentence=text, topK=max_keywords_per_chunk, ) - # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default. keywords = cast(list[str], keywords) return set(self._expand_tokens_with_subtokens(set(keywords))) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index fcbc3ffbfab..2997710daf5 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired, TypedDict from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only @@ -15,7 +14,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor, from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments -from core.rag.entities.metadata_entities import MetadataCondition +from core.rag.entities import MetadataFilteringCondition from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.query_type import QueryType @@ -24,6 +23,7 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, @@ -158,7 +158,7 @@ class RetrievalService: ) if futures: - for future in concurrent.futures.as_completed(futures, timeout=3600): + for _ in concurrent.futures.as_completed(futures, timeout=3600): if exceptions: for f in futures: f.cancel() @@ -174,15 +174,17 @@ class RetrievalService: cls, dataset_id: str, query: str, - external_retrieval_model: dict | None = None, - metadata_filtering_conditions: dict | None = None, + external_retrieval_model: dict[str, Any] | None = None, + metadata_filtering_conditions: dict[str, Any] | None = None, ): stmt = select(Dataset).where(Dataset.id == dataset_id) dataset = db.session.scalar(stmt) if not dataset: return [] metadata_condition = ( - MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None + MetadataFilteringCondition.model_validate(metadata_filtering_conditions) + if metadata_filtering_conditions + else None ) all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( dataset.tenant_id, @@ -193,6 +195,23 @@ class RetrievalService: ) return all_documents + @classmethod + def _filter_documents_by_vector_score_threshold( + cls, documents: list[Document], score_threshold: float | None + ) -> list[Document]: + """Keep documents whose stored retrieval score meets the threshold. + + Used when hybrid search skips early vector thresholding but no rerank + runner applies a threshold afterward (same rule as ``calculate_vector_score``). + """ + if score_threshold is None: + return documents + return [ + document + for document in documents + if document.metadata and document.metadata.get("score", 0) >= score_threshold + ] + @classmethod def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]: """Deduplicate documents in O(n) while preserving first-seen order. @@ -240,7 +259,7 @@ class RetrievalService: @classmethod def _get_dataset(cls, dataset_id: str) -> Dataset | None: with Session(db.engine) as session: - return session.query(Dataset).where(Dataset.id == dataset_id).first() + return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) @classmethod def keyword_search( @@ -292,13 +311,20 @@ class RetrievalService: vector = Vector(dataset=dataset) documents = [] + # Hybrid search merges keyword / full-text / vector hits and then reranks + # (weighted fusion or reranking model). Applying the user score threshold at + # vector retrieval time uses embedding similarity, which is not comparable to + # reranked or fused scores and incorrectly drops high-quality chunks (#35233). + embedding_score_threshold = ( + 0.0 if retrieval_method == RetrievalMethod.HYBRID_SEARCH else score_threshold + ) if query_type == QueryType.TEXT_QUERY: documents.extend( vector.search_by_vector( query, search_type="similarity_score_threshold", top_k=top_k, - score_threshold=score_threshold, + score_threshold=embedding_score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -310,7 +336,7 @@ class RetrievalService: vector.search_by_file( file_id=query, top_k=top_k, - score_threshold=score_threshold, + score_threshold=embedding_score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -573,15 +599,13 @@ class RetrievalService: # Batch query summaries for segments retrieved via summary (only enabled summaries) if summary_segment_ids: - summaries = ( - session.query(DocumentSegmentSummary) - .filter( + summaries = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)), DocumentSegmentSummary.status == "completed", - DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only retrieve enabled summaries ) - .all() - ) + ).all() for summary in summaries: if summary.summary_content: segment_summary_map[summary.chunk_id] = summary.summary_content @@ -844,6 +868,10 @@ class RetrievalService: top_n=top_k, query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, ) + if not data_post_processor.rerank_runner and score_threshold: + all_documents_item = self._filter_documents_by_vector_score_threshold( + all_documents_item, score_threshold + ) all_documents.extend(all_documents_item) @@ -851,12 +879,12 @@ class RetrievalService: def get_segment_attachment_info( cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session ) -> SegmentAttachmentResult | None: - upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == attachment_id).limit(1)) if upload_file: - attachment_binding = ( - session.query(SegmentAttachmentBinding) + attachment_binding = session.scalar( + select(SegmentAttachmentBinding) .where(SegmentAttachmentBinding.attachment_id == upload_file.id) - .first() + .limit(1) ) if attachment_binding: attachment_info: AttachmentInfoDict = { @@ -875,14 +903,12 @@ class RetrievalService: cls, attachment_ids: list[str], session: Session ) -> list[SegmentAttachmentInfoResult]: attachment_infos: list[SegmentAttachmentInfoResult] = [] - upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() + upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all() if upload_files: upload_file_ids = [upload_file.id for upload_file in upload_files] - attachment_bindings = ( - session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids)) - .all() - ) + attachment_bindings = session.scalars( + select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids)) + ).all() attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings} if attachment_bindings: diff --git a/api/core/rag/datasource/vdb/vector_backend_registry.py b/api/core/rag/datasource/vdb/vector_backend_registry.py new file mode 100644 index 00000000000..15f4357cafe --- /dev/null +++ b/api/core/rag/datasource/vdb/vector_backend_registry.py @@ -0,0 +1,87 @@ +"""Vector store backend discovery. + +Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package +declares third-party dependencies and registers ``importlib`` entry points in group +``dify.vector_backends`` (see each package's ``pyproject.toml``). + +Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol +remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``). + +Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a +distribution; entry points take precedence when both exist. + +After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``. +""" + +from __future__ import annotations + +import importlib +import logging +from importlib.metadata import entry_points +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory + +logger = logging.getLogger(__name__) + +_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {} + +# module_path:class_name — optional fallback when no distribution registers the backend. +_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {} + + +def clear_vector_factory_cache() -> None: + """Drop lazily loaded factories (for tests or plugin reload).""" + _VECTOR_FACTORY_CACHE.clear() + + +def _vector_backend_entry_points(): + return entry_points().select(group="dify.vector_backends") + + +def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None: + for ep in _vector_backend_entry_points(): + if ep.name != vector_type: + continue + try: + loaded = ep.load() + except Exception: + logger.exception("Failed to load vector backend entry point %s", ep.name) + raise + return loaded # type: ignore[return-value] + return None + + +def _unsupported(vector_type: str) -> ValueError: + installed = sorted(ep.name for ep in _vector_backend_entry_points()) + available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed." + return ValueError( + f"Vector store {vector_type!r} is not supported.{available_msg} " + "Install a plugin (uv sync --group vdb-all, or vdb- per api/pyproject.toml), " + "or register a dify.vector_backends entry point." + ) + + +def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]: + target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type) + if not target: + raise _unsupported(vector_type) + module_path, _, attr = target.partition(":") + module = importlib.import_module(module_path) + return getattr(module, attr) # type: ignore[no-any-return] + + +def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]: + """Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value.""" + if vector_type in _VECTOR_FACTORY_CACHE: + return _VECTOR_FACTORY_CACHE[vector_type] + + plugin_cls = _load_plugin_factory(vector_type) + if plugin_cls is not None: + _VECTOR_FACTORY_CACHE[vector_type] = plugin_cls + return plugin_cls + + cls = _load_builtin_factory(vector_type) + _VECTOR_FACTORY_CACHE[vector_type] = cls + return cls diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index f29b270e40f..6fbd802a103 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -1,11 +1,20 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any +from typing import Any, TypedDict from core.rag.models.document import Document +class VectorStoreDict(TypedDict): + class_prefix: str + + +class VectorIndexStructDict(TypedDict): + type: str + vector_store: VectorStoreDict + + class BaseVector(ABC): def __init__(self, collection_name: str): self._collection_name = collection_name diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 26531eab886..59d7f3c3c4b 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,12 +4,12 @@ import time from abc import ABC, abstractmethod from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config from core.model_manager import ModelManager -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings @@ -18,6 +18,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile @@ -30,15 +31,34 @@ class AbstractVectorFactory(ABC): raise NotImplementedError @staticmethod - def gen_index_struct_dict(vector_type: VectorType, collection_name: str): - index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} + def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict: + index_struct_dict: VectorIndexStructDict = { + "type": vector_type, + "vector_store": {"class_prefix": collection_name}, + } return index_struct_dict class Vector: def __init__(self, dataset: Dataset, attributes: list | None = None): if attributes is None: - attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + # `is_summary` and `original_chunk_id` are stored on summary vectors + # by `SummaryIndexService` and read back by `RetrievalService` to + # route summary hits through their original parent chunks. They + # must be listed here so vector backends that use this list as an + # explicit return-properties projection (notably Weaviate) actually + # return those fields; without them, summary hits silently + # collapse into `is_summary = False` branches and the summary + # retrieval path is a no-op. See #34884. + attributes = [ + "doc_id", + "dataset_id", + "document_id", + "doc_hash", + "doc_type", + "is_summary", + "original_chunk_id", + ] self._dataset = dataset self._embeddings = self._get_embeddings() self._attributes = attributes @@ -66,137 +86,7 @@ class Vector: @staticmethod def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: - match vector_type: - case VectorType.CHROMA: - from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory - - return ChromaVectorFactory - case VectorType.MILVUS: - from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory - - return MilvusVectorFactory - case VectorType.ALIBABACLOUD_MYSQL: - from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import ( - AlibabaCloudMySQLVectorFactory, - ) - - return AlibabaCloudMySQLVectorFactory - case VectorType.MYSCALE: - from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory - - return MyScaleVectorFactory - case VectorType.PGVECTOR: - from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory - - return PGVectorFactory - case VectorType.VASTBASE: - from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory - - return VastbaseVectorFactory - case VectorType.PGVECTO_RS: - from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory - - return PGVectoRSFactory - case VectorType.QDRANT: - from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory - - return QdrantVectorFactory - case VectorType.RELYT: - from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory - - return RelytVectorFactory - case VectorType.ELASTICSEARCH: - from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory - - return ElasticSearchVectorFactory - case VectorType.ELASTICSEARCH_JA: - from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import ( - ElasticSearchJaVectorFactory, - ) - - return ElasticSearchJaVectorFactory - case VectorType.TIDB_VECTOR: - from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory - - return TiDBVectorFactory - case VectorType.WEAVIATE: - from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory - - return WeaviateVectorFactory - case VectorType.TENCENT: - from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory - - return TencentVectorFactory - case VectorType.ORACLE: - from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory - - return OracleVectorFactory - case VectorType.OPENSEARCH: - from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory - - return OpenSearchVectorFactory - case VectorType.ANALYTICDB: - from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory - - return AnalyticdbVectorFactory - case VectorType.COUCHBASE: - from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory - - return CouchbaseVectorFactory - case VectorType.BAIDU: - from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory - - return BaiduVectorFactory - case VectorType.VIKINGDB: - from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory - - return VikingDBVectorFactory - case VectorType.UPSTASH: - from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory - - return UpstashVectorFactory - case VectorType.TIDB_ON_QDRANT: - from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory - - return TidbOnQdrantVectorFactory - case VectorType.LINDORM: - from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory - - return LindormVectorStoreFactory - case VectorType.OCEANBASE | VectorType.SEEKDB: - from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory - - return OceanBaseVectorFactory - case VectorType.OPENGAUSS: - from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory - - return OpenGaussFactory - case VectorType.TABLESTORE: - from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory - - return TableStoreVectorFactory - case VectorType.HUAWEI_CLOUD: - from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory - - return HuaweiCloudVectorFactory - case VectorType.MATRIXONE: - from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory - - return MatrixoneVectorFactory - case VectorType.CLICKZETTA: - from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory - - return ClickzettaVectorFactory - case VectorType.IRIS: - from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory - - return IrisVectorFactory - case VectorType.HOLOGRES: - from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory - - return HologresVectorFactory - case _: - raise ValueError(f"Vector store {vector_type} is not supported.") + return get_vector_factory_class(vector_type) def create(self, texts: list | None = None, **kwargs): if texts: diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/core/rag/datasource/vdb/vector_integration_test_support.py similarity index 83% rename from api/tests/integration_tests/vdb/test_vector_store.py rename to api/core/rag/datasource/vdb/vector_integration_test_support.py index a033443cf8d..3148b7d5c14 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/core/rag/datasource/vdb/vector_integration_test_support.py @@ -1,10 +1,19 @@ +"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``). + +:class:`AbstractVectorTest` and helper functions live here so package tests can import +``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the +``tests.*`` package. + +The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is +auto-discovered by pytest for all package tests. +""" + import uuid -from unittest.mock import MagicMock import pytest +from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document -from extensions import ext_redis from models.dataset import Dataset @@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document: return doc -@pytest.fixture -def setup_mock_redis(): - # get - ext_redis.redis_client.get = MagicMock(return_value=None) - - # set - ext_redis.redis_client.set = MagicMock(return_value=None) - - # lock - mock_redis_lock = MagicMock() - mock_redis_lock.__enter__ = MagicMock() - mock_redis_lock.__exit__ = MagicMock() - ext_redis.redis_client.lock = mock_redis_lock - - class AbstractVectorTest: + vector: BaseVector + def __init__(self): - self.vector = None self.dataset_id = str(uuid.uuid4()) self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test" self.example_doc_id = str(uuid.uuid4()) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 40f45953af4..f4699f68697 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,13 +3,13 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding @@ -244,7 +244,7 @@ class DatasetDocumentStore: return document_segment def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None): - if multimodel_documents: + if multimodel_documents and self._document_id is not None: for multimodel_document in multimodel_documents: binding = SegmentAttachmentBinding( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 8d1c0da392d..a9995778f70 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,8 +4,6 @@ import pickle from typing import Any, cast import numpy as np -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -15,6 +13,8 @@ from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding @@ -106,7 +106,7 @@ class CacheEmbedding(Embeddings): return text_embeddings - def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: """Embed file documents.""" # use doc embedding cache or store if not exists multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] @@ -232,7 +232,7 @@ class CacheEmbedding(Embeddings): return embedding_results # type: ignore - def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: """Embed multimodal documents.""" # use doc embedding cache or store if not exists file_id = multimodel_document["file_id"] diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py index 1be55bda80d..7ae5c09ab72 100644 --- a/api/core/rag/embedding/embedding_base.py +++ b/api/core/rag/embedding/embedding_base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any class Embeddings(ABC): @@ -10,7 +11,7 @@ class Embeddings(ABC): raise NotImplementedError @abstractmethod - def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: """Embed file documents.""" raise NotImplementedError @@ -20,7 +21,7 @@ class Embeddings(ABC): raise NotImplementedError @abstractmethod - def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: """Embed multimodal query.""" raise NotImplementedError diff --git a/api/core/rag/entities/__init__.py b/api/core/rag/entities/__init__.py new file mode 100644 index 00000000000..373b68894bf --- /dev/null +++ b/api/core/rag/entities/__init__.py @@ -0,0 +1,34 @@ +from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.entities.context_entities import DocumentContext +from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent +from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod +from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator +from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation +from core.rag.entities.retrieval_settings import ( + KeywordSetting, + RerankingModelConfig, + VectorSetting, + WeightedScoreConfig, +) + +__all__ = [ + "Condition", + "DatasourceCompletedEvent", + "DatasourceErrorEvent", + "DatasourceProcessingEvent", + "DocumentContext", + "EconomySetting", + "EmbeddingSetting", + "IndexMethod", + "KeywordSetting", + "MetadataFilteringCondition", + "ParentMode", + "PreProcessingRule", + "RerankingModelConfig", + "RetrievalSourceMetadata", + "Rule", + "Segmentation", + "SupportedComparisonOperator", + "VectorSetting", + "WeightedScoreConfig", +] diff --git a/api/core/rag/entities/index_entities.py b/api/core/rag/entities/index_entities.py new file mode 100644 index 00000000000..f86a04fa9f8 --- /dev/null +++ b/api/core/rag/entities/index_entities.py @@ -0,0 +1,30 @@ +from typing import Literal + +from pydantic import BaseModel + + +class EmbeddingSetting(BaseModel): + """ + Embedding Setting. + """ + + embedding_provider_name: str + embedding_model_name: str + + +class EconomySetting(BaseModel): + """ + Economy Setting. + """ + + keyword_number: int + + +class IndexMethod(BaseModel): + """ + Knowledge Index Setting. + """ + + indexing_technique: Literal["high_quality", "economy"] + embedding_setting: EmbeddingSetting + economy_setting: EconomySetting diff --git a/api/core/rag/entities/metadata_entities.py b/api/core/rag/entities/metadata_entities.py index b07d760cf48..a2ac44807f8 100644 --- a/api/core/rag/entities/metadata_entities.py +++ b/api/core/rag/entities/metadata_entities.py @@ -38,9 +38,9 @@ class Condition(BaseModel): value: str | Sequence[str] | None | int | float = None -class MetadataCondition(BaseModel): +class MetadataFilteringCondition(BaseModel): """ - Metadata Condition. + Metadata Filtering Condition. """ logical_operator: Literal["and", "or"] | None = "and" diff --git a/api/core/rag/entities/processing_entities.py b/api/core/rag/entities/processing_entities.py new file mode 100644 index 00000000000..1b54444a198 --- /dev/null +++ b/api/core/rag/entities/processing_entities.py @@ -0,0 +1,27 @@ +from enum import StrEnum +from typing import Literal + +from pydantic import BaseModel + + +class ParentMode(StrEnum): + FULL_DOC = "full-doc" + PARAGRAPH = "paragraph" + + +class PreProcessingRule(BaseModel): + id: str + enabled: bool + + +class Segmentation(BaseModel): + separator: str = "\n" + max_tokens: int + chunk_overlap: int = 0 + + +class Rule(BaseModel): + pre_processing_rules: list[PreProcessingRule] | None = None + segmentation: Segmentation | None = None + parent_mode: Literal["full-doc", "paragraph"] | None = None + subchunk_segmentation: Segmentation | None = None diff --git a/api/core/rag/entities/retrieval_settings.py b/api/core/rag/entities/retrieval_settings.py new file mode 100644 index 00000000000..8d40ab68fdd --- /dev/null +++ b/api/core/rag/entities/retrieval_settings.py @@ -0,0 +1,51 @@ +from pydantic import BaseModel, ConfigDict, Field + + +class RerankingModelConfig(BaseModel): + """ + Canonical reranking model configuration. + + Accepts both naming conventions: + - reranking_provider_name / reranking_model_name (services layer) + - provider / model (workflow layer via validation_alias) + """ + + model_config = ConfigDict(populate_by_name=True) + + reranking_provider_name: str = Field(validation_alias="provider") + reranking_model_name: str = Field(validation_alias="model") + + @property + def provider(self) -> str: + return self.reranking_provider_name + + @property + def model(self) -> str: + return self.reranking_model_name + + +class VectorSetting(BaseModel): + """ + Vector Setting. + """ + + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSetting(BaseModel): + """ + Keyword Setting. + """ + + keyword_weight: float + + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + + vector_setting: VectorSetting + keyword_setting: KeywordSetting diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 3bfae9d6bd9..19bc9cec846 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,6 +1,7 @@ """Abstract interface for document loader implementations.""" import csv +from typing import Any import pandas as pd @@ -23,7 +24,7 @@ class CSVExtractor(BaseExtractor): encoding: str | None = None, autodetect_encoding: bool = False, source_column: str | None = None, - csv_args: dict | None = None, + csv_args: dict[str, Any] | None = None, ): """Initialize with file path.""" self._file_path = file_path diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 449be6a448f..b679edab367 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -94,16 +94,18 @@ class ExtractProcessor: cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE: + upload_file = extract_setting.upload_file with tempfile.TemporaryDirectory() as temp_dir: + upload_file = extract_setting.upload_file if not file_path: - assert extract_setting.upload_file is not None, "upload_file is required" - upload_file: UploadFile = extract_setting.upload_file + assert upload_file is not None, "upload_file is required" suffix = Path(upload_file.key).suffix # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() + assert upload_file is not None, "upload_file is required" etl_type = dify_config.ETL_TYPE extractor: BaseExtractor | None = None if etl_type == "Unstructured": @@ -113,6 +115,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": + assert upload_file is not None extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = ( @@ -123,6 +126,7 @@ class ExtractProcessor: elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": + assert upload_file is not None extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".doc": extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key) @@ -149,12 +153,14 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": + assert upload_file is not None extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": + assert upload_file is not None extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 89bdd56a6c9..556158cf00a 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -174,21 +174,25 @@ class FirecrawlApp: return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}" def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response: + response: httpx.Response | None = None for attempt in range(retries): response = httpx.post(url, headers=headers, json=data) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response + assert response is not None, "retries must be at least 1" return response def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response: + response: httpx.Response | None = None for attempt in range(retries): response = httpx.get(url, headers=headers) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response + assert response is not None, "retries must be at least 1" return response def _handle_error(self, response, action): diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 7dd8beaa462..f9fbfbc409a 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -19,12 +19,15 @@ class UnstructuredWordExtractor(BaseExtractor): def extract(self) -> list[Document]: from unstructured.__version__ import __version__ as __unstructured_version__ - from unstructured.file_utils.filetype import FileType, detect_filetype + from unstructured.file_utils.filetype import ( # pyright: ignore[reportPrivateImportUsage] + FileType, + detect_filetype, + ) unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: - import magic # noqa: F401 + import magic # noqa: F401 # pyright: ignore[reportUnusedImport] is_doc = detect_filetype(self._file_path) == FileType.DOC except ImportError: diff --git a/api/core/rag/extractor/watercrawl/client.py b/api/core/rag/extractor/watercrawl/client.py index 7b4a388df93..d1ce142dbdd 100644 --- a/api/core/rag/extractor/watercrawl/client.py +++ b/api/core/rag/extractor/watercrawl/client.py @@ -54,8 +54,8 @@ class BaseAPIClient: self, method: str, endpoint: str, - query_params: dict | None = None, - data: dict | None = None, + query_params: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, **kwargs, ) -> Response: stream = kwargs.pop("stream", False) @@ -66,19 +66,25 @@ class BaseAPIClient: return self.session.request(method, url, params=query_params, json=data, **kwargs) - def _get(self, endpoint: str, query_params: dict | None = None, **kwargs): + def _get(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs): return self._request("GET", endpoint, query_params=query_params, **kwargs) - def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): + def _post( + self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs + ): return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs) - def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): + def _put( + self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs + ): return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs) - def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs): + def _delete(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs): return self._request("DELETE", endpoint, query_params=query_params, **kwargs) - def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): + def _patch( + self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs + ): return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs) @@ -99,7 +105,7 @@ class WaterCrawlAPIClient(BaseAPIClient): finally: response.close() - def process_response(self, response: Response) -> dict | bytes | list | None | Generator: + def process_response(self, response: Response) -> dict[str, Any] | bytes | list[Any] | None | Generator: if response.status_code == 401: raise WaterCrawlAuthenticationError(response) @@ -186,7 +192,7 @@ class WaterCrawlAPIClient(BaseAPIClient): yield from generator def get_crawl_request_results( - self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None + self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict[str, Any] | None = None ): query_params = query_params or {} query_params.update({"page": page or 1, "page_size": page_size or 25}) @@ -210,7 +216,7 @@ class WaterCrawlAPIClient(BaseAPIClient): if event_data["type"] == "result": return event_data["data"] - def download_result(self, result_object: dict): + def download_result(self, result_object: dict[str, Any]): response = httpx.get(result_object["result"], timeout=None) try: response.raise_for_status() diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index 2a9403eda0f..ae7bebcb9bb 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -120,7 +120,7 @@ class WaterCrawlProvider: } def _get_results( - self, crawl_request_id: str, query_params: dict | None = None + self, crawl_request_id: str, query_params: dict[str, Any] | None = None ) -> Generator[WatercrawlDocumentData, None, None]: page = 0 page_size = 100 diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 052fca930da..0330a43b284 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -3,6 +3,7 @@ Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). """ +import inspect import logging import mimetypes import os @@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor): file_path: Path to the file to load. """ + _closed: bool + def __init__(self, file_path: str, tenant_id: str, user_id: str): """Initialize with file path.""" + self._closed = False self.file_path = file_path self.tenant_id = tenant_id self.user_id = user_id @@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor): elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") + def close(self) -> None: + """Best-effort cleanup for downloaded temporary files.""" + if getattr(self, "_closed", False): + return + + self._closed = True + temp_file = getattr(self, "temp_file", None) + if temp_file is None: + return + + try: + close_result = temp_file.close() + if inspect.isawaitable(close_result): + close_awaitable = getattr(close_result, "close", None) + if callable(close_awaitable): + close_awaitable() + except Exception: + logger.debug("Failed to cleanup downloaded word temp file", exc_info=True) + def __del__(self): - if hasattr(self, "temp_file"): - self.temp_file.close() + self.close() def extract(self) -> list[Document]: """Load given path as single page.""" diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index 825ae012269..aded5315bd8 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -6,13 +6,13 @@ from collections.abc import Mapping from typing import Any from flask import current_app -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError -from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview +from core.workflow.nodes.knowledge_index.protocols import IndexingResultDict, Preview, PreviewItem, QaPreview from models.dataset import Dataset, Document, DocumentSegment from .index_processor_factory import IndexProcessorFactory @@ -61,13 +61,13 @@ class IndexProcessor: chunks: Mapping[str, Any], batch: Any, summary_index_setting: SummaryIndexSettingDict | None = None, - ): + ) -> IndexingResultDict: with session_factory.create_session() as session: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not document: raise KnowledgeIndexNodeError(f"Document {document_id} not found.") - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") @@ -104,12 +104,12 @@ class IndexProcessor: document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.word_count = ( - session.query(func.sum(DocumentSegment.word_count)) - .where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, + session.scalar( + select(func.sum(DocumentSegment.word_count)).where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) ) - .scalar() ) or 0 # Update need_summary based on dataset's summary_index_setting if summary_index_setting and summary_index_setting.get("enable") is True: @@ -118,18 +118,20 @@ class IndexProcessor: document.need_summary = False session.add(document) # update document segment status - session.query(DocumentSegment).where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, - ).update( - { - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - } + session.execute( + update(DocumentSegment) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) + .values( + status="completed", + enabled=True, + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + ) ) - return { + result: IndexingResultDict = { "dataset_id": dataset_id, "dataset_name": dataset_name_value, "batch": batch, @@ -138,6 +140,7 @@ class IndexProcessor: "created_at": created_at_value.timestamp(), "display_status": "completed", } + return result def get_preview_output( self, @@ -150,11 +153,11 @@ class IndexProcessor: doc_language = None with session_factory.create_session() as session: if document_id: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) else: document = None - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 22ab492cbf3..7ffa9afafda 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -3,21 +3,10 @@ import logging import re import uuid -from collections.abc import Mapping -from typing import Any, cast +from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) -from graphon.file import File, FileTransferMethod, FileType, file_manager -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -32,6 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore +from core.rag.entities import Rule from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType @@ -43,18 +33,33 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.account_service import AccountService -from services.entities.knowledge_entities.knowledge_entities import Rule from services.summary_index_service import SummaryIndexService _file_access_controller = DatabaseFileAccessController() +class ParagraphFormatPreviewDict(TypedDict): + chunk_structure: str + preview: list[dict[str, Any]] + total_segments: int + + class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -266,16 +271,17 @@ class ParagraphIndexProcessor(BaseIndexProcessor): keyword = Keyword(dataset) keyword.add_texts(documents) - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> ParagraphFormatPreviewDict: if isinstance(chunks, list): preview = [] for content in chunks: preview.append({"content": content}) - return { + result: ParagraphFormatPreviewDict = { "chunk_structure": IndexStructureType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks), } + return result else: raise ValueError("Chunks is not a list") @@ -603,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): try: # Create File object directly (similar to DatasetRetrieval) file_obj = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 1c5e02e9c8f..ba277d50186 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -3,8 +3,7 @@ import json import logging import uuid -from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict from sqlalchemy import delete, select @@ -17,6 +16,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore +from core.rag.entities import ParentMode, Rule from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType @@ -30,12 +30,18 @@ from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from services.account_service import AccountService -from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) +class ParentChildFormatPreviewDict(TypedDict): + chunk_structure: str + parent_mode: str + preview: list[dict[str, Any]] + total_segments: int + + class ParentChildIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -153,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if node_ids: # Find segments by index_node_id with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment) - .filter( + segments = session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) @@ -351,17 +355,18 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if all_multimodal_documents and dataset.is_multimodal: vector.create_multimodal(all_multimodal_documents) - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> ParentChildFormatPreviewDict: parent_childs = ParentChildStructureChunk.model_validate(chunks) preview = [] for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) - return { + result: ParentChildFormatPreviewDict = { "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "parent_mode": parent_childs.parent_mode, "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), } + return result def generate_summary_preview( self, diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 6874603a833..d3f311b08ea 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -4,11 +4,11 @@ import logging import re import threading import uuid -from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict import pandas as pd from flask import Flask, current_app +from sqlalchemy import select from werkzeug.datastructures import FileStorage from core.db.session_factory import session_factory @@ -19,6 +19,7 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore +from core.rag.entities import Rule from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -30,12 +31,17 @@ from libs import helper from models.account import Account from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument -from services.entities.knowledge_entities.knowledge_entities import Rule from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) +class QAFormatPreviewDict(TypedDict): + chunk_structure: str + qa_preview: list[dict[str, Any]] + total_segments: int + + class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( @@ -158,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor): if node_ids: # Find segments by index_node_id with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment) - .filter( + segments = session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) @@ -230,16 +234,17 @@ class QAIndexProcessor(BaseIndexProcessor): else: raise ValueError("Indexing technique must be high quality.") - def format_preview(self, chunks: Any) -> Mapping[str, Any]: + def format_preview(self, chunks: Any) -> QAFormatPreviewDict: qa_chunks = QAStructureChunk.model_validate(chunks) preview = [] for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) - return { + result: QAFormatPreviewDict = { "chunk_structure": IndexStructureType.QA_INDEX, "qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks), } + return result def generate_summary_preview( self, diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 087736d0b0a..4ebf0959042 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any -from graphon.file import File from pydantic import BaseModel, Field +from graphon.file import File + class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" diff --git a/api/core/rag/rerank/entity/weight.py b/api/core/rag/rerank/entity/weight.py index 6dbbad2f8da..54392a03236 100644 --- a/api/core/rag/rerank/entity/weight.py +++ b/api/core/rag/rerank/entity/weight.py @@ -1,16 +1,6 @@ from pydantic import BaseModel - -class VectorSetting(BaseModel): - vector_weight: float - - embedding_provider_name: str - - embedding_model_name: str - - -class KeywordSetting(BaseModel): - keyword_weight: float +from core.rag.entities import KeywordSetting, VectorSetting class Weights(BaseModel): diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 8283be19f9b..bce08f998f3 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,8 +1,5 @@ import base64 -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import RerankResult - from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType @@ -10,6 +7,8 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from models.model import UploadFile @@ -123,7 +122,7 @@ class RerankModelRunner(BaseRerankRunner): :param query_type: query type :return: rerank result """ - docs = [] + docs: list[MultimodalRerankInput] = [] doc_ids = set() unique_documents = [] for document in documents: @@ -138,26 +137,28 @@ class RerankModelRunner(BaseRerankRunner): if upload_file: blob = storage.load_once(upload_file.key) document_file_base64 = base64.b64encode(blob).decode() - document_file_dict = { - "content": document_file_base64, - "content_type": document.metadata["doc_type"], - } - docs.append(document_file_dict) + docs.append( + MultimodalRerankInput( + content=document_file_base64, + content_type=document.metadata["doc_type"], + ) + ) else: - document_text_dict = { - "content": document.page_content, - "content_type": document.metadata.get("doc_type") or DocType.TEXT, - } - docs.append(document_text_dict) + docs.append( + MultimodalRerankInput( + content=document.page_content, + content_type=document.metadata.get("doc_type") or DocType.TEXT, + ) + ) doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) elif document.provider == "external": if document not in unique_documents: docs.append( - { - "content": document.page_content, - "content_type": document.metadata.get("doc_type") or DocType.TEXT, - } + MultimodalRerankInput( + content=document.page_content, + content_type=document.metadata.get("doc_type") or DocType.TEXT, + ) ) unique_documents.append(document) @@ -171,12 +172,12 @@ class RerankModelRunner(BaseRerankRunner): if upload_file: blob = storage.load_once(upload_file.key) file_query = base64.b64encode(blob).decode() - file_query_dict = { - "content": file_query, - "content_type": DocType.IMAGE, - } + file_query_input = MultimodalRerankInput( + content=file_query, + content_type=DocType.IMAGE, + ) rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( - query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n + query=file_query_input, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents else: diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 49123e13d05..d0732b269af 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,7 +2,6 @@ import math from collections import Counter import numpy as np -from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -12,6 +11,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner +from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 593e1f1420a..5631b3a921f 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,13 +9,8 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app -from graphon.file import File, FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from sqlalchemy import and_, func, literal, or_, select -from sqlalchemy.orm import Session +from sqlalchemy import and_, func, literal, or_, select, update +from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import ( DatasetEntity, @@ -39,9 +34,7 @@ from core.prompt.simple_prompt_transform import ModelMode from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService -from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.rag.entities.context_entities import DocumentContext -from core.rag.entities.metadata_entities import Condition, MetadataCondition +from core.rag.entities import Condition, DocumentContext, RetrievalSourceMetadata from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType @@ -71,6 +64,11 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( ) from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile @@ -278,8 +276,8 @@ class DatasetRetrieval: document_ids = [i.segment.document_id for i in records] with session_factory.create_session() as session: - datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() - documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all() + datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() + documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all() dataset_map = {i.id: i for i in datasets} document_map = {i.id: i for i in documents} @@ -519,11 +517,11 @@ class DatasetRetrieval: if attachments_with_bindings: for _, upload_file in attachments_with_bindings: attachment_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( @@ -604,7 +602,7 @@ class DatasetRetrieval: planning_strategy: PlanningStrategy, message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, - metadata_condition: MetadataCondition | None = None, + metadata_condition: MetadataFilteringCondition | None = None, ): tools = [] for dataset in available_datasets: @@ -743,7 +741,7 @@ class DatasetRetrieval: reranking_enable: bool = True, message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, - metadata_condition: MetadataCondition | None = None, + metadata_condition: MetadataFilteringCondition | None = None, attachment_ids: list[str] | None = None, ): if not available_datasets: @@ -877,7 +875,11 @@ class DatasetRetrieval: return retrieval_resource_list def _on_retrieval_end( - self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None + self, + flask_app: Flask, + documents: list[Document], + message_id: str | None = None, + timer: dict[str, Any] | None = None, ): """Handle retrieval end.""" with flask_app.app_context(): @@ -886,7 +888,7 @@ class DatasetRetrieval: self._send_trace_task(message_id, documents, timer) return - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # Collect all document_ids and batch fetch DatasetDocuments document_ids = { doc.metadata["document_id"] @@ -973,15 +975,16 @@ class DatasetRetrieval: # Batch update hit_count for all segments if segment_ids_to_update: - session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, + session.execute( + update(DocumentSegment) + .where(DocumentSegment.id.in_(segment_ids_to_update)) + .values(hit_count=DocumentSegment.hit_count + 1) + .execution_options(synchronize_session=False) ) - session.commit() self._send_trace_task(message_id, documents, timer) - def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None): + def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict[str, Any] | None): """Send trace task if trace manager is available.""" trace_manager: TraceQueueManager | None = ( self.application_generate_entity.trace_manager if self.application_generate_entity else None @@ -1063,7 +1066,7 @@ class DatasetRetrieval: top_k: int, all_documents: list[Document], document_ids_filter: list[str] | None = None, - metadata_condition: MetadataCondition | None = None, + metadata_condition: MetadataFilteringCondition | None = None, attachment_ids: list[str] | None = None, ): with flask_app.app_context(): @@ -1143,7 +1146,7 @@ class DatasetRetrieval: invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, user_id: str, - inputs: dict, + inputs: dict[str, Any], ) -> list[DatasetRetrieverBaseTool] | None: """ A dataset tool is a tool that can be used to retrieve information from a dataset @@ -1338,8 +1341,8 @@ class DatasetRetrieval: metadata_filtering_mode: str, metadata_model_config: ModelConfig, metadata_filtering_conditions: MetadataFilteringCondition | None, - inputs: dict, - ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None]: + inputs: dict[str, Any], + ) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]: document_query = select(DatasetDocument).where( DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.indexing_status == "completed", @@ -1371,7 +1374,7 @@ class DatasetRetrieval: value=filter.get("value"), ) ) - metadata_condition = MetadataCondition( + metadata_condition = MetadataFilteringCondition( logical_operator=metadata_filtering_conditions.logical_operator if metadata_filtering_conditions else "or", # type: ignore @@ -1400,7 +1403,7 @@ class DatasetRetrieval: expected_value, filters, ) - metadata_condition = MetadataCondition( + metadata_condition = MetadataFilteringCondition( logical_operator=metadata_filtering_conditions.logical_operator, conditions=conditions, ) @@ -1418,7 +1421,7 @@ class DatasetRetrieval: metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore return metadata_filter_document_ids, metadata_condition - def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str: + def _replace_metadata_filter_value(self, text: str, inputs: dict[str, Any]) -> str: if not inputs: return text @@ -1723,7 +1726,7 @@ class DatasetRetrieval: self, flask_app: Flask, available_datasets: list[Dataset], - metadata_condition: MetadataCondition | None, + metadata_condition: MetadataFilteringCondition | None, metadata_filter_document_ids: dict[str, list[str]] | None, all_documents: list[Document], tenant_id: str, @@ -1825,7 +1828,7 @@ class DatasetRetrieval: def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]: with session_factory.create_session() as session: subquery = ( - session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count")) + select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count")) .where( DocumentModel.indexing_status == "completed", DocumentModel.enabled == True, @@ -1837,13 +1840,12 @@ class DatasetRetrieval: .subquery() ) - results = ( - session.query(Dataset) + results = session.scalars( + select(Dataset) .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) .where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids)) .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) - .all() - ) + ).all() available_datasets = [] for dataset in results: diff --git a/api/core/rag/retrieval/output_parser/react_output.py b/api/core/rag/retrieval/output_parser/react_output.py index 9a14d417164..29abae42803 100644 --- a/api/core/rag/retrieval/output_parser/react_output.py +++ b/api/core/rag/retrieval/output_parser/react_output.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import NamedTuple, Union +from typing import Any, NamedTuple, Union @dataclass @@ -10,7 +10,7 @@ class ReactAction: tool: str """The name of the Tool to execute.""" - tool_input: Union[str, dict] + tool_input: Union[str, dict[str, Any]] """The input to pass in to the Tool.""" log: str """Additional information to log about the action.""" @@ -19,7 +19,7 @@ class ReactAction: class ReactFinish(NamedTuple): """The final return value of an ReactFinish.""" - return_values: dict + return_values: dict[str, Any] """Dictionary of return values.""" log: str """Additional information to log about the return value""" diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index dce7b6226ce..426d1b67dcd 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,10 +1,9 @@ from typing import Union -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: @@ -29,7 +28,7 @@ class FunctionCallMultiDatasetRouter: SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result: LLMResult = model_instance.invoke_llm( + result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType] prompt_messages=prompt_messages, tools=dataset_tools, stream=False, diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index dd280cdf6a7..21a9d04f7f2 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,9 +1,5 @@ from collections.abc import Generator, Sequence -from typing import Union - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelType +from typing import Any, Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota @@ -12,6 +8,9 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -139,7 +138,7 @@ class ReactMultiDatasetRouter: def _invoke_llm( self, - completion_param: dict, + completion_param: dict[str, Any], model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: list[str], diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 3383c7f3bd7..52c9a02f97f 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -4,13 +4,12 @@ from __future__ import annotations import codecs import re -from collections.abc import Collection +from collections.abc import Set as AbstractSet from typing import Any, Literal -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer - from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter +from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @@ -22,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder[T: EnhanceRecursiveCharacterTextSplitter]( cls: type[T], embedding_model_instance: ModelInstance | None, - allowed_special: Literal["all"] | set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ) -> T: def _token_encoder(texts: list[str]) -> list[int]: @@ -41,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return [len(text) for text in texts] + _ = _token_encoder # kept for future token-length wiring return cls(length_function=_character_encoder, **kwargs) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 8977611f938..a8d9013fbc0 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -4,7 +4,8 @@ import copy import logging import re from abc import ABC, abstractmethod -from collections.abc import Callable, Collection, Iterable, Sequence, Set +from collections.abc import Callable, Iterable, Sequence +from collections.abc import Set as AbstractSet from dataclasses import dataclass from typing import Any, Literal @@ -63,7 +64,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: list[dict[str, Any]] | None = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter): self, encoding_name: str = "gpt2", model_name: str | None = None, - allowed_special: Literal["all"] | Set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ): """Create a new TextSplitter.""" @@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter): else: enc = tiktoken.get_encoding(encoding_name) self._tokenizer = enc - self._allowed_special = allowed_special - self._disallowed_special = disallowed_special + self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special + self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special def split_text(self, text: str) -> list[str]: def _encode(_text: str) -> list[int]: diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 6f120bd4711..bff5f85decb 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -1,6 +1,8 @@ import concurrent.futures import logging +from sqlalchemy import select + from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict @@ -21,7 +23,7 @@ class SummaryIndex: ) -> None: if is_preview: with session_factory.create_session() as session: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return @@ -34,32 +36,31 @@ class SummaryIndex: if not document_id: return - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) # Skip qa_model documents if document is None or document.doc_form == "qa_model": return - query = session.query(DocumentSegment).filter_by( - dataset_id=dataset_id, - document_id=document_id, - status="completed", - enabled=True, - ) - segments = query.all() + segments = session.scalars( + select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + ).all() segment_ids = [segment.id for segment in segments] if not segment_ids: return - existing_summaries = ( - session.query(DocumentSegmentSummary) - .filter( + existing_summaries = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset_id, DocumentSegmentSummary.status == "completed", ) - .all() - ) + ).all() completed_summary_segment_ids = {i.chunk_id for i in existing_summaries} # Preview mode should process segments that are MISSING completed summaries pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids] @@ -73,7 +74,7 @@ class SummaryIndex: def process_segment(segment_id: str) -> None: """Process a single segment in a thread with a fresh DB session.""" with session_factory.create_session() as session: - segment = session.query(DocumentSegment).filter_by(id=segment_id).first() + segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)) if segment is None: return try: diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index b07c63fdf0e..e87d1cd6b22 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -7,11 +7,11 @@ providing improved performance by offloading database operations to background w import logging -from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index cdb3af01a8a..24515633175 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -8,7 +8,6 @@ providing improved performance by offloading database operations to background w import logging from collections.abc import Sequence -from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,6 +15,7 @@ from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) +from graphon.entities import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index ce3ad15759f..4e83e707997 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -9,11 +9,11 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol -from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 72d93941498..740d727e268 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,13 +4,11 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Protocol -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, @@ -19,6 +17,8 @@ from core.workflow.human_input_compat import ( InteractiveSurfaceDeliveryMethod, is_human_input_webapp_enabled, ) +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index d74cc8f231e..6be39023171 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -5,13 +5,13 @@ SQLAlchemy implementation of the WorkflowExecutionRepository. import json import logging -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 13e885672a1..b036687bc99 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,10 +10,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any import psycopg2.errors -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -23,6 +19,10 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att from configs import dify_config from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( diff --git a/api/core/schemas/resolver.py b/api/core/schemas/resolver.py index 6e26664ac25..e267c1abd9a 100644 --- a/api/core/schemas/resolver.py +++ b/api/core/schemas/resolver.py @@ -254,7 +254,7 @@ def resolve_dify_schema_refs( return resolver.resolve(schema) -def _remove_metadata_fields(schema: dict) -> dict: +def _remove_metadata_fields(schema: dict[str, Any]) -> dict[str, Any]: """ Remove metadata fields from schema that shouldn't be included in resolved output diff --git a/api/core/telemetry/gateway.py b/api/core/telemetry/gateway.py index 7b013d05638..812edeeb14a 100644 --- a/api/core/telemetry/gateway.py +++ b/api/core/telemetry/gateway.py @@ -89,7 +89,7 @@ def _get_case_routing() -> dict[TelemetryCase, CaseRoute]: return _case_routing -def __getattr__(name: str) -> dict: +def __getattr__(name: str) -> Any: """Lazy module-level access to routing tables.""" if name == "CASE_ROUTING": return _get_case_routing() diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 7bb2cdb8764..ab0f73a9a22 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -198,7 +198,7 @@ class Tool(ABC): message=ToolInvokeMessage.TextMessage(text=text), ) - def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage: + def create_blob_message(self, blob: bytes, meta: dict[str, Any] | None = None) -> ToolInvokeMessage: """ create a blob message @@ -212,7 +212,7 @@ class Tool(ABC): meta=meta, ) - def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage: + def create_json_message(self, object: dict[str, Any], suppress_output: bool = False) -> ToolInvokeMessage: """ create a json message """ diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index e5390743036..95660ab93b4 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,15 +2,14 @@ import io from collections.abc import Generator from typing import Any -from graphon.file import FileType -from graphon.file.file_manager import download -from graphon.model_runtime.entities.model_entities import ModelType - from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index f49c669fe09..ac3820f1aba 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,13 +2,12 @@ import io from collections.abc import Generator from typing import Any -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType - from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 14af63a962e..d41503e1e61 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,12 +1,11 @@ from __future__ import annotations -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage - from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 0a2c37c5632..168e5f4493a 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,7 +6,6 @@ from typing import Any, Union from urllib.parse import urlencode import httpx -from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -14,6 +13,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError +from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index d5d3d1b1d95..42a88c00039 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,7 +2,6 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -10,6 +9,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): @@ -75,22 +75,27 @@ class ToolProviderApiEntity(BaseModel): parameter.pop("input_schema", None) # ------------- optional_fields = self.optional_field("server_url", self.server_url) - if self.type == ToolProviderType.MCP: - optional_fields.update(self.optional_field("updated_at", self.updated_at)) - optional_fields.update(self.optional_field("server_identifier", self.server_identifier)) - optional_fields.update( - self.optional_field( - "configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration() + match self.type: + case ToolProviderType.MCP: + optional_fields.update(self.optional_field("updated_at", self.updated_at)) + optional_fields.update(self.optional_field("server_identifier", self.server_identifier)) + optional_fields.update( + self.optional_field( + "configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration() + ) ) - ) - optional_fields.update( - self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None) - ) - optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration)) - optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) - optional_fields.update(self.optional_field("original_headers", self.original_headers)) - elif self.type == ToolProviderType.WORKFLOW: - optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id)) + optional_fields.update( + self.optional_field( + "authentication", self.authentication.model_dump() if self.authentication else None + ) + ) + optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration)) + optional_fields.update(self.optional_field("masked_headers", self.masked_headers)) + optional_fields.update(self.optional_field("original_headers", self.original_headers)) + case ToolProviderType.WORKFLOW: + optional_fields.update(self.optional_field("workflow_app_id", self.workflow_app_id)) + case _: + pass return { "id": self.id, "author": self.author, diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 21d310bbb9f..83a042ed631 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -1,6 +1,15 @@ +from typing import TypedDict + from pydantic import BaseModel, Field, model_validator +class I18nObjectDict(TypedDict): + zh_Hans: str | None + en_US: str + pt_BR: str | None + ja_JP: str | None + + class I18nObject(BaseModel): """ Model class for i18n object. @@ -18,5 +27,11 @@ class I18nObject(BaseModel): self.ja_JP = self.ja_JP or self.en_US return self - def to_dict(self): - return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} + def to_dict(self) -> I18nObjectDict: + result: I18nObjectDict = { + "zh_Hans": self.zh_Hans, + "en_US": self.en_US, + "pt_BR": self.pt_BR, + "ja_JP": self.ja_JP, + } + return result diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 10710c43767..4e07b7157a8 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,4 +1,5 @@ from collections.abc import Mapping +from typing import Any from pydantic import BaseModel, Field @@ -26,6 +27,6 @@ class ApiToolBundle(BaseModel): # icon icon: str | None = None # openapi operation - openapi: dict + openapi: dict[str, Any] # output schema output_schema: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 96268d029ed..0c77693dde4 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -6,9 +6,20 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any, Union -from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + TypeAdapter, + ValidationInfo, + field_serializer, + field_validator, + model_validator, +) +from typing_extensions import TypedDict from core.entities.provider_entities import ProviderConfig +from core.plugin.entities import OAuthSchema from core.plugin.entities.parameters import ( MCPServerParameterType, PluginParameter, @@ -18,11 +29,19 @@ from core.plugin.entities.parameters import ( cast_parameter_value, init_frontend_parameter, ) -from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.entities import RetrievalSourceMetadata from core.tools.entities.common_entities import I18nObject from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY +class EmojiIconDict(TypedDict): + background: str + content: str + + +emoji_icon_adapter: TypeAdapter[EmojiIconDict] = TypeAdapter(EmojiIconDict) + + class ToolLabelEnum(StrEnum): SEARCH = "search" IMAGE = "image" @@ -130,7 +149,7 @@ class ToolInvokeMessage(BaseModel): text: str class JsonMessage(BaseModel): - json_object: dict | list + json_object: dict[str, Any] | list[Any] suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string") class BlobMessage(BaseModel): @@ -318,7 +337,7 @@ class ToolParameter(PluginParameter): form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: str | None = None # MCP object and array type parameters use this field to store the schema - input_schema: dict | None = None + input_schema: dict[str, Any] | None = None @classmethod def get_simple_instance( @@ -410,15 +429,6 @@ class ToolEntity(BaseModel): return value or {} -class OAuthSchema(BaseModel): - client_schema: list[ProviderConfig] = Field( - default_factory=list[ProviderConfig], description="The schema of the OAuth client" - ) - credentials_schema: list[ProviderConfig] = Field( - default_factory=list[ProviderConfig], description="The schema of the OAuth credentials" - ) - - class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity plugin_id: str | None = None @@ -440,6 +450,12 @@ class WorkflowToolParameterConfiguration(BaseModel): form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") +class ToolInvokeMetaDict(TypedDict): + time_cost: float + error: str | None + tool_config: dict[str, Any] | None + + class ToolInvokeMeta(BaseModel): """ Tool invoke meta @@ -447,7 +463,7 @@ class ToolInvokeMeta(BaseModel): time_cost: float = Field(..., description="The time cost of the tool invoke") error: str | None = None - tool_config: dict | None = None + tool_config: dict[str, Any] | None = None @classmethod def empty(cls) -> ToolInvokeMeta: @@ -463,12 +479,13 @@ class ToolInvokeMeta(BaseModel): """ return cls(time_cost=0.0, error=error, tool_config={}) - def to_dict(self): - return { + def to_dict(self) -> ToolInvokeMetaDict: + result: ToolInvokeMetaDict = { "time_cost": self.time_cost, "error": self.error, "tool_config": self.tool_config, } + return result class ToolLabel(BaseModel): diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 4c3efd6ff9c..2b26832b44a 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -38,6 +38,17 @@ class ToolCredentialPolicyViolationError(ValueError): pass +class ApiToolProviderNotFoundError(ValueError): + error_code = "api_tool_provider_not_found" + provider_name: str + tenant_id: str + + def __init__(self, provider_name: str, tenant_id: str): + self.provider_name = provider_name + self.tenant_id = tenant_id + super().__init__(f"api provider {provider_name} does not exist") + + class WorkflowToolHumanInputNotSupportedError(BaseHTTPException): error_code = "workflow_tool_human_input_not_supported" description = "Workflow with Human Input nodes cannot be published as a workflow tool." diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index f6d09472b3c..00fc8a82827 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,8 +6,6 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata - from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -23,6 +21,7 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 685d687d8c4..3caacb87067 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,7 +7,6 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast -from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -33,6 +32,7 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.file import FileTransferMethod, FileType from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile @@ -47,7 +47,7 @@ class ToolEngine: @staticmethod def agent_invoke( tool: Tool, - tool_parameters: Union[str, dict], + tool_parameters: Union[str, dict[str, Any]], user_id: str, tenant_id: str, message: Message, @@ -85,7 +85,8 @@ class ToolEngine: invocation_meta_dict: dict[str, ToolInvokeMeta] = {} def message_callback( - invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None] + invocation_meta_dict: dict[str, ToolInvokeMeta], + messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None], ): for message in messages: if isinstance(message, ToolInvokeMeta): @@ -200,7 +201,7 @@ class ToolEngine: @staticmethod def _invoke( tool: Tool, - tool_parameters: dict, + tool_parameters: dict[str, Any], user_id: str, conversation_id: str | None = None, app_id: str | None = None, @@ -262,6 +263,8 @@ class ToolEngine: ensure_ascii=False, ) ) + elif response.type == ToolInvokeMessage.MessageType.VARIABLE: + continue else: parts.append(str(response.message)) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 7ac29cf0698..c87e8a3ae08 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -6,17 +6,17 @@ import os import time from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Union from uuid import uuid4 import httpx -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type +from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile @@ -28,7 +28,7 @@ class ToolFileManager: def _build_graph_file_reference(tool_file: ToolFile) -> File: extension = guess_extension(tool_file.mimetype) or ".bin" return File( - type=get_file_type_by_mime_type(tool_file.mimetype), + file_type=get_file_type_by_mime_type(tool_file.mimetype), transfer_method=FileTransferMethod.TOOL_FILE, remote_url=tool_file.original_url, reference=build_file_reference(record_id=str(tool_file.id)), @@ -157,7 +157,7 @@ class ToolFileManager: return tool_file - def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary(self, id: str) -> tuple[bytes, str] | None: """ get file binary @@ -166,13 +166,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ with session_factory.create_session() as session: - tool_file: ToolFile | None = ( - session.query(ToolFile) - .where( - ToolFile.id == id, - ) - .first() - ) + tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == id).limit(1)) if not tool_file: return None @@ -181,7 +175,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None: """ get file binary @@ -190,13 +184,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ with session_factory.create_session() as session: - message_file: MessageFile | None = ( - session.query(MessageFile) - .where( - MessageFile.id == id, - ) - .first() - ) + message_file: MessageFile | None = session.scalar(select(MessageFile).where(MessageFile.id == id).limit(1)) # Check if message_file is not None if message_file is not None: @@ -210,13 +198,7 @@ class ToolFileManager: else: tool_file_id = None - tool_file: ToolFile | None = ( - session.query(ToolFile) - .where( - ToolFile.id == tool_file_id, - ) - .first() - ) + tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1)) if not tool_file: return None @@ -234,13 +216,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ with session_factory.create_session() as session: - tool_file: ToolFile | None = ( - session.query(ToolFile) - .where( - ToolFile.id == tool_file_id, - ) - .first() - ) + tool_file: ToolFile | None = session.scalar(select(ToolFile).where(ToolFile.id == tool_file_id).limit(1)) if not tool_file: return None, None diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 58190d10894..d8969a3391d 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,4 +1,5 @@ from sqlalchemy import delete, select +from sqlalchemy.orm import Session, sessionmaker from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -19,10 +20,18 @@ class ToolLabelManager: return list(set(tool_labels)) @classmethod - def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): + def update_tool_labels( + cls, controller: ToolProviderController, labels: list[str], session: Session | None = None + ) -> None: """ Update tool labels + + :param controller: tool provider controller + :param labels: list of tool labels + :param session: database session, if None, a new session will be created + :return: None """ + labels = cls.filter_tool_labels(labels) if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): @@ -30,26 +39,46 @@ class ToolLabelManager: else: raise ValueError("Unsupported tool type") + if session is not None: + cls._update_tool_labels_logics(session, provider_id, controller, labels) + else: + with sessionmaker(db.engine).begin() as _session: + cls._update_tool_labels_logics(_session, provider_id, controller, labels) + + @classmethod + def _update_tool_labels_logics( + cls, session: Session, provider_id: str, controller: ToolProviderController, labels: list[str] + ) -> None: + """ + Update tool labels logics + + :param session: database session + :param provider_id: tool provider ID + :param controller: tool provider controller + :param labels: list of tool labels + :return: None + """ + # delete old labels - db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id)) + _ = session.execute( + delete(ToolLabelBinding).where( + ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type + ) + ) # insert new labels for label in labels: - db.session.add( - ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type, - label_name=label, - ) - ) - - db.session.commit() + session.add(ToolLabelBinding(tool_id=provider_id, tool_type=controller.provider_type, label_name=label)) @classmethod def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: """ Get tool labels + + :param controller: tool provider controller + :return: list of tool labels (str) """ + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id elif isinstance(controller, BuiltinToolProviderController): @@ -60,9 +89,11 @@ class ToolLabelManager: ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type, ) - labels = db.session.scalars(stmt).all() - return list(labels) + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + labels: list[str] = list(_session.scalars(stmt).all()) + + return labels @classmethod def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: @@ -78,16 +109,22 @@ class ToolLabelManager: if not tool_providers: return {} + provider_ids: list[str] = [] + provider_types: set[str] = set() + for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): raise ValueError("Unsupported tool type") - - provider_ids = [] - for controller in tool_providers: - assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) provider_ids.append(controller.provider_id) + provider_types.add(controller.provider_type) - labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() + labels: list[ToolLabelBinding] = [] + + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + stmt = select(ToolLabelBinding).where( + ToolLabelBinding.tool_id.in_(provider_ids), ToolLabelBinding.tool_type.in_(list(provider_types)) + ) + labels = list(_session.scalars(stmt).all()) tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index a58d3103137..87cf6d70853 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,16 +5,18 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import sqlalchemy as sa -from graphon.runtime import VariablePool +from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import Session +from typing_extensions import TypedDict from yarl import URL import contexts from configs import dify_config +from core.entities import PluginCredentialType from core.helper.provider_cache import ToolProviderCredentialsCache from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController @@ -26,15 +28,13 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from extensions.ext_database import db +from graphon.runtime import VariablePool from models.provider_ids import ToolProviderID -from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: pass -from graphon.model_runtime.utils.encoders import jsonable_encoder - from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source @@ -49,15 +49,18 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, + EmojiIconDict, ToolInvokeFrom, ToolParameter, ToolProviderType, + emoji_icon_adapter, ) from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -72,9 +75,7 @@ class ApiProviderControllerItem(TypedDict): controller: ApiToolProviderController -class EmojiIconDict(TypedDict): - background: str - content: str +_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) class WorkflowToolRuntimeSpec(Protocol): @@ -98,7 +99,7 @@ class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} _builtin_providers_loaded = False - _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + _builtin_tools_labels: dict[str, I18nObject | None] = {} @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: @@ -188,7 +189,7 @@ class ToolManager: invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, - ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: + ) -> BuiltinTool | PluginTool | ApiTool | WorkflowTool | MCPTool: """ get the tool runtime @@ -203,16 +204,160 @@ class ToolManager: :return: the tool """ - if provider_type == ToolProviderType.BUILT_IN: - # check if the builtin tool need credentials - provider_controller = cls.get_builtin_provider(provider_id, tenant_id) + match provider_type: + case ToolProviderType.BUILT_IN: + provider_controller = cls.get_builtin_provider(provider_id, tenant_id) - builtin_tool = provider_controller.get_tool(tool_name) - if not builtin_tool: - raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") + builtin_tool = provider_controller.get_tool(tool_name) + if not builtin_tool: + raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") + + if not provider_controller.need_credentials: + return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + user_id=user_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ) + builtin_provider = None + if isinstance(provider_controller, PluginToolProviderController): + provider_id_entity = ToolProviderID(provider_id) + if is_valid_uuid(credential_id): + try: + builtin_provider_stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + builtin_provider = db.session.scalar(builtin_provider_stmt) + except Exception as e: + builtin_provider = None + logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") + + if builtin_provider is None: + with Session(db.engine) as session: + builtin_provider = session.scalar( + sa.select(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + ) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"no default provider for {provider_id}") + else: + builtin_provider = db.session.scalar( + select(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id) + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .limit(1) + ) + + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + + from core.helper.credential_utils import check_credential_policy_compliance + + check_credential_policy_compliance( + credential_id=builtin_provider.id, + provider=provider_id, + credential_type=PluginCredentialType.TOOL, + check_existence=False, + ) + + encrypter, cache = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) + ], + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id + ), + ) + + decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) + + if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): + # TODO: circular import + from core.plugin.impl.oauth import OAuthHandler + from services.tools.builtin_tools_manage_service import BuiltinToolManageService + + tool_provider = ToolProviderID(provider_id) + provider_name = tool_provider.provider_name + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" + system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) + + oauth_handler = OAuthHandler() + refreshed_credentials = oauth_handler.refresh_credentials( + tenant_id=tenant_id, + user_id=builtin_provider.user_id, + plugin_id=tool_provider.plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=decrypted_credentials, + ) + # update the credentials + builtin_provider.encrypted_credentials = json.dumps( + encrypter.encrypt(refreshed_credentials.credentials) + ) + builtin_provider.expires_at = refreshed_credentials.expires_at + db.session.commit() + decrypted_credentials = refreshed_credentials.credentials + cache.delete() - if not provider_controller.need_credentials: return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + user_id=user_id, + credentials=dict(decrypted_credentials), + credential_type=builtin_provider.credential_type, + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ) + + case ToolProviderType.API: + api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) + encrypter, _ = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=api_provider, + ) + return api_provider.get_tool(tool_name).fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + user_id=user_id, + credentials=dict(encrypter.decrypt(credentials)), + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) + ) + case ToolProviderType.WORKFLOW: + workflow_provider_stmt = select(WorkflowToolProvider).where( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id + ) + with Session(db.engine, expire_on_commit=False) as session, session.begin(): + workflow_provider = session.scalar(workflow_provider_stmt) + + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, user_id=user_id, @@ -221,177 +366,28 @@ class ToolManager: tool_invoke_from=tool_invoke_from, ) ) - builtin_provider = None - if isinstance(provider_controller, PluginToolProviderController): - provider_id_entity = ToolProviderID(provider_id) - # get specific credentials - if is_valid_uuid(credential_id): - try: - builtin_provider_stmt = select(BuiltinToolProvider).where( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.id == credential_id, - ) - builtin_provider = db.session.scalar(builtin_provider_stmt) - except Exception as e: - builtin_provider = None - logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) - # if the provider has been deleted, raise an error - if builtin_provider is None: - raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") - - # fallback to the default provider - if builtin_provider is None: - # use the default provider - with Session(db.engine) as session: - builtin_provider = session.scalar( - sa.select(BuiltinToolProvider) - .where( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), - ) - .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - ) - if builtin_provider is None: - raise ToolProviderNotFoundError(f"no default provider for {provider_id}") - else: - builtin_provider = db.session.scalar( - select(BuiltinToolProvider) - .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) - .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .limit(1) - ) - - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - - # check if the credential is allowed to be used - from core.helper.credential_utils import check_credential_policy_compliance - - check_credential_policy_compliance( - credential_id=builtin_provider.id, - provider=provider_id, - credential_type=PluginCredentialType.TOOL, - check_existence=False, - ) - - encrypter, cache = create_provider_encrypter( - tenant_id=tenant_id, - config=[ - x.to_basic_provider_config() - for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) - ], - cache=ToolProviderCredentialsCache( - tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id - ), - ) - - # decrypt the credentials - decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) - - # check if the credentials is expired - if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): - # TODO: circular import - from core.plugin.impl.oauth import OAuthHandler - from services.tools.builtin_tools_manage_service import BuiltinToolManageService - - # refresh the credentials - tool_provider = ToolProviderID(provider_id) - provider_name = tool_provider.provider_name - redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" - system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) - - oauth_handler = OAuthHandler() - # refresh the credentials - refreshed_credentials = oauth_handler.refresh_credentials( - tenant_id=tenant_id, - user_id=builtin_provider.user_id, - plugin_id=tool_provider.plugin_id, - provider=provider_name, - redirect_uri=redirect_uri, - system_credentials=system_credentials or {}, - credentials=decrypted_credentials, - ) - # update the credentials - builtin_provider.encrypted_credentials = json.dumps( - encrypter.encrypt(refreshed_credentials.credentials) - ) - builtin_provider.expires_at = refreshed_credentials.expires_at - db.session.commit() - decrypted_credentials = refreshed_credentials.credentials - cache.delete() - - return builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - user_id=user_id, - credentials=dict(decrypted_credentials), - credential_type=builtin_provider.credential_type, - runtime_parameters={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ) - - elif provider_type == ToolProviderType.API: - api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - encrypter, _ = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=api_provider, - ) - return api_provider.get_tool(tool_name).fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - user_id=user_id, - credentials=dict(encrypter.decrypt(credentials)), - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ) - elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider_stmt = select(WorkflowToolProvider).where( - WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id - ) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): - workflow_provider = session.scalar(workflow_provider_stmt) - - if workflow_provider is None: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) - controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) - if controller_tools is None or len(controller_tools) == 0: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - user_id=user_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ) - elif provider_type == ToolProviderType.APP: - raise NotImplementedError("app provider not implemented") - elif provider_type == ToolProviderType.PLUGIN: - plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) - runtime = getattr(plugin_tool, "runtime", None) - if runtime is not None: - runtime.user_id = user_id - runtime.invoke_from = invoke_from - runtime.tool_invoke_from = tool_invoke_from - return plugin_tool - elif provider_type == ToolProviderType.MCP: - mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) - runtime = getattr(mcp_tool, "runtime", None) - if runtime is not None: - runtime.user_id = user_id - runtime.invoke_from = invoke_from - runtime.tool_invoke_from = tool_invoke_from - return mcp_tool - else: - raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + case ToolProviderType.APP: + raise NotImplementedError("app provider not implemented") + case ToolProviderType.PLUGIN: + plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + runtime = getattr(plugin_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return plugin_tool + case ToolProviderType.MCP: + mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + runtime = getattr(mcp_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return mcp_tool + case ToolProviderType.DATASET_RETRIEVAL: + raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + case _: + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod def get_agent_tool_runtime( @@ -401,7 +397,7 @@ class ToolManager: agent_tool: AgentToolEntity, user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: "VariablePool | None" = None, ) -> Tool: """ get the agent tool runtime @@ -445,7 +441,7 @@ class ToolManager: workflow_tool: WorkflowToolRuntimeSpec, user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - variable_pool: Optional["VariablePool"] = None, + variable_pool: "VariablePool | None" = None, ) -> Tool: """ get the workflow tool runtime @@ -637,7 +633,7 @@ class ToolManager: cls._builtin_providers_loaded = False @classmethod - def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: + def get_tool_label(cls, tool_name: str) -> I18nObject | None: """ get the tool label @@ -685,7 +681,7 @@ class ToolManager: with Session(db.engine, autoflush=False) as session: ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] - return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() + return list(session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)))) @classmethod def list_providers_from_api( @@ -885,7 +881,7 @@ class ToolManager: raise ValueError(f"you have not added provider {provider_name}") try: - credentials = json.loads(provider_obj.credentials_str) or {} + credentials = _credentials_adapter.validate_json(provider_obj.credentials_str) or {} except Exception: credentials = {} @@ -910,7 +906,7 @@ class ToolManager: masked_credentials = encrypter.mask_plugin_credentials(encrypter.decrypt(credentials)) try: - icon = json.loads(provider_obj.icon) + icon = emoji_icon_adapter.validate_json(provider_obj.icon) except Exception: icon = {"background": "#252525", "content": "\ud83d\ude01"} @@ -973,7 +969,7 @@ class ToolManager: if workflow_provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - icon = json.loads(workflow_provider.icon) + icon = emoji_icon_adapter.validate_json(workflow_provider.icon) return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} @@ -990,13 +986,13 @@ class ToolManager: if api_provider is None: raise ToolProviderNotFoundError(f"api provider {provider_id} not found") - icon = json.loads(api_provider.icon) + icon = emoji_icon_adapter.validate_json(api_provider.icon) return icon except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str: + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str: try: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) @@ -1004,7 +1000,7 @@ class ToolManager: mcp_provider = mcp_service.get_provider_entity( provider_id=provider_id, tenant_id=tenant_id, by_server_id=True ) - return mcp_provider.provider_icon + return cast(EmojiIconDict | str, mcp_provider.provider_icon) except ValueError: raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") except Exception: @@ -1016,7 +1012,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> str | EmojiIconDict | dict[str, str]: + ) -> str | EmojiIconDict: """ get the tool icon @@ -1025,37 +1021,37 @@ class ToolManager: :param provider_id: the id of the provider :return: """ - provider_type = provider_type - provider_id = provider_id - if provider_type == ToolProviderType.BUILT_IN: - provider = ToolManager.get_builtin_provider(provider_id, tenant_id) - if isinstance(provider, PluginToolProviderController): + match provider_type: + case ToolProviderType.BUILT_IN: + provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + if isinstance(provider, PluginToolProviderController): + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except Exception: + return {"background": "#252525", "content": "\ud83d\ude01"} + return cls.generate_builtin_tool_icon_url(provider_id) + case ToolProviderType.API: + return cls.generate_api_tool_icon_url(tenant_id, provider_id) + case ToolProviderType.WORKFLOW: + return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) + case ToolProviderType.PLUGIN: + provider = ToolManager.get_plugin_provider(provider_id, tenant_id) try: return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) except Exception: return {"background": "#252525", "content": "\ud83d\ude01"} - return cls.generate_builtin_tool_icon_url(provider_id) - elif provider_type == ToolProviderType.API: - return cls.generate_api_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.WORKFLOW: - return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.PLUGIN: - provider = ToolManager.get_plugin_provider(provider_id, tenant_id) - try: - return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - raise ValueError(f"plugin provider {provider_id} not found") - elif provider_type == ToolProviderType.MCP: - return cls.generate_mcp_tool_icon_url(tenant_id, provider_id) - else: - raise ValueError(f"provider type {provider_type} not found") + case ToolProviderType.MCP: + return cls.generate_mcp_tool_icon_url(tenant_id, provider_id) + case ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL: + raise ValueError(f"provider type {provider_type} not found") + case _: + raise ValueError(f"provider type {provider_type} not found") @classmethod def _convert_tool_parameters_type( cls, parameters: list[ToolParameter], - variable_pool: Optional["VariablePool"], + variable_pool: "VariablePool | None", tool_configurations: Mapping[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: @@ -1086,7 +1082,12 @@ class ToolManager: continue tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + variable_selector = tool_input.value + if not isinstance(variable_selector, list) or not all( + isinstance(selector_part, str) for selector_part in variable_selector + ): + raise ToolParameterError("Variable tool input must be a variable selector") + variable = variable_pool.get(variable_selector) if variable is None: raise ToolParameterError(f"Variable {tool_input.value} does not exist") parameter_value = variable.value diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index e63435db988..b6890b2611c 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,21 +1,20 @@ import threading from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelManager -from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService +from core.rag.entities import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool -from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index cbd8bdb36cf..0d1dc7273b5 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,13 +1,11 @@ -from typing import NotRequired, TypedDict, cast +from typing import Any, cast from pydantic import BaseModel, Field from sqlalchemy import select from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.rag.entities.context_entities import DocumentContext +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService +from core.rag.entities import DocumentContext, RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -18,18 +16,6 @@ from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService - -class DefaultRetrievalModelDict(TypedDict): - search_method: RetrievalMethod - reranking_enable: bool - reranking_model: RerankingModelDict - reranking_mode: NotRequired[str] - weights: NotRequired[WeightsDict | None] - score_threshold: NotRequired[float] - top_k: int - score_threshold_enabled: bool - - default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, @@ -53,7 +39,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): dataset_id: str user_id: str | None = None retrieve_config: DatasetRetrieveConfigEntity - inputs: dict + inputs: dict[str, Any] @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index fca6e6f1c78..0bdc3df8692 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -33,7 +33,7 @@ class DatasetRetrieverTool(Tool): invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, user_id: str, - inputs: dict, + inputs: dict[str, Any], ) -> list["DatasetRetrieverTool"]: """ get dataset tool diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index bb5b3ba76e9..5679466cbcb 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -4,15 +4,16 @@ from collections.abc import Generator from datetime import date, datetime from decimal import Decimal from mimetypes import guess_extension +from typing import Any from uuid import UUID import numpy as np import pytz -from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account @@ -40,6 +41,10 @@ def safe_json_value(v): return v.hex() elif isinstance(v, memoryview): return v.tobytes().hex() + elif isinstance(v, np.integer): + return int(v) + elif isinstance(v, np.floating): + return float(v) elif isinstance(v, np.ndarray): return v.tolist() elif isinstance(v, dict): @@ -50,7 +55,7 @@ def safe_json_value(v): return v -def safe_json_dict(d: dict): +def safe_json_dict(d: dict[str, Any]): if not isinstance(d, dict): raise TypeError("safe_json_dict() expects a dictionary (dict) as input") return {k: safe_json_value(v) for k, v in d.items()} @@ -118,7 +123,8 @@ class ToolFileMessageTransformer: if not isinstance(message.message, ToolInvokeMessage.BlobMessage): raise ValueError("unexpected message type") - assert isinstance(message.message.blob, bytes) + if not isinstance(message.message.blob, bytes): + raise TypeError(f"Expected blob to be bytes, got {type(message.message.blob).__name__}") tool_file_manager = ToolFileManager() tool_file = tool_file_manager.create_file_by_raw( user_id=user_id, @@ -195,11 +201,11 @@ class ToolFileMessageTransformer: @staticmethod def _with_tool_file_meta( - meta: dict | None, + meta: dict[str, Any] | None, *, tool_file_id: str | None = None, url: str | None = None, - ) -> dict: + ) -> dict[str, Any]: normalized_meta = meta.copy() if meta is not None else {} resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url) if resolved_tool_file_id and "tool_file_id" not in normalized_meta: diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8d6f83dc07c..a3623d4ecde 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,6 +8,9 @@ import json from decimal import Decimal from typing import cast +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType +from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -18,12 +21,8 @@ from graphon.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder - -from core.model_manager import ModelManager -from core.tools.entities.tool_entities import ToolProviderType -from extensions.ext_database import db from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index f7484b93fb9..434af555832 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -32,7 +32,7 @@ class OpenAPISpecDict(TypedDict): class ApiBasedToolSchemaParser: @staticmethod def parse_openapi_to_tool_bundle( - openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None + openapi: Mapping[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -236,7 +236,7 @@ class ApiBasedToolSchemaParser: return value @staticmethod - def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: + def _get_tool_parameter_type(parameter: dict[str, Any]) -> ToolParameter.ToolParameterType | None: parameter = parameter or {} typ: str | None = None if parameter.get("format") == "binary": @@ -265,7 +265,7 @@ class ApiBasedToolSchemaParser: @staticmethod def parse_openapi_yaml_to_tool_bundle( - yaml: str, extra_info: dict | None = None, warning: dict | None = None + yaml: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> list[ApiToolBundle]: """ parse openapi yaml to tool bundle @@ -278,14 +278,14 @@ class ApiBasedToolSchemaParser: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} - openapi: dict = safe_load(yaml) + openapi: dict[str, Any] = safe_load(yaml) if openapi is None: raise ToolApiSchemaError("Invalid openapi yaml.") return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @staticmethod def parse_swagger_to_openapi( - swagger: dict, extra_info: dict | None = None, warning: dict | None = None + swagger: dict[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> OpenAPISpecDict: warning = warning or {} """ @@ -351,7 +351,7 @@ class ApiBasedToolSchemaParser: @staticmethod def parse_openai_plugin_json_to_tool_bundle( - json: str, extra_info: dict | None = None, warning: dict | None = None + json: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> list[ApiToolBundle]: """ parse openapi plugin yaml to tool bundle @@ -392,7 +392,7 @@ class ApiBasedToolSchemaParser: @staticmethod def auto_parse_to_tool_bundle( - content: str, extra_info: dict | None = None, warning: dict | None = None + content: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ auto parse to tool bundle diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 4bfaa5e49bd..1dd0605f285 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -19,5 +19,18 @@ def remove_leading_symbols(text: str) -> str: # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later - pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+' + pattern = re.compile( + r""" + ^ + (?: + [\u2000-\u2025] # General Punctuation: spaces, quotes, dashes + | [\u2027-\u206F] # General Punctuation: ellipsis, underscores, etc. + | [\u2E00-\u2E7F] # Supplemental Punctuation: medieval, ancient marks + | [\u3000-\u300F] # CJK Punctuation: 、。〃「」『》』 (excludes 【】) + | [\u3012-\u303F] # CJK Punctuation: 〖〗〔〕〘〙〚〛〜 etc. + | ["#$%&'()*+,./:;<=>?@^_`~] # ASCII punctuation (excludes []【】) + )+ + """, + re.VERBOSE, + ) return re.sub(pattern, "", text) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index ed3ed3e0de9..94a2c0427be 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -105,7 +105,7 @@ class Article: def extract_using_readabilipy(html: str): - json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True) + json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False) article = Article( title=json_article.get("title") or "", author=json_article.get("byline") or "", diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index c4b7d574493..45718cadb6b 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,13 +1,12 @@ from collections.abc import Mapping, Sequence from typing import Any +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.entities import OutputVariableEntity from graphon.variables.input_entities import VariableEntity -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError - class WorkflowToolConfigurationUtils: @classmethod @@ -17,10 +16,8 @@ class WorkflowToolConfigurationUtils: """ nodes = graph.get("nodes", []) start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) - if not start_node: return [] - return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])] @classmethod diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index f48b24be30e..5905fd919e9 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,8 +2,8 @@ from __future__ import annotations from collections.abc import Mapping -from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field +from sqlalchemy import select from sqlalchemy.orm import Session from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -24,6 +24,7 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -96,10 +97,10 @@ class WorkflowToolProviderController(ToolProviderController): :param app: the app :return: the tool """ - workflow: Workflow | None = ( - session.query(Workflow) + workflow: Workflow | None = session.scalar( + select(Workflow) .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) - .first() + .limit(1) ) if not workflow: @@ -217,13 +218,13 @@ class WorkflowToolProviderController(ToolProviderController): return self.tools with Session(db.engine, expire_on_commit=False) as session, session.begin(): - db_provider: WorkflowToolProvider | None = ( - session.query(WorkflowToolProvider) + db_provider: WorkflowToolProvider | None = session.scalar( + select(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == self.provider_id, ) - .first() + .limit(1) ) if not db_provider: diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index a3fb4eda928..cd8c6352b52 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,8 +5,6 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -22,6 +20,8 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolInvokeError from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser from models.utils.file_input_compat import build_file_from_stored_mapping @@ -277,7 +277,7 @@ class WorkflowTool(Tool): session.expunge(app) return app - def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: + def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, str | None]]]: """ transform the tool parameters @@ -305,14 +305,15 @@ class WorkflowTool(Tool): "transfer_method": file.transfer_method.value, "type": file.type.value, } - if file.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = resolve_file_record_id(file.reference) - elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = resolve_file_record_id(file.reference) - elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE: - file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) - elif file.transfer_method == FileTransferMethod.REMOTE_URL: - file_dict["url"] = file.generate_url() + match file.transfer_method: + case FileTransferMethod.TOOL_FILE: + file_dict["tool_file_id"] = resolve_file_record_id(file.reference) + case FileTransferMethod.LOCAL_FILE: + file_dict["upload_file_id"] = resolve_file_record_id(file.reference) + case FileTransferMethod.DATASOURCE_FILE: + file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) + case FileTransferMethod.REMOTE_URL: + file_dict["url"] = file.generate_url() files.append(file_dict) except Exception: @@ -322,7 +323,7 @@ class WorkflowTool(Tool): return parameters_result, files - def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: + def _extract_files(self, outputs: dict[str, Any]) -> tuple[dict[str, Any], list[File]]: """ extract files from the result @@ -354,11 +355,17 @@ class WorkflowTool(Tool): return result, files - def _update_file_mapping(self, file_dict: dict): + def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]: file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) - transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) - if transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_id - elif transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_id + transfer_method_value = file_dict.get("transfer_method") + if not isinstance(transfer_method_value, str): + raise ValueError("Workflow file mapping is missing a valid transfer_method") + transfer_method = FileTransferMethod.value_of(transfer_method_value) + match transfer_method: + case FileTransferMethod.TOOL_FILE: + file_dict["tool_file_id"] = file_id + case FileTransferMethod.LOCAL_FILE: + file_dict["upload_file_id"] = file_id + case FileTransferMethod.REMOTE_URL | FileTransferMethod.DATASOURCE_FILE: + pass return file_dict diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 61d1cd85402..24c1271488c 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,7 +8,6 @@ from collections.abc import Mapping from datetime import datetime from typing import Any -from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -28,6 +27,7 @@ from core.trigger.debug.events import ( from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/trigger/entities/entities.py b/api/core/trigger/entities/entities.py index 89824481b5a..a922e881cdf 100644 --- a/api/core/trigger/entities/entities.py +++ b/api/core/trigger/entities/entities.py @@ -6,6 +6,7 @@ from typing import Any, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from core.entities.provider_entities import ProviderConfig +from core.plugin.entities import OAuthSchema from core.plugin.entities.parameters import ( PluginParameterAutoGenerate, PluginParameterOption, @@ -108,13 +109,6 @@ class EventEntity(BaseModel): return v or [] -class OAuthSchema(BaseModel): - client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client") - credentials_schema: list[ProviderConfig] = Field( - default_factory=list, description="The schema of the OAuth credentials" - ) - - class SubscriptionConstructor(BaseModel): """ The subscription constructor of the trigger provider diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py similarity index 74% rename from api/core/workflow/human_input_compat.py rename to api/core/workflow/human_input_adapter.py index c95516a240b..4b765e6aeab 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_adapter.py @@ -1,8 +1,8 @@ -"""Workflow-layer adapters for legacy human-input payload keys. +"""Workflow-to-Graphon adapters for persisted node payloads. -Stored workflow graphs and editor payloads may still use Dify-specific human -input recipient keys. Normalize them here before handing configs to -`graphon` so graph-owned models only see graph-neutral field names. +Stored workflow graphs and editor payloads still contain a small set of +Dify-owned field spellings and value shapes. Adapt them here before handing the +payload to Graphon so Graphon-owned models only see current contracts. """ from __future__ import annotations @@ -14,12 +14,13 @@ from typing import Annotated, Any, ClassVar, Literal import bleach import markdown +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter + from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.variable_template_parser import VariableTemplateParser from graphon.runtime import VariablePool from graphon.variables.consts import SELECTORS_LENGTH -from markdown.extensions.tables import TableExtension -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter class DeliveryMethodType(enum.StrEnum): @@ -184,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None: return None -def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") @@ -214,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: - normalized = normalize_human_input_node_data_for_graph(node_data) + normalized = adapt_human_input_node_data_for_graph(node_data) raw_delivery_methods = normalized.get("delivery_methods") if not isinstance(raw_delivery_methods, list): return [] @@ -228,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b return False -def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") - if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: - return normalized - return normalize_human_input_node_data_for_graph(normalized) + node_type = normalized.get("type") + if node_type == BuiltinNodeTypes.HUMAN_INPUT: + return adapt_human_input_node_data_for_graph(normalized) + if node_type == BuiltinNodeTypes.TOOL: + return _adapt_tool_node_data_for_graph(normalized) + return normalized -def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_config) if normalized is None: raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") @@ -247,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) if data_mapping is None: return normalized - normalized["data"] = normalize_node_data_for_graph(data_mapping) + normalized["data"] = adapt_node_data_for_graph(data_mapping) return normalized +def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(node_data) + + raw_tool_configurations = normalized.get("tool_configurations") + if not isinstance(raw_tool_configurations, Mapping): + return normalized + + existing_tool_parameters = normalized.get("tool_parameters") + normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {} + normalized_tool_configurations: dict[str, Any] = {} + found_legacy_tool_inputs = False + + for name, value in raw_tool_configurations.items(): + if not isinstance(value, Mapping): + normalized_tool_configurations[name] = value + continue + + input_type = value.get("type") + input_value = value.get("value") + if input_type not in {"mixed", "variable", "constant"}: + normalized_tool_configurations[name] = value + continue + + found_legacy_tool_inputs = True + normalized_tool_parameters.setdefault(name, dict(value)) + + flattened_value = _flatten_legacy_tool_configuration_value( + input_type=input_type, + input_value=input_value, + ) + if flattened_value is not None: + normalized_tool_configurations[name] = flattened_value + + if not found_legacy_tool_inputs: + return normalized + + normalized["tool_parameters"] = normalized_tool_parameters + normalized["tool_configurations"] = normalized_tool_configurations + return normalized + + +def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None: + if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool): + return input_value + + if ( + input_type == "variable" + and isinstance(input_value, list) + and all(isinstance(item, str) for item in input_value) + ): + return "{{#" + ".".join(input_value) + "#}}" + + return None + + def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: normalized = dict(recipients) @@ -290,9 +349,9 @@ __all__ = [ "MemberRecipient", "WebAppDeliveryMethod", "_WebAppDeliveryConfig", + "adapt_human_input_node_data_for_graph", + "adapt_node_config_for_graph", + "adapt_node_data_for_graph", "is_human_input_webapp_enabled", - "normalize_human_input_node_data_for_graph", - "normalize_node_config_for_graph", - "normalize_node_data_for_graph", "parse_human_input_delivery_methods", ] diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index f6c3aee4c12..de4eae1b22c 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -1,25 +1,10 @@ import importlib import pkgutil from collections.abc import Callable, Iterator, Mapping, MutableMapping +from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Any, cast, final, override -from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.file.file_manager import file_manager -from graphon.graph.graph import NodeFactory -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.nodes.base.node import Node -from graphon.nodes.code.code_node import WorkflowCodeExecutor -from graphon.nodes.code.entities import CodeLanguage -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.nodes.document_extractor import UnstructuredApiConfig -from graphon.nodes.http_request import build_http_request_config -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from sqlalchemy import select from sqlalchemy.orm import Session @@ -30,12 +15,12 @@ from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, ) -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.trigger.constants import TRIGGER_NODE_TYPES -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.node_runtime import ( DifyFileReferenceFactory, DifyHumanInputNodeRuntime, @@ -55,6 +40,22 @@ from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from models.model import Conversation if TYPE_CHECKING: @@ -67,6 +68,31 @@ _START_NODE_TYPES: frozenset[NodeType] = frozenset( ) +@dataclass(frozen=True, slots=True) +class DifyGraphInitContext: + """Explicit graph-init values owned by the workflow layer. + + Dify is gradually removing direct `GraphInitParams` construction from its + production call sites. Keep the translation here until `graphon` exposes an + equivalent explicit API. + """ + + workflow_id: str + graph_config: Mapping[str, Any] + run_context: Mapping[str, Any] + call_depth: int + + def to_graph_init_params(self) -> "GraphInitParams": + from graphon.entities import GraphInitParams + + return GraphInitParams( + workflow_id=self.workflow_id, + graph_config=self.graph_config, + run_context=self.run_context, + call_depth=self.call_depth, + ) + + def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None: package = importlib.import_module(package_name) for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."): @@ -95,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + """Resolve the production node class for the requested type/version.""" node_mapping = get_node_type_classes_mapping().get(node_type) if not node_mapping: raise ValueError(f"No class mapping found for node type: {node_type}") @@ -237,6 +264,19 @@ class DifyNodeFactory(NodeFactory): Default implementation of NodeFactory that resolves node classes from the live registry. """ + @classmethod + def from_graph_init_context( + cls, + *, + graph_init_context: DifyGraphInitContext, + graph_runtime_state: "GraphRuntimeState", + ) -> "DifyNodeFactory": + """Bridge Dify's explicit init context into the current `graphon` API.""" + return cls( + graph_init_params=graph_init_context.to_graph_init_params(), + graph_runtime_state=graph_runtime_state, + ) + def __init__( self, graph_init_params: "GraphInitParams", @@ -258,7 +298,7 @@ class DifyNodeFactory(NodeFactory): ) self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH - self._http_request_http_client = ssrf_proxy + self._http_request_http_client = graphon_ssrf_proxy self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( self._dify_context, conversation_id_getter=self._conversation_id, @@ -325,10 +365,14 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) + # Graph configs are initially validated against permissive shared node data. + # Re-validate using the resolved node class so workflow-local node schemas + # stay explicit and constructors receive the concrete typed payload. + resolved_node_data = self._validate_resolved_node_data(node_class, node_data) node_type = node_data.type node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { BuiltinNodeTypes.CODE: lambda: { @@ -352,7 +396,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -366,7 +410,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -376,7 +420,7 @@ class DifyNodeFactory(NodeFactory): ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=False, include_llm_file_saver=False, @@ -397,8 +441,8 @@ class DifyNodeFactory(NodeFactory): } node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, **node_init_kwargs, @@ -409,7 +453,10 @@ class DifyNodeFactory(NodeFactory): """ Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class. """ - return node_class.validate_node_data(node_data) + validate_node_data = getattr(node_class, "validate_node_data", None) + if callable(validate_node_data): + return cast("BaseNodeData", validate_node_data(node_data)) + return node_data @staticmethod def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 19cb3a7b0ab..b8725853c4c 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -2,38 +2,8 @@ from __future__ import annotations from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities import LLMMode -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.nodes.llm.runtime_protocols import ( - PreparedLLMProtocol, - PromptMessageSerializerProtocol, - RetrieverAttachmentLoaderProtocol, -) -from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol -from graphon.nodes.runtime import ( - HumanInputFormStateProtocol, - HumanInputNodeRuntimeProtocol, - ToolNodeRuntimeProtocol, -) -from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError -from graphon.nodes.tool_runtime_entities import ( - ToolRuntimeHandle, - ToolRuntimeMessage, - ToolRuntimeParameter, -) from sqlalchemy import select from sqlalchemy.orm import Session @@ -60,11 +30,41 @@ from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories import file_factory +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities import LLMMode +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, + ToolNodeRuntimeProtocol, +) +from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from .human_input_compat import ( +from .human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, DeliveryMethodType, @@ -76,13 +76,12 @@ from .human_input_compat import ( from .system_variables import SystemVariableKey, get_system_text if TYPE_CHECKING: + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage from graphon.file import File from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.tool.entities import ToolNodeData - from core.tools.__base.tool import Tool - from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage - _file_access_controller = DatabaseFileAccessController() @@ -174,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: return self._model_instance.get_llm_num_tokens(prompt_messages) + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + def invoke_llm( self, *, @@ -191,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): stream=stream, ) + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + def invoke_llm_with_structured_output( self, *, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index bfd5536e4a7..68a24e86b1e 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,15 +3,13 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.workflow.system_variables import SystemVariableKey, get_system_text - from .entities import AgentNodeData from .exceptions import ( AgentInvocationError, @@ -36,18 +34,18 @@ class AgentNode(Node[AgentNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: AgentNodeData, + *, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, - *, strategy_resolver: AgentStrategyResolver, presentation_provider: AgentStrategyPresentationProvider, runtime_support: AgentRuntimeSupport, message_transformer: AgentMessageTransformer, ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index c52aad150bb..51452c29a3f 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index db74590ed76..f44681377dc 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -3,6 +3,14 @@ from __future__ import annotations from collections.abc import Generator, Mapping from typing import Any, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import DatabaseFileAccessController +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from extensions.ext_database import db +from factories import file_factory from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata @@ -15,14 +23,6 @@ from graphon.node_events import ( StreamCompletedEvent, ) from graphon.variables.segments import ArrayFileSegment -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.file_access import DatabaseFileAccessController -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from extensions.ext_database import db -from factories import file_factory from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index be50edbc4d4..a872774c98c 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -4,8 +4,6 @@ import json from collections.abc import Sequence from typing import Any, cast -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.runtime import VariablePool from packaging.version import Version from pydantic import ValidationError from sqlalchemy import select @@ -21,6 +19,8 @@ from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolP from core.tools.tool_manager import ToolManager from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from models.model import Conversation from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index d9247b25932..f3006c42422 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,7 +1,12 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.plugin.impl.exc import PluginDaemonClientSideError +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, @@ -12,13 +17,6 @@ from graphon.node_events import NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.datasource.datasource_manager import DatasourceManager -from core.datasource.entities.datasource_entities import DatasourceProviderType -from core.plugin.impl.exc import PluginDaemonClientSideError -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.system_variables import SystemVariableKey, get_system_segment - from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError @@ -37,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: DatasourceNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index cad32f8d5bd..28966f2392c 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,9 +1,10 @@ from typing import Any, Literal, Union +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo + from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index cba6c12dca0..260881e49ca 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,65 +1,13 @@ -from typing import Literal, Union +from typing import Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel +from core.rag.entities import RerankingModelConfig, WeightedScoreConfig from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE - - -class RerankingModelConfig(BaseModel): - """ - Reranking Model Config. - """ - - reranking_provider_name: str - reranking_model_name: str - - -class VectorSetting(BaseModel): - """ - Vector Setting. - """ - - vector_weight: float - embedding_provider_name: str - embedding_model_name: str - - -class KeywordSetting(BaseModel): - """ - Keyword Setting. - """ - - keyword_weight: float - - -class WeightedScoreConfig(BaseModel): - """ - Weighted score Config. - """ - - vector_setting: VectorSetting - keyword_setting: KeywordSetting - - -class EmbeddingSetting(BaseModel): - """ - Embedding Setting. - """ - - embedding_provider_name: str - embedding_model_name: str - - -class EconomySetting(BaseModel): - """ - Economy Setting. - """ - - keyword_number: int +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class RetrievalSetting(BaseModel): @@ -77,16 +25,6 @@ class RetrievalSetting(BaseModel): weights: WeightedScoreConfig | None = None -class IndexMethod(BaseModel): - """ - Knowledge Index Setting. - """ - - indexing_technique: Literal["high_quality", "economy"] - embedding_setting: EmbeddingSetting - economy_setting: EconomySetting - - class FileInfo(BaseModel): """ File Info. diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index bb72fe38816..9c1b7ab2c42 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,17 +2,15 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template - from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( @@ -33,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeIndexNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: - super().__init__(id, config, graph_init_params, graph_runtime_state) + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) self.index_processor = IndexProcessor() self.summary_index_service = SummaryIndex() diff --git a/api/core/workflow/nodes/knowledge_index/protocols.py b/api/core/workflow/nodes/knowledge_index/protocols.py index bb521230820..d04e79c2a8c 100644 --- a/api/core/workflow/nodes/knowledge_index/protocols.py +++ b/api/core/workflow/nodes/knowledge_index/protocols.py @@ -1,9 +1,19 @@ from collections.abc import Mapping -from typing import Any, Protocol +from typing import Any, Protocol, TypedDict from pydantic import BaseModel, Field +class IndexingResultDict(TypedDict): + dataset_id: str + dataset_name: str + batch: Any + document_id: str + document_name: str + created_at: float + display_status: str + + class PreviewItem(BaseModel): content: str | None = Field(default=None) child_chunks: list[str] | None = Field(default=None) @@ -33,15 +43,20 @@ class IndexProcessorProtocol(Protocol): original_document_id: str, chunks: Mapping[str, Any], batch: Any, - summary_index_setting: dict | None = None, - ) -> dict[str, Any]: ... + summary_index_setting: dict[str, Any] | None = None, + ) -> IndexingResultDict: ... def get_preview_output( - self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None + self, + chunks: Any, + dataset_id: str, + document_id: str, + chunk_structure: str, + summary_index_setting: dict[str, Any] | None, ) -> Preview: ... class SummaryIndexServiceProtocol(Protocol): def generate_and_vectorize_summary( - self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None + self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict[str, Any] | None = None ) -> None: ... diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index b1fa8593efe..3825f526a2d 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,46 +1,13 @@ -from collections.abc import Sequence from typing import Literal +from pydantic import BaseModel, Field + +from core.rag.entities import Condition, MetadataFilteringCondition, RerankingModelConfig, WeightedScoreConfig from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.llm.entities import ModelConfig, VisionConfig -from pydantic import BaseModel, Field - -class RerankingModelConfig(BaseModel): - """ - Reranking Model Config. - """ - - provider: str - model: str - - -class VectorSetting(BaseModel): - """ - Vector Setting. - """ - - vector_weight: float - embedding_provider_name: str - embedding_model_name: str - - -class KeywordSetting(BaseModel): - """ - Keyword Setting. - """ - - keyword_weight: float - - -class WeightedScoreConfig(BaseModel): - """ - Weighted score Config. - """ - - vector_setting: VectorSetting - keyword_setting: KeywordSetting +__all__ = ["Condition"] class MultipleRetrievalConfig(BaseModel): @@ -64,50 +31,6 @@ class SingleRetrievalConfig(BaseModel): model: ModelConfig -SupportedComparisonOperator = Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - "in", - "not in", - # for number - "=", - "≠", - ">", - "<", - "≥", - "≤", - # for time - "before", - "after", -] - - -class Condition(BaseModel): - """ - Condition detail - """ - - name: str - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | None | int | float = None - - -class MetadataFilteringCondition(BaseModel): - """ - Metadata Filtering Condition. - """ - - logical_operator: Literal["and", "or"] | None = "and" - conditions: list[Condition] | None = Field(default=None, deprecated=True) - - class KnowledgeRetrievalNodeData(BaseNodeData): """ Knowledge retrieval Node Data. diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 13624b27b37..25f73e446df 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,8 +8,12 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, @@ -27,12 +31,6 @@ from graphon.variables import ( ) from graphon.variables.segments import ArrayObjectSegment -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file_reference import parse_file_reference - from .entities import ( Condition, KnowledgeRetrievalNodeData, @@ -51,6 +49,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None: + if value is None or isinstance(value, (str, float)): + return value + if isinstance(value, int) and not isinstance(value, bool): + return value + return str(value) + + +def _normalize_metadata_filter_sequence_item(value: object) -> str: + return value if isinstance(value, str) else str(value) + + class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL @@ -60,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeRetrievalNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, @@ -283,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD resolved_conditions: list[Condition] = [] for cond in conditions.conditions or []: value = cond.value + resolved_value: str | Sequence[str] | int | float | None if isinstance(value, str): segment_group = variable_pool.convert_template(value) if len(segment_group.value) == 1: - resolved_value = segment_group.value[0].to_object() + resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object()) else: resolved_value = segment_group.text elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value): - resolved_values = [] - for v in value: # type: ignore + resolved_values: list[str] = [] + for v in value: segment_group = variable_pool.convert_template(v) if len(segment_group.value) == 1: - resolved_values.append(segment_group.value[0].to_object()) + resolved_values.append( + _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object()) + ) else: resolved_values.append(segment_group.text) resolved_value = resolved_values diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index 39e2008a2ca..ea45dcf5c20 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Protocol -from graphon.model_runtime.entities import LLMUsage -from graphon.nodes.llm.entities import ModelConfig from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index bf5be2379af..23ed2cd4088 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index e50de11bb90..c848a862558 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,13 +1,12 @@ from collections.abc import Mapping from typing import Any +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID - from .entities import TriggerEventNodeData diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index f14ca893c9e..683c8d420f2 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ -from typing import Literal, Union +from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): @@ -16,7 +16,7 @@ class TriggerScheduleNodeData(BaseNodeData): mode: str = Field(default="visual", description="Schedule mode: visual or cron") frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly") cron_expression: str | None = Field(default=None, description="Cron expression for cron mode") - visual_config: dict | None = Field(default=None, description="Visual configuration details") + visual_config: dict[str, Any] | None = Field(default=None, description="Visual configuration details") timezone: str = Field(default="UTC", description="Timezone for schedule execution") diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index a9753ab387d..b46cc76a6e8 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,10 @@ from collections.abc import Mapping -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node - from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .entities import TriggerScheduleNodeData diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index a30f877e4be..b2610394480 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field, field_validator from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType +from graphon.variables.types import SegmentType _WEBHOOK_HEADER_ALLOWED_TYPES: frozenset[SegmentType] = frozenset((SegmentType.STRING,)) diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index ebaac939345..13c4f05bfd5 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,6 +2,10 @@ import logging from collections.abc import Mapping from typing import Any +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment_with_type from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.file import FileTransferMethod from graphon.node_events import NodeRunResult @@ -10,11 +14,6 @@ from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.variables.types import SegmentType from graphon.variables.variables import FileVariable -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment_with_type - from .entities import ContentType, WebhookData logger = logging.getLogger(__name__) @@ -29,7 +28,7 @@ class TriggerWebhookNode(Node[WebhookData]): def post_init(self) -> None: from core.workflow.node_runtime import DifyFileReferenceFactory - self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context) + self._file_reference_factory = DifyFileReferenceFactory(self.run_context) @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -75,7 +74,7 @@ class TriggerWebhookNode(Node[WebhookData]): outputs=outputs, ) - def generate_file_var(self, param_name: str, file: dict): + def generate_file_var(self, param_name: str, file: dict[str, Any]): file_id = resolve_file_record_id(file.get("reference") or file.get("related_id")) transfer_method_value = file.get("transfer_method") if transfer_method_value: @@ -147,7 +146,7 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = str(webhook_data.get("body", {}).get("raw", "")) continue elif self.node_data.content_type == ContentType.BINARY: - raw_data: dict = webhook_data.get("body", {}).get("raw", {}) + raw_data: dict[str, Any] = webhook_data.get("body", {}).get("raw", {}) file_var = self.generate_file_var(param_name, raw_data) if file_var: outputs[param_name] = file_var @@ -155,24 +154,25 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = raw_data continue - if param_type == SegmentType.FILE: - # Get File object (already processed by webhook controller) - files = webhook_data.get("files", {}) - if files and isinstance(files, dict): - file = files.get(param_name) - if file and isinstance(file, dict): - file_var = self.generate_file_var(param_name, file) - if file_var: - outputs[param_name] = file_var + match param_type: + case SegmentType.FILE: + # Get File object (already processed by webhook controller) + files = webhook_data.get("files", {}) + if files and isinstance(files, dict): + file = files.get(param_name) + if file and isinstance(file, dict): + file_var = self.generate_file_var(param_name, file) + if file_var: + outputs[param_name] = file_var + else: + outputs[param_name] = files else: outputs[param_name] = files else: outputs[param_name] = files - else: - outputs[param_name] = files - else: - # Get regular body parameter - outputs[param_name] = webhook_data.get("body", {}).get(param_name) + case _: + # Get regular body parameter + outputs[param_name] = webhook_data.get("body", {}).get(param_name) # Include raw webhook data for debugging/advanced use outputs["_webhook_raw"] = webhook_data diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py index d51cfadd098..b4ffb37549d 100644 --- a/api/core/workflow/template_rendering.py +++ b/api/core/workflow/template_rendering.py @@ -3,11 +3,10 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from graphon.nodes.code.entities import CodeLanguage from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor - class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2346a95d6a8..4e2f603e5be 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,8 +1,31 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import Any +from typing import Any, TypedDict +from configs import dify_config +from context import capture_current_context +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.app.workflow.layers.llm_quota import LLMQuotaLayer +from core.app.workflow.layers.observability import ObservabilityLayer +from core.workflow.node_factory import ( + DifyGraphInitContext, + DifyNodeFactory, + is_start_node_type, + resolve_workflow_node_class, +) +from core.workflow.system_variables import ( + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID +from extensions.otel.runtime import is_instrument_flag_enabled +from factories import file_factory from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.errors import WorkflowNodeRunFailedError @@ -16,25 +39,6 @@ from graphon.nodes import BuiltinNodeTypes from graphon.nodes.base.node import Node from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool - -from configs import dify_config -from context import capture_current_context -from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.app.file_access import DatabaseFileAccessController -from core.app.workflow.layers.llm_quota import LLMQuotaLayer -from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class -from core.workflow.system_variables import ( - default_system_variables, - get_node_creation_preload_selectors, - inject_default_system_variable_mappings, - preload_node_creation_variables, -) -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID -from extensions.otel.runtime import is_instrument_flag_enabled -from factories import file_factory from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -107,6 +111,26 @@ class _WorkflowChildEngineBuilder: return child_engine +class _NodeConfigDict(TypedDict): + id: str + width: int + height: int + type: str + data: dict[str, Any] + + +class _EdgeConfigDict(TypedDict): + source: str + target: str + sourceHandle: str + targetHandle: str + + +class SingleNodeGraphDict(TypedDict): + nodes: list[_NodeConfigDict] + edges: list[_EdgeConfigDict] + + class WorkflowEntry: def __init__( self, @@ -231,17 +255,18 @@ class WorkflowEntry: node_version = str(node_config_data.version) node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) - # init graph init params and runtime state - graph_init_params = GraphInitParams( + # init graph context and runtime state + run_context = build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow.id, graph_config=workflow.graph_dict, - run_context=build_dify_run_context( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ), + run_context=run_context, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -293,8 +318,8 @@ class WorkflowEntry: ) # init workflow run state - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) node = node_factory.create_node(node_config) @@ -318,7 +343,7 @@ class WorkflowEntry: node_data: dict[str, Any], node_width: int = 114, node_height: int = 514, - ) -> dict[str, Any]: + ) -> SingleNodeGraphDict: """ Create a minimal graph structure for testing a single node in isolation. @@ -328,14 +353,14 @@ class WorkflowEntry: :param node_height: height for UI layout (default: 100) :return: graph dictionary with start node and target node """ - node_config = { + node_config: _NodeConfigDict = { "id": node_id, "width": node_width, "height": node_height, "type": "custom", "data": node_data, } - start_node_config = { + start_node_config: _NodeConfigDict = { "id": "start", "width": node_width, "height": node_height, @@ -346,9 +371,9 @@ class WorkflowEntry: "desc": "Start", }, } - return { - "nodes": [start_node_config, node_config], - "edges": [ + return SingleNodeGraphDict( + nodes=[start_node_config, node_config], + edges=[ { "source": "start", "target": node_id, @@ -356,7 +381,7 @@ class WorkflowEntry: "targetHandle": "target", } ], - } + ) @classmethod def run_free_node( @@ -389,17 +414,18 @@ class WorkflowEntry: variable_pool = VariablePool() add_variables_to_pool(variable_pool, default_system_variables()) - # init graph init params and runtime state - graph_init_params = GraphInitParams( + # init graph context and runtime state + run_context = build_dify_run_context( + tenant_id=tenant_id, + app_id="", + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_context = DifyGraphInitContext( workflow_id="", graph_config=graph_dict, - run_context=build_dify_run_context( - tenant_id=tenant_id, - app_id="", - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ), + run_context=run_context, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -410,8 +436,8 @@ class WorkflowEntry: # init workflow run state node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) node = node_factory.create_node(node_config) diff --git a/api/dify_app.py b/api/dify_app.py index d6deb8e0074..bbe3f33787a 100644 --- a/api/dify_app.py +++ b/api/dify_app.py @@ -1,5 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from flask import Flask +if TYPE_CHECKING: + from extensions.ext_login import DifyLoginManager + class DifyApp(Flask): - pass + """Flask application type with Dify-specific extension attributes.""" + + login_manager: DifyLoginManager diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 6b904b7d0d6..fc118df5bc0 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -119,14 +119,16 @@ elif [[ "${MODE}" == "job" ]]; then else if [[ "${DEBUG}" == "true" ]]; then - exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug + export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0} + export PORT=${DIFY_PORT:-5001} + exec python -m app else exec gunicorn \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --workers ${SERVER_WORKER_AMOUNT:-1} \ - --worker-class ${SERVER_WORKER_CLASS:-gevent} \ + --worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \ --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ --timeout ${GUNICORN_TIMEOUT:-200} \ - app:app + app:socketio_app fi fi diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py index 5a8d0ee6f49..dff558988c1 100644 --- a/api/enterprise/telemetry/draft_trace.py +++ b/api/enterprise/telemetry/draft_trace.py @@ -3,10 +3,9 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any -from graphon.enums import WorkflowNodeExecutionMetadataKey - from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName from core.telemetry import emit as telemetry_emit +from graphon.enums import WorkflowNodeExecutionMetadataKey from models.workflow import WorkflowNodeExecutionModel diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py index ffd9a7e2b58..c564ace5842 100644 --- a/api/enterprise/telemetry/metric_handler.py +++ b/api/enterprise/telemetry/metric_handler.py @@ -68,46 +68,49 @@ class EnterpriseMetricHandler: # Route to appropriate handler based on case case = envelope.case - if case == TelemetryCase.APP_CREATED: - self._on_app_created(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) - elif case == TelemetryCase.APP_UPDATED: - self._on_app_updated(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) - elif case == TelemetryCase.APP_DELETED: - self._on_app_deleted(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) - elif case == TelemetryCase.FEEDBACK_CREATED: - self._on_feedback_created(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) - elif case == TelemetryCase.MESSAGE_RUN: - self._on_message_run(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) - elif case == TelemetryCase.TOOL_EXECUTION: - self._on_tool_execution(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) - elif case == TelemetryCase.MODERATION_CHECK: - self._on_moderation_check(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) - elif case == TelemetryCase.SUGGESTED_QUESTION: - self._on_suggested_question(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) - elif case == TelemetryCase.DATASET_RETRIEVAL: - self._on_dataset_retrieval(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) - elif case == TelemetryCase.GENERATE_NAME: - self._on_generate_name(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) - elif case == TelemetryCase.PROMPT_GENERATION: - self._on_prompt_generation(envelope) - self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) - else: - logger.warning( - "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", - case, - envelope.tenant_id, - envelope.event_id, - ) + match case: + case TelemetryCase.APP_CREATED: + self._on_app_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) + case TelemetryCase.APP_UPDATED: + self._on_app_updated(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) + case TelemetryCase.APP_DELETED: + self._on_app_deleted(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) + case TelemetryCase.FEEDBACK_CREATED: + self._on_feedback_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) + case TelemetryCase.MESSAGE_RUN: + self._on_message_run(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) + case TelemetryCase.TOOL_EXECUTION: + self._on_tool_execution(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) + case TelemetryCase.MODERATION_CHECK: + self._on_moderation_check(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) + case TelemetryCase.SUGGESTED_QUESTION: + self._on_suggested_question(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) + case TelemetryCase.DATASET_RETRIEVAL: + self._on_dataset_retrieval(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) + case TelemetryCase.GENERATE_NAME: + self._on_generate_name(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) + case TelemetryCase.PROMPT_GENERATION: + self._on_prompt_generation(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) + case TelemetryCase.WORKFLOW_RUN | TelemetryCase.NODE_EXECUTION | TelemetryCase.DRAFT_NODE_EXECUTION: + pass + case _: + logger.warning( + "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", + case, + envelope.tenant_id, + envelope.event_id, + ) def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool: """Check if this event has already been processed. @@ -326,7 +329,7 @@ class EnterpriseMetricHandler: return include_content = exporter.include_content - attrs: dict = { + attrs: dict[str, Any] = { "dify.message.id": payload.get("message_id"), "dify.tenant_id": envelope.tenant_id, "dify.event.id": envelope.event_id, diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index b7e7a6e60f1..0c535a1c5be 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -6,9 +6,9 @@ import click from sqlalchemy import select from werkzeug.exceptions import NotFound +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.document_index_event import document_index_created -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document from models.enums import IndexingStatus @@ -22,24 +22,25 @@ def handle(sender, **kwargs): document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) + with session_factory.create_session() as session: + for document_id in document_ids: + logger.info(click.style(f"Start process document: {document_id}", fg="green")) - document = db.session.scalar( - select(Document).where( - Document.id == document_id, - Document.dataset_id == dataset_id, + document = session.scalar( + select(Document).where( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) ) - ) - if not document: - raise NotFound("Document not found") + if not document: + raise NotFound("Document not found") - document.indexing_status = IndexingStatus.PARSING - document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() + document.indexing_status = IndexingStatus.PARSING + document.processing_started_at = naive_utc_now() + documents.append(document) + session.add(document) + session.commit() with contextlib.suppress(Exception): try: diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py index 57412cc4ad0..38e102d5fd2 100644 --- a/api/events/event_handlers/create_installed_app_when_app_created.py +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -1,5 +1,5 @@ +from core.db.session_factory import session_factory from events.app_event import app_was_created -from extensions.ext_database import db from models.model import InstalledApp @@ -12,5 +12,6 @@ def handle(sender, **kwargs): app_id=app.id, app_owner_tenant_id=app.tenant_id, ) - db.session.add(installed_app) - db.session.commit() + with session_factory.create_session() as session: + session.add(installed_app) + session.commit() diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 84be592b1a9..5e2a456dce3 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,5 @@ +from core.db.session_factory import session_factory from events.app_event import app_was_created -from extensions.ext_database import db from models.enums import CustomizeTokenStrategy from models.model import Site @@ -22,6 +22,6 @@ def handle(sender, **kwargs): created_by=app.created_by, updated_by=app.updated_by, ) - - db.session.add(site) - db.session.commit() + with session_factory.create_session() as session: + session.add(site) + session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 7bd8e88231a..ba9758175fa 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,12 +1,11 @@ import logging -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.tool.entities import ToolEntity - from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_draft_workflow_was_synced +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) diff --git a/api/events/event_handlers/sync_workflow_schedule_when_app_published.py b/api/events/event_handlers/sync_workflow_schedule_when_app_published.py index 168513fc046..5f8fcd86170 100644 --- a/api/events/event_handlers/sync_workflow_schedule_when_app_published.py +++ b/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @@ -2,7 +2,7 @@ import logging from typing import cast from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.workflow.nodes.trigger_schedule.entities import SchedulePlanUpdate from events.app_event import app_published_workflow_was_updated @@ -45,7 +45,7 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) Returns: Updated or created WorkflowSchedulePlan, or None if no schedule node """ - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: schedule_config = ScheduleService.extract_schedule_config(workflow) existing_plan = session.scalar( @@ -59,7 +59,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) if existing_plan: logger.info("No schedule node in workflow for app %s, removing schedule plan", app_id) ScheduleService.delete_schedule(session=session, schedule_id=existing_plan.id) - session.commit() return None if existing_plan: @@ -73,7 +72,6 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) schedule_id=existing_plan.id, updates=updates, ) - session.commit() return updated_plan else: new_plan = ScheduleService.create_schedule( @@ -82,5 +80,4 @@ def sync_schedule_from_workflow(tenant_id: str, app_id: str, workflow: Workflow) app_id=app_id, config=schedule_config, ) - session.commit() return new_plan diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 86b5b2bbf05..6769b94cde0 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,11 +1,11 @@ from typing import cast -from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db +from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py index b3917d56227..d55fe262fb2 100644 --- a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @@ -1,7 +1,7 @@ from typing import cast from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.trigger.constants import TRIGGER_NODE_TYPES from events.app_event import app_published_workflow_was_updated @@ -31,7 +31,7 @@ def handle(sender, **kwargs): # Extract trigger info from workflow trigger_infos = get_trigger_infos_from_workflow(published_workflow) - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: # Get existing app triggers existing_triggers = ( session.execute( @@ -79,8 +79,6 @@ def handle(sender, **kwargs): existing_trigger.title = new_title session.add(existing_trigger) - session.commit() - def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]: """ diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index f68cdaadde0..1d615f0f87a 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -135,37 +135,40 @@ def handle(sender: Message, **kwargs): model_name=model_config.model, ) if used_quota is not None: - if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL: - from services.credit_pool_service import CreditPoolService + match provider_configuration.system_configuration.current_quota_type: + case ProviderQuotaType.TRIAL: + from services.credit_pool_service import CreditPoolService - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - pool_type="trial", - ) - elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( - tenant_id=tenant_id, - credits_required=used_quota, - pool_type="paid", - ) - else: - quota_update = _ProviderUpdateOperation( - filters=_ProviderUpdateFilters( + CreditPoolService.check_and_deduct_credits( tenant_id=tenant_id, - provider_name=ModelProviderID(model_config.provider).provider_name, - provider_type=ProviderType.SYSTEM.value, - quota_type=provider_configuration.system_configuration.current_quota_type, - ), - values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), - additional_filters=_ProviderUpdateAdditionalFilters( - quota_limit_check=True # Provider.quota_limit > Provider.quota_used - ), - description="quota_deduction_update", - ) - updates_to_perform.append(quota_update) + credits_required=used_quota, + pool_type="trial", + ) + case ProviderQuotaType.PAID: + from services.credit_pool_service import CreditPoolService + + CreditPoolService.check_and_deduct_credits( + tenant_id=tenant_id, + credits_required=used_quota, + pool_type="paid", + ) + case ProviderQuotaType.FREE: + quota_update = _ProviderUpdateOperation( + filters=_ProviderUpdateFilters( + tenant_id=tenant_id, + provider_name=ModelProviderID(model_config.provider).provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=provider_configuration.system_configuration.current_quota_type, + ), + values=_ProviderUpdateValues( + quota_used=Provider.quota_used + used_quota, last_used=current_time + ), + additional_filters=_ProviderUpdateAdditionalFilters( + quota_limit_check=True # Provider.quota_limit > Provider.quota_used + ), + description="quota_deduction_update", + ) + updates_to_perform.append(quota_update) # Execute all updates start_time = time_module.perf_counter() diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 4eed34436a4..340f514fcc8 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -5,12 +5,32 @@ from typing import Any import pytz # type: ignore[import-untyped] from celery import Celery, Task from celery.schedules import crontab +from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import normalize_redis_key_prefix -def _get_celery_ssl_options() -> dict[str, Any] | None: +class _CelerySentinelKwargsDict(TypedDict): + socket_timeout: float | None + password: str | None + + +class CelerySentinelTransportDict(TypedDict, total=False): + master_name: str | None + sentinel_kwargs: _CelerySentinelKwargsDict + global_keyprefix: str + + +class CelerySSLOptionsDict(TypedDict): + ssl_cert_reqs: int + ssl_ca_certs: str | None + ssl_certfile: str | None + ssl_keyfile: str | None + + +def get_celery_ssl_options() -> CelerySSLOptionsDict | None: """Get SSL configuration for Celery broker/backend connections.""" # Only apply SSL if we're using Redis as broker/backend if not dify_config.BROKER_USE_SSL: @@ -33,14 +53,41 @@ def _get_celery_ssl_options() -> dict[str, Any] | None: ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) - ssl_options = { - "ssl_cert_reqs": ssl_cert_reqs, - "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, - "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, - "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, - } + return CelerySSLOptionsDict( + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS, + ssl_certfile=dify_config.REDIS_SSL_CERTFILE, + ssl_keyfile=dify_config.REDIS_SSL_KEYFILE, + ) - return ssl_options + +def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]: + """Get broker transport options (e.g. Redis Sentinel) for Celery connections.""" + transport_options: CelerySentinelTransportDict | dict[str, Any] + if dify_config.CELERY_USE_SENTINEL: + transport_options = CelerySentinelTransportDict( + master_name=dify_config.CELERY_SENTINEL_MASTER_NAME, + sentinel_kwargs=_CelerySentinelKwargsDict( + socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, + password=dify_config.CELERY_SENTINEL_PASSWORD, + ), + ) + else: + transport_options = {} + + global_keyprefix = get_celery_redis_global_keyprefix() + if global_keyprefix: + transport_options["global_keyprefix"] = global_keyprefix + + return transport_options + + +def get_celery_redis_global_keyprefix() -> str | None: + """Return the Redis transport prefix for Celery when namespace isolation is enabled.""" + normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + if not normalized_prefix: + return None + return f"{normalized_prefix}:" def init_app(app: DifyApp) -> Celery: @@ -53,16 +100,7 @@ def init_app(app: DifyApp) -> Celery: init_request_context() return self.run(*args, **kwargs) - broker_transport_options = {} - - if dify_config.CELERY_USE_SENTINEL: - broker_transport_options = { - "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, - "sentinel_kwargs": { - "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, - "password": dify_config.CELERY_SENTINEL_PASSWORD, - }, - } + broker_transport_options = get_celery_broker_transport_options() celery_app = Celery( app.name, @@ -89,7 +127,7 @@ def init_app(app: DifyApp) -> Celery: ) # Apply SSL configuration if enabled - ssl_options = _get_celery_ssl_options() + ssl_options = get_celery_ssl_options() if ssl_options: celery_app.conf.update( broker_use_ssl=ssl_options, diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 02e50a90fcc..bc59eaca635 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,7 +1,8 @@ import json +from typing import cast import flask_login -from flask import Response, request +from flask import Request, Response, request from flask_login import user_loaded_from_request, user_logged_in from sqlalchemy import select from werkzeug.exceptions import NotFound, Unauthorized @@ -16,13 +17,35 @@ from models import Account, Tenant, TenantAccountJoin from models.model import AppMCPServer, EndUser from services.account_service import AccountService -login_manager = flask_login.LoginManager() +type LoginUser = Account | EndUser + + +class DifyLoginManager(flask_login.LoginManager): + """Project-specific Flask-Login manager with a stable unauthorized contract. + + Dify registers `unauthorized_handler` below to always return a JSON `Response`. + Overriding this method lets callers rely on that narrower return type instead of + Flask-Login's broader callback contract. + """ + + def unauthorized(self) -> Response: + """Return the registered unauthorized handler result as a Flask `Response`.""" + return cast(Response, super().unauthorized()) + + def load_user_from_request_context(self) -> None: + """Populate Flask-Login's request-local user cache for the current request.""" + self._load_user() + + +login_manager = DifyLoginManager() # Flask-Login configuration @login_manager.request_loader -def load_user_from_request(request_from_flask_login): +def load_user_from_request(request_from_flask_login: Request) -> LoginUser | None: """Load user based on the request.""" + del request_from_flask_login + # Skip authentication for documentation endpoints if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")): return None @@ -100,10 +123,12 @@ def load_user_from_request(request_from_flask_login): raise NotFound("End user not found.") return end_user + return None + @user_logged_in.connect @user_loaded_from_request.connect -def on_user_logged_in(_sender, user): +def on_user_logged_in(_sender: object, user: LoginUser) -> None: """Called when a user logged in. Note: AccountService.load_logged_in_account will populate user.current_tenant_id @@ -114,8 +139,10 @@ def on_user_logged_in(_sender, user): @login_manager.unauthorized_handler -def unauthorized_handler(): +def unauthorized_handler() -> Response: """Handle unauthorized requests.""" + # Keep this as a concrete `Response`; `DifyLoginManager.unauthorized()` narrows + # Flask-Login's callback contract based on this override. return Response( json.dumps({"code": "unauthorized", "message": "Unauthorized."}), status=401, @@ -123,5 +150,5 @@ def unauthorized_handler(): ) -def init_app(app: DifyApp): +def init_app(app: DifyApp) -> None: login_manager.init_app(app) diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 5f528dbf9e6..9f7f73765ee 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,29 +3,41 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Union +from typing import Any, Union, cast import redis from redis import RedisError +from redis.backoff import ExponentialWithJitterBackoff # type: ignore from redis.cache import CacheConfig from redis.client import PubSub from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection +from redis.retry import Retry from redis.sentinel import Sentinel +from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import ( + normalize_redis_key_prefix, + serialize_redis_name, + serialize_redis_name_arg, + serialize_redis_name_args, +) from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel -if TYPE_CHECKING: - from redis.lock import Lock - logger = logging.getLogger(__name__) +_normalize_redis_key_prefix = normalize_redis_key_prefix +_serialize_redis_name = serialize_redis_name +_serialize_redis_name_arg = serialize_redis_name_arg +_serialize_redis_name_args = serialize_redis_name_args + + class RedisClientWrapper: """ A wrapper class for the Redis client that addresses the issue where the global @@ -56,74 +68,189 @@ class RedisClientWrapper: if self._client is None: self._client = client - if TYPE_CHECKING: - # Type hints for IDE support and static analysis - # These are not executed at runtime but provide type information - def get(self, name: str | bytes) -> Any: ... - - def set( - self, - name: str | bytes, - value: Any, - ex: int | None = None, - px: int | None = None, - nx: bool = False, - xx: bool = False, - keepttl: bool = False, - get: bool = False, - exat: int | None = None, - pxat: int | None = None, - ) -> Any: ... - - def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ... - def setnx(self, name: str | bytes, value: Any) -> Any: ... - def delete(self, *names: str | bytes) -> Any: ... - def incr(self, name: str | bytes, amount: int = 1) -> Any: ... - def expire( - self, - name: str | bytes, - time: int | timedelta, - nx: bool = False, - xx: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def lock( - self, - name: str, - timeout: float | None = None, - sleep: float = 0.1, - blocking: bool = True, - blocking_timeout: float | None = None, - thread_local: bool = True, - ) -> Lock: ... - def zadd( - self, - name: str | bytes, - mapping: dict[str | bytes | int | float, float | int | str | bytes], - nx: bool = False, - xx: bool = False, - ch: bool = False, - incr: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... - def zcard(self, name: str | bytes) -> Any: ... - def getdel(self, name: str | bytes) -> Any: ... - def pubsub(self) -> PubSub: ... - def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ... - - def __getattr__(self, item: str) -> Any: + def _require_client(self) -> redis.Redis | RedisCluster: if self._client is None: raise RuntimeError("Redis client is not initialized. Call init_app first.") - return getattr(self._client, item) + return self._client + + def _get_prefix(self) -> str: + return dify_config.REDIS_KEY_PREFIX + + def get(self, name: str | bytes) -> Any: + return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix())) + + def set( + self, + name: str | bytes, + value: Any, + ex: int | None = None, + px: int | None = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: int | None = None, + pxat: int | None = None, + ) -> Any: + return self._require_client().set( + _serialize_redis_name_arg(name, self._get_prefix()), + value, + ex=ex, + px=px, + nx=nx, + xx=xx, + keepttl=keepttl, + get=get, + exat=exat, + pxat=pxat, + ) + + def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: + return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value) + + def setnx(self, name: str | bytes, value: Any) -> Any: + return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value) + + def delete(self, *names: str | bytes) -> Any: + return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix())) + + def incr(self, name: str | bytes, amount: int = 1) -> Any: + return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount) + + def expire( + self, + name: str | bytes, + time: int | timedelta, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().expire( + _serialize_redis_name_arg(name, self._get_prefix()), + time, + nx=nx, + xx=xx, + gt=gt, + lt=lt, + ) + + def exists(self, *names: str | bytes) -> Any: + return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix())) + + def ttl(self, name: str | bytes) -> Any: + return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix())) + + def getdel(self, name: str | bytes) -> Any: + return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix())) + + def lock( + self, + name: str, + timeout: float | None = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: float | None = None, + thread_local: bool = True, + ) -> Any: + return self._require_client().lock( + _serialize_redis_name(name, self._get_prefix()), + timeout=timeout, + sleep=sleep, + blocking=blocking, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any: + return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs) + + def hgetall(self, name: str | bytes) -> Any: + return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix())) + + def hdel(self, name: str | bytes, *keys: str | bytes) -> Any: + return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys) + + def hlen(self, name: str | bytes) -> Any: + return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix())) + + def zadd( + self, + name: str | bytes, + mapping: dict[str | bytes | int | float, float | int | str | bytes], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().zadd( + _serialize_redis_name_arg(name, self._get_prefix()), + cast(Any, mapping), + nx=nx, + xx=xx, + ch=ch, + incr=incr, + gt=gt, + lt=lt, + ) + + def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: + return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max) + + def zcard(self, name: str | bytes) -> Any: + return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix())) + + def pubsub(self) -> PubSub: + return self._require_client().pubsub() + + def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: + return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint) + + def __getattr__(self, item: str) -> Any: + return getattr(self._require_client(), item) redis_client: RedisClientWrapper = RedisClientWrapper() _pubsub_redis_client: redis.Redis | RedisCluster | None = None +class RedisSSLParamsDict(TypedDict): + ssl_cert_reqs: int + ssl_ca_certs: str | None + ssl_certfile: str | None + ssl_keyfile: str | None + + +class RedisHealthParamsDict(TypedDict): + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + health_check_interval: int | None + + +class RedisClusterHealthParamsDict(TypedDict): + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + + +class RedisBaseParamsDict(TypedDict): + username: str | None + password: str | None + db: int + encoding: str + encoding_errors: str + decode_responses: bool + protocol: int + cache_config: CacheConfig | None + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + health_check_interval: int | None + + def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: """Get SSL configuration for Redis connection.""" if not dify_config.REDIS_USE_SSL: @@ -158,21 +285,60 @@ def _get_cache_configuration() -> CacheConfig | None: return CacheConfig() -def _get_base_redis_params() -> dict[str, Any]: - """Get base Redis connection parameters.""" - return { - "username": dify_config.REDIS_USERNAME, - "password": dify_config.REDIS_PASSWORD or None, - "db": dify_config.REDIS_DB, - "encoding": "utf-8", - "encoding_errors": "strict", - "decode_responses": False, - "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, - "cache_config": _get_cache_configuration(), +def _get_retry_policy() -> Retry: + """Build the shared retry policy for Redis connections.""" + return Retry( + backoff=ExponentialWithJitterBackoff( + base=dify_config.REDIS_RETRY_BACKOFF_BASE, + cap=dify_config.REDIS_RETRY_BACKOFF_CAP, + ), + retries=dify_config.REDIS_RETRY_RETRIES, + ) + + +def _get_connection_health_params() -> RedisHealthParamsDict: + """Get connection health and retry parameters for standalone and Sentinel Redis clients.""" + return RedisHealthParamsDict( + retry=_get_retry_policy(), + socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT, + socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, + health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL, + ) + + +def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict: + """Get retry and timeout parameters for Redis Cluster clients. + + RedisCluster does not support ``health_check_interval`` as a constructor + keyword (it is silently stripped by ``cleanup_kwargs``), so it is excluded + here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout`` + are passed through. + """ + health_params = _get_connection_health_params() + result: RedisClusterHealthParamsDict = { + "retry": health_params["retry"], + "socket_timeout": health_params["socket_timeout"], + "socket_connect_timeout": health_params["socket_connect_timeout"], } + return result -def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: +def _get_base_redis_params() -> RedisBaseParamsDict: + """Get base Redis connection parameters including retry and health policy.""" + return RedisBaseParamsDict( + username=dify_config.REDIS_USERNAME, + password=dify_config.REDIS_PASSWORD or None, + db=dify_config.REDIS_DB, + encoding="utf-8", + encoding_errors="strict", + decode_responses=False, + protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, + cache_config=_get_cache_configuration(), + **_get_connection_health_params(), + ) + + +def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create Redis client using Sentinel configuration.""" if not dify_config.REDIS_SENTINELS: raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") @@ -196,7 +362,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, sentinel_kwargs=sentinel_kwargs, ) - master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) + params: dict[str, Any] = {**redis_params} + master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params) return master @@ -215,6 +382,7 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: "password": dify_config.REDIS_CLUSTERS_PASSWORD, "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, "cache_config": _get_cache_configuration(), + **_get_cluster_connection_health_params(), } if dify_config.REDIS_MAX_CONNECTIONS: cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS @@ -222,41 +390,43 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: return cluster -def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: +def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create standalone Redis client.""" connection_class, ssl_kwargs = _get_ssl_configuration() - redis_params.update( - { - "host": dify_config.REDIS_HOST, - "port": dify_config.REDIS_PORT, - "connection_class": connection_class, - } - ) + params: dict[str, Any] = { + **redis_params, + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, + "connection_class": connection_class, + } if dify_config.REDIS_MAX_CONNECTIONS: - redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS if ssl_kwargs: - redis_params.update(ssl_kwargs) + params.update(ssl_kwargs) - pool = redis.ConnectionPool(**redis_params) + pool = redis.ConnectionPool(**params) client: redis.Redis = redis.Redis(connection_pool=pool) return client def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster: max_conns = dify_config.REDIS_MAX_CONNECTIONS - if use_clusters: - if max_conns: - return RedisCluster.from_url(pubsub_url, max_connections=max_conns) - else: - return RedisCluster.from_url(pubsub_url) + if use_clusters: + health_params = _get_cluster_connection_health_params() + kwargs: dict[str, Any] = {**health_params} + if max_conns: + kwargs["max_connections"] = max_conns + return RedisCluster.from_url(pubsub_url, **kwargs) + + standalone_health_params: dict[str, Any] = dict(_get_connection_health_params()) + kwargs = {**standalone_health_params} if max_conns: - return redis.Redis.from_url(pubsub_url, max_connections=max_conns) - else: - return redis.Redis.from_url(pubsub_url) + kwargs["max_connections"] = max_conns + return redis.Redis.from_url(pubsub_url, **kwargs) def init_app(app: DifyApp): diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 5cc58f27c4d..69d1f1ab07e 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,11 +5,12 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk - from graphon.model_runtime.errors.invoke import InvokeRateLimitError from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException + from graphon.model_runtime.errors.invoke import InvokeRateLimitError + try: from langfuse._utils import parse_error diff --git a/api/extensions/ext_socketio.py b/api/extensions/ext_socketio.py new file mode 100644 index 00000000000..5ed82bac8de --- /dev/null +++ b/api/extensions/ext_socketio.py @@ -0,0 +1,5 @@ +import socketio # type: ignore[reportMissingTypeStubs] + +from configs import dify_config + +sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS) diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index db599c5d495..64ff0f06745 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value +from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 3c83ab4f84e..7f77a0437a6 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,12 +20,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string +from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType @@ -354,11 +354,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): ) -> WorkflowRun | None: """Fallback to PostgreSQL query for records not in LogStore (with tenant isolation).""" from sqlalchemy import select - from sqlalchemy.orm import Session + from sqlalchemy.orm import sessionmaker from extensions.ext_database import db - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: stmt = select(WorkflowRun).where( WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id, WorkflowRun.app_id == app_id ) @@ -439,11 +439,11 @@ class LogstoreAPIWorkflowRunRepository(APIWorkflowRunRepository): def _fallback_get_workflow_run_by_id(self, run_id: str) -> WorkflowRun | None: """Fallback to PostgreSQL query for records not in LogStore.""" from sqlalchemy import select - from sqlalchemy.orm import Session + from sqlalchemy.orm import sessionmaker from extensions.ext_database import db - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: stmt = select(WorkflowRun).where(WorkflowRun.id == run_id) return session.scalar(stmt) diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index d0f3e2e2449..544109276d1 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -3,14 +3,14 @@ import logging import os import time -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 37952d6464a..dc7654a25ca 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -13,10 +13,6 @@ from collections.abc import Sequence from datetime import datetime from typing import Any -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -26,6 +22,10 @@ from core.repositories.factory import OrderConfig, WorkflowNodeExecutionReposito from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/otel/celery_sqlcommenter.py b/api/extensions/otel/celery_sqlcommenter.py index 8abb1ce15ab..15e52fb5efa 100644 --- a/api/extensions/otel/celery_sqlcommenter.py +++ b/api/extensions/otel/celery_sqlcommenter.py @@ -11,7 +11,7 @@ SQLAlchemy instrumentor appends comments to SQL statements. """ import logging -from typing import Any +from typing import Any, TypedDict from celery.signals import task_postrun, task_prerun from opentelemetry import context @@ -24,9 +24,17 @@ _SQLCOMMENTER_CONTEXT_KEY = "SQLCOMMENTER_ORM_TAGS_AND_VALUES" _TOKEN_ATTR = "_dify_sqlcommenter_context_token" -def _build_celery_sqlcommenter_tags(task: Any) -> dict[str, str | int]: +class CelerySqlcommenterTagsDict(TypedDict, total=False): + framework: str + task_name: str + traceparent: str + celery_retries: int + routing_key: str + + +def _build_celery_sqlcommenter_tags(task: Any) -> CelerySqlcommenterTagsDict: """Build SQL commenter tags from the current Celery task and OpenTelemetry context.""" - tags: dict[str, str | int] = {} + tags: CelerySqlcommenterTagsDict = {} try: tags["framework"] = f"celery:{_get_celery_version()}" diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 1dd92caeaea..ad838264277 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab handler = _get_handler_instance(handler_class or SpanHandler) tracer = get_tracer(__name__) - return handler.wrapper( - tracer=tracer, - wrapped=func, - args=args, - kwargs=kwargs, - ) + return handler.wrapper(tracer, func, *args, **kwargs) return cast(Callable[P, R], wrapper) diff --git a/api/extensions/otel/decorators/handler.py b/api/extensions/otel/decorators/handler.py index e465a615a65..b0d9fa7af63 100644 --- a/api/extensions/otel/decorators/handler.py +++ b/api/extensions/otel/decorators/handler.py @@ -1,8 +1,8 @@ import inspect -from collections.abc import Callable, Mapping +from collections.abc import Callable from typing import Any -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer class SpanHandler: @@ -16,9 +16,9 @@ class SpanHandler: exceptions. Handlers can override the wrapper method to customize behavior. """ - _signature_cache: dict[Callable[..., Any], inspect.Signature] = {} + _signature_cache: dict[Callable[..., object], inspect.Signature] = {} - def _build_span_name(self, wrapped: Callable[..., Any]) -> str: + def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str: """ Build the span name from the wrapped function. @@ -29,11 +29,11 @@ class SpanHandler: """ return f"{wrapped.__module__}.{wrapped.__qualname__}" - def _extract_arguments[T]( + def _extract_arguments[**P, R]( self, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, ) -> dict[str, Any] | None: """ Extract function arguments using inspect.signature. @@ -59,13 +59,13 @@ class SpanHandler: except Exception: return None - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: """ Fully control the wrapper behavior. diff --git a/api/extensions/otel/decorators/handlers/generate_handler.py b/api/extensions/otel/decorators/handlers/generate_handler.py index cc6c75304f2..df5142c310b 100644 --- a/api/extensions/otel/decorators/handlers/generate_handler.py +++ b/api/extensions/otel/decorators/handlers/generate_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -15,15 +14,15 @@ logger = logging.getLogger(__name__) class AppGenerateHandler(SpanHandler): """Span handler for ``AppGenerateService.generate``.""" - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py index 8abd60197c6..6b2112ceb2d 100644 --- a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py +++ b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -14,15 +13,15 @@ logger = logging.getLogger(__name__) class WorkflowAppRunnerHandler(SpanHandler): """Span handler for ``WorkflowAppRunner.run``.""" - def wrapper( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., Any], - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - ) -> Any: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 23d324f9ead..fbf379b3e55 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -10,17 +10,17 @@ Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when import json from typing import Any, Protocol -from graphon.enums import BuiltinNodeTypes -from graphon.file import File -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment def should_include_content() -> bool: diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 335c5cc29e2..ec3c78a12d4 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,12 +6,12 @@ import logging from collections.abc import Mapping from typing import Any -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 6df5f62c155..56672d1fd45 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,13 +6,13 @@ import logging from collections.abc import Sequence from typing import Any -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index b9fdd9e1caa..75ddbba4480 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,14 +2,14 @@ Parser for tool nodes that captures tool-specific metadata. """ -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.nodes.tool.entities import ToolNodeData from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.nodes.tool.entities import ToolNodeData class ToolNodeOTelParser: diff --git a/api/extensions/redis_names.py b/api/extensions/redis_names.py new file mode 100644 index 00000000000..9e63416daf1 --- /dev/null +++ b/api/extensions/redis_names.py @@ -0,0 +1,32 @@ +from configs import dify_config + + +def normalize_redis_key_prefix(prefix: str | None) -> str: + """Normalize the configured Redis key prefix for consistent runtime use.""" + if prefix is None: + return "" + return prefix.strip() + + +def get_redis_key_prefix() -> str: + """Read and normalize the current Redis key prefix from config.""" + return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + + +def serialize_redis_name(name: str, prefix: str | None = None) -> str: + """Convert a logical Redis name into the physical name used in Redis.""" + normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix) + if not normalized_prefix: + return name + return f"{normalized_prefix}:{name}" + + +def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes: + """Prefix string Redis names while preserving bytes inputs unchanged.""" + if isinstance(name, bytes): + return name + return serialize_redis_name(name, prefix) + + +def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]: + return tuple(serialize_redis_name_arg(name, prefix) for name in names) diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 18eed4e4810..05492327c83 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -10,6 +10,7 @@ import tempfile from collections.abc import Generator from io import BytesIO from pathlib import Path +from typing import Any import clickzetta from pydantic import BaseModel, model_validator @@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): """Validate the configuration values. This method will first try to use CLICKZETTA_VOLUME_* environment variables, diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 483bd6bbf69..1cb940b797b 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -13,7 +13,7 @@ import operator from dataclasses import asdict, dataclass from datetime import datetime from enum import StrEnum, auto -from typing import Any +from typing import Any, TypedDict from pydantic import TypeAdapter @@ -22,6 +22,17 @@ logger = logging.getLogger(__name__) _metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) +class StorageStatisticsDict(TypedDict): + total_files: int + active_files: int + archived_files: int + deleted_files: int + total_size: int + versions_count: int + oldest_file: str | None + newest_file: str | None + + class FileStatus(StrEnum): """File status enumeration""" @@ -54,7 +65,7 @@ class FileMetadata: return data @classmethod - def from_dict(cls, data: dict) -> FileMetadata: + def from_dict(cls, data: dict[str, Any]) -> FileMetadata: """Create instance from dictionary""" data = data.copy() data["created_at"] = datetime.fromisoformat(data["created_at"]) @@ -384,7 +395,7 @@ class FileLifecycleManager: logger.exception("Failed to cleanup old versions") return 0 - def get_storage_statistics(self) -> dict[str, Any]: + def get_storage_statistics(self) -> StorageStatisticsDict: """Get storage statistics Returns: @@ -393,16 +404,16 @@ class FileLifecycleManager: try: metadata_dict = self._load_metadata() - stats: dict[str, Any] = { - "total_files": len(metadata_dict), - "active_files": 0, - "archived_files": 0, - "deleted_files": 0, - "total_size": 0, - "versions_count": 0, - "oldest_file": None, - "newest_file": None, - } + stats = StorageStatisticsDict( + total_files=len(metadata_dict), + active_files=0, + archived_files=0, + deleted_files=0, + total_size=0, + versions_count=0, + oldest_file=None, + newest_file=None, + ) oldest_date = None newest_date = None @@ -437,9 +448,18 @@ class FileLifecycleManager: except Exception: logger.exception("Failed to get storage statistics") - return {} + return StorageStatisticsDict( + total_files=0, + active_files=0, + archived_files=0, + deleted_files=0, + total_size=0, + versions_count=0, + oldest_file=None, + newest_file=None, + ) - def _create_version_backup(self, filename: str, metadata: dict): + def _create_version_backup(self, filename: str, metadata: dict[str, Any]): """Create version backup""" try: # Read current file content @@ -467,7 +487,7 @@ class FileLifecycleManager: logger.warning("Failed to load metadata: %s", e) return {} - def _save_metadata(self, metadata_dict: dict): + def _save_metadata(self, metadata_dict: dict[str, Any]): """Save metadata file""" try: metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index 96f5915ff03..cd7f7db2959 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -2,6 +2,7 @@ import logging import os from collections.abc import Generator from pathlib import Path +from typing import Any import opendal from dotenv import dotenv_values @@ -19,7 +20,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str if key.startswith(config_prefix): kwargs[key[len(config_prefix) :].lower()] = value - file_env_vars: dict = dotenv_values(env_file_path) or {} + file_env_vars: dict[str, Any] = dotenv_values(env_file_path) or {} for key, value in file_env_vars.items(): if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: kwargs[key[len(config_prefix) :].lower()] = value diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index 7516d18c8e0..1d2ad4d4451 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -7,12 +7,12 @@ import uuid from collections.abc import Mapping, Sequence from typing import Any -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from models import ToolFile, UploadFile from .common import resolve_mapping_file_id @@ -135,29 +135,30 @@ def _build_from_local_file( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if row is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + row = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type", "custom"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - reference=build_file_reference(record_id=str(row.id)), - size=row.size, - storage_key=row.key, - ) + return File( + file_id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) def _build_from_remote_url( @@ -179,32 +180,33 @@ def _build_from_remote_url( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if upload_file is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type( - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - ) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - reference=build_file_reference(record_id=str(upload_file.id)), - size=upload_file.size, - storage_key=upload_file.key, - ) + return File( + file_id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) url = mapping.get("url") or mapping.get("remote_url") if not url: @@ -220,9 +222,9 @@ def _build_from_remote_url( ) return File( - id=mapping.get("id"), + file_id=mapping.get("id"), filename=filename, - type=file_type, + file_type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, @@ -247,30 +249,31 @@ def _build_from_tool_file( ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) - tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") + with session_factory.create_session() as session: + tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - reference=build_file_reference(record_id=str(tool_file.id)), - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) + return File( + file_id=mapping.get("id"), + filename=tool_file.name, + file_type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) def _build_from_datasource_file( @@ -289,31 +292,32 @@ def _build_from_datasource_file( UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) - datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + with session_factory.create_session() as session: + datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("datasource_file_id"), - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - reference=build_file_reference(record_id=str(datasource_file.id)), - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) + return File( + file_id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + file_type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py index 5582b85c956..4b3d5142386 100644 --- a/api/factories/file_factory/message_files.py +++ b/api/factories/file_factory/message_files.py @@ -4,9 +4,8 @@ from __future__ import annotations from collections.abc import Sequence -from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig - from core.app.file_access import FileAccessControllerProtocol +from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig from models import MessageFile from .builders import build_from_mapping diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py index db3a7f30159..dba4c84407f 100644 --- a/api/factories/file_factory/storage_keys.py +++ b/api/factories/file_factory/storage_keys.py @@ -5,12 +5,12 @@ from __future__ import annotations import uuid from collections.abc import Mapping, Sequence -from graphon.file import File, FileTransferMethod from sqlalchemy import select from sqlalchemy.orm import Session from core.app.file_access import FileAccessControllerProtocol from core.workflow.file_reference import build_file_reference, parse_file_reference +from graphon.file import File, FileTransferMethod from models import ToolFile, UploadFile diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 57205b5739f..fd7acb14d3a 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -8,6 +8,11 @@ shared conversion functions for legacy callers and tests. from collections.abc import Mapping, Sequence from typing import Any, cast +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) from graphon.variables.exc import VariableError from graphon.variables.factory import ( TypeMismatchError, @@ -31,12 +36,6 @@ from graphon.variables.variables import ( VariableBase, ) -from configs import dify_config -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) - __all__ = [ "TypeMismatchError", "UnsupportedSegmentTypeError", diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index b5acbbbcb4d..d518114777a 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False): def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): - return v.value_type.exposed_type().value + return str(v.value_type.exposed_type()) else: value_type = v.get("value_type") if value_type is None: raise ValueError("value_type is required but not provided") - return value_type.exposed_type().value + return str(value_type.exposed_type()) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index a6469507226..b2a0e92c476 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -2,7 +2,9 @@ from __future__ import annotations from datetime import datetime -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import Field, field_validator + +from fields.base import ResponseModel def _to_timestamp(value: datetime | int | None) -> int | None: @@ -11,16 +13,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None: return value -class ResponseModel(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="ignore", - populate_by_name=True, - serialize_by_alias=True, - protected_namespaces=(), - ) - - class Annotation(ResponseModel): id: str question: str | None = None diff --git a/api/fields/base.py b/api/fields/base.py new file mode 100644 index 00000000000..b806ab6c9c8 --- /dev/null +++ b/api/fields/base.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + + +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index b1d1b4caac3..bf5c9ffcb12 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,22 +3,14 @@ from __future__ import annotations from datetime import datetime from typing import Any +from pydantic import Field, field_validator, model_validator + +from fields.base import ResponseModel from graphon.file import File -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator type JSONValue = Any -class ResponseModel(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="ignore", - populate_by_name=True, - serialize_by_alias=True, - protected_namespaces=(), - ) - - class MessageFile(ResponseModel): id: str filename: str @@ -88,7 +80,7 @@ class Feedback(ResponseModel): from_account: SimpleAccount | None = None -class Annotation(ResponseModel): +class ConversationAnnotation(ResponseModel): id: str question: str | None = None content: str @@ -103,8 +95,8 @@ class Annotation(ResponseModel): return value -class AnnotationHitHistory(ResponseModel): - annotation_id: str +class ConversationAnnotationHitHistory(ResponseModel): + annotation_id: str = Field(validation_alias="id") annotation_create_account: SimpleAccount | None = None created_at: int | None = None @@ -151,7 +143,7 @@ class MessageDetail(ResponseModel): query: str message: JSONValue message_tokens: int - answer: str + answer: str = Field(validation_alias="re_sign_file_url_answer") answer_tokens: int provider_response_latency: float from_source: str @@ -159,12 +151,12 @@ class MessageDetail(ResponseModel): from_account_id: str | None = None feedbacks: list[Feedback] workflow_run_id: str | None = None - annotation: Annotation | None = None - annotation_hit_history: AnnotationHitHistory | None = None + annotation: ConversationAnnotation | None = None + annotation_hit_history: ConversationAnnotationHitHistory | None = None created_at: int | None = None agent_thoughts: list[AgentThought] message_files: list[MessageFile] - metadata: JSONValue + metadata: JSONValue = Field(validation_alias="message_metadata_dict") status: str error: str | None = None parent_message_id: str | None = None @@ -204,7 +196,7 @@ class ModelConfig(ResponseModel): class SimpleModelConfig(ResponseModel): - model: JSONValue | None = None + model: JSONValue | None = Field(default=None, validation_alias="model_dict") pre_prompt: str | None = None @@ -219,6 +211,11 @@ class SimpleMessageDetail(ResponseModel): def _normalize_inputs(cls, value: JSONValue) -> JSONValue: return format_files_contained(value) + @field_validator("message", mode="before") + @classmethod + def _normalize_message(cls, value: JSONValue) -> str: + return message_text(value) + class Conversation(ResponseModel): id: str @@ -231,19 +228,26 @@ class Conversation(ResponseModel): read_at: int | None = None created_at: int | None = None updated_at: int | None = None - annotation: Annotation | None = None + annotation: ConversationAnnotation | None = None model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config") user_feedback_stats: FeedbackStat | None = None admin_feedback_stats: FeedbackStat | None = None - message: SimpleMessageDetail | None = None + message: SimpleMessageDetail | None = Field(default=None, validation_alias="first_message") + + @field_validator("read_at", "created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value class ConversationPagination(ResponseModel): page: int - limit: int + limit: int = Field(validation_alias="per_page") total: int - has_more: bool - data: list[Conversation] + has_more: bool = Field(validation_alias="has_next") + data: list[Conversation] = Field(validation_alias="items") class ConversationMessageDetail(ResponseModel): @@ -254,7 +258,14 @@ class ConversationMessageDetail(ResponseModel): from_account_id: str | None = None created_at: int | None = None model_config_: ModelConfig | None = Field(default=None, alias="model_config") - message: MessageDetail | None = None + message: MessageDetail | None = Field(default=None, validation_alias="first_message") + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value class ConversationWithSummary(ResponseModel): @@ -266,7 +277,7 @@ class ConversationWithSummary(ResponseModel): from_account_id: str | None = None from_account_name: str | None = None name: str - summary: str + summary: str = Field(validation_alias="summary_or_query") read_at: int | None = None created_at: int | None = None updated_at: int | None = None @@ -277,13 +288,20 @@ class ConversationWithSummary(ResponseModel): admin_feedback_stats: FeedbackStat | None = None status_count: StatusCount | None = None + @field_validator("read_at", "created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + class ConversationWithSummaryPagination(ResponseModel): page: int - limit: int + limit: int = Field(validation_alias="per_page") total: int - has_more: bool - data: list[ConversationWithSummary] + has_more: bool = Field(validation_alias="has_next") + data: list[ConversationWithSummary] = Field(validation_alias="items") class ConversationDetail(ResponseModel): @@ -301,6 +319,13 @@ class ConversationDetail(ResponseModel): user_feedback_stats: FeedbackStat | None = None admin_feedback_stats: FeedbackStat | None = None + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + def to_timestamp(value: datetime | None) -> int | None: if value is None: diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index c55014a368e..e4219ba1ee4 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,5 +1,13 @@ -from flask_restx import Namespace, fields +from __future__ import annotations +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from pydantic import field_validator + +from fields.base import ResponseModel +from graphon.variables.types import SegmentType from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type @@ -29,6 +37,74 @@ conversation_variable_infinite_scroll_pagination_fields = { } +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def _normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type()) + if isinstance(value, str): + try: + return str(SegmentType(value).exposed_type()) + except ValueError: + return value + try: + return serialize_value_type(value) + except (AttributeError, TypeError, ValueError): + pass + + try: + return serialize_value_type({"value_type": value}) + except (AttributeError, TypeError, ValueError): + value_attr = getattr(value, "value", None) + if value_attr is not None: + return str(value_attr) + return str(value) + + @field_validator("value", mode="before") + @classmethod + def _normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class PaginatedConversationVariableResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationVariableResponse] + + +class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[ConversationVariableResponse] + + def build_conversation_variable_model(api_or_ns: Namespace): """Build the conversation variable model for the API or Namespace.""" return api_or_ns.model("ConversationVariable", conversation_variable_fields) diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index df1980616aa..3851933cc2e 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -3,7 +3,9 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields -from pydantic import BaseModel, ConfigDict, Field +from pydantic import Field + +from fields.base import ResponseModel simple_end_user_fields = { "id": fields.String, @@ -26,16 +28,6 @@ end_user_detail_fields = { } -class ResponseModel(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="ignore", - populate_by_name=True, - serialize_by_alias=True, - protected_namespaces=(), - ) - - class SimpleEndUser(ResponseModel): id: str type: str diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 913fb675f9e..ad8b95e4dc5 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -2,17 +2,9 @@ from __future__ import annotations from datetime import datetime -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import field_validator - -class ResponseModel(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="ignore", - populate_by_name=True, - serialize_by_alias=True, - protected_namespaces=(), - ) +from fields.base import ResponseModel def _to_timestamp(value: datetime | int | None) -> int | None: diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index b8daa5af303..67b320beaa0 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,8 +3,10 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields +from pydantic import computed_field, field_validator + +from fields.base import ResponseModel from graphon.file import helpers as file_helpers -from pydantic import BaseModel, ConfigDict, computed_field, field_validator simple_account_fields = { "id": fields.String, @@ -27,16 +29,6 @@ def _build_avatar_url(avatar: str | None) -> str | None: return file_helpers.get_signed_file_url(avatar) -class ResponseModel(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="ignore", - populate_by_name=True, - serialize_by_alias=True, - protected_namespaces=(), - ) - - class SimpleAccount(ResponseModel): id: str name: str diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index a063a643b45..ca18f1c203d 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,19 +3,16 @@ from __future__ import annotations from datetime import datetime from uuid import uuid4 -from graphon.file import File -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel +from fields.base import ResponseModel from fields.conversation_fields import AgentThought, JSONValue, MessageFile +from graphon.file import File type JSONValueType = JSONValue -class ResponseModel(BaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") - - class SimpleFeedback(ResponseModel): rating: str | None = None diff --git a/api/fields/online_user_fields.py b/api/fields/online_user_fields.py new file mode 100644 index 00000000000..bdbe19679ce --- /dev/null +++ b/api/fields/online_user_fields.py @@ -0,0 +1,16 @@ +from flask_restx import fields + +online_user_partial_fields = { + "user_id": fields.String, + "username": fields.String, + "avatar": fields.String, +} + +workflow_online_users_fields = { + "app_id": fields.String, + "users": fields.List(fields.Nested(online_user_partial_fields)), +} + +online_user_list_fields = { + "data": fields.List(fields.Nested(workflow_online_users_fields)), +} diff --git a/api/fields/raws.py b/api/fields/raws.py index 4c65cdab7af..ee6f53b360c 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,5 @@ from flask_restx import fields + from graphon.file import File diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index 7cb64e5ca8e..a3629f477a0 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,16 +1,6 @@ from __future__ import annotations -from pydantic import BaseModel, ConfigDict - - -class ResponseModel(BaseModel): - model_config = ConfigDict( - from_attributes=True, - extra="ignore", - populate_by_name=True, - serialize_by_alias=True, - protected_namespaces=(), - ) +from fields.base import ResponseModel class DataSetTag(ResponseModel): diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index d0e762f62b3..1b2c71255dd 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,8 +1,17 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from pydantic import field_validator + +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser, simple_end_user_fields +from fields.member_fields import SimpleAccount, simple_account_fields from fields.workflow_run_fields import ( + WorkflowRunForArchivedLogResponse, + WorkflowRunForLogResponse, build_workflow_run_for_archived_log_model, build_workflow_run_for_log_model, workflow_run_for_archived_log_fields, @@ -85,3 +94,55 @@ def build_workflow_archived_log_pagination_model(api_or_ns: Namespace): copied_fields = workflow_archived_log_pagination_fields.copy() copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model)) return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields) + + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: Any = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowArchivedLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForArchivedLogResponse | None = None + trigger_metadata: Any = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +class WorkflowArchivedLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowArchivedLogPartialResponse] diff --git a/api/fields/workflow_comment_fields.py b/api/fields/workflow_comment_fields.py new file mode 100644 index 00000000000..c708dd34601 --- /dev/null +++ b/api/fields/workflow_comment_fields.py @@ -0,0 +1,96 @@ +from flask_restx import fields + +from libs.helper import AvatarUrlField, TimestampField + +# basic account fields for comments +account_fields = { + "id": fields.String, + "name": fields.String, + "email": fields.String, + "avatar_url": AvatarUrlField, +} + +# Comment mention fields +workflow_comment_mention_fields = { + "mentioned_user_id": fields.String, + "mentioned_user_account": fields.Nested(account_fields, allow_null=True), + "reply_id": fields.String, +} + +# Comment reply fields +workflow_comment_reply_fields = { + "id": fields.String, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, +} + +# Basic comment fields (for list views) +workflow_comment_basic_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "reply_count": fields.Integer, + "mention_count": fields.Integer, + "participants": fields.List(fields.Nested(account_fields)), +} + +# Detailed comment fields (for single comment view) +workflow_comment_detail_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "replies": fields.List(fields.Nested(workflow_comment_reply_fields)), + "mentions": fields.List(fields.Nested(workflow_comment_mention_fields)), +} + +# Comment creation response fields (simplified) +workflow_comment_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Comment update response fields (simplified) +workflow_comment_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} + +# Comment resolve response fields +workflow_comment_resolve_fields = { + "id": fields.String, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, +} + +# Reply creation response fields (simplified) +workflow_comment_reply_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Reply update response fields +workflow_comment_reply_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index b0b6cc0b483..6e947858ba2 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields -from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter from fields.member_fields import simple_account_fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type @@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.exposed_type().value, + "value_type": str(value.value_type.exposed_type()), "description": value.description, } if isinstance(value, dict): diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 35bb442c599..8c659086ed8 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,7 +1,14 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from pydantic import Field, field_validator + +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser, simple_end_user_fields +from fields.member_fields import SimpleAccount, simple_account_fields from libs.helper import TimestampField workflow_run_for_log_fields = { @@ -147,3 +154,174 @@ workflow_run_node_execution_fields = { workflow_run_node_execution_list_fields = { "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), } + + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowRunForArchivedLogResponse(ResponseModel): + id: str + status: str | None = None + triggered_from: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + +class WorkflowRunForListResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_by_account: SimpleAccount | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + retry_index: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AdvancedChatWorkflowRunForListResponse(WorkflowRunForListResponse): + conversation_id: str | None = None + message_id: str | None = None + + +class AdvancedChatWorkflowRunPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[AdvancedChatWorkflowRunForListResponse] + + +class WorkflowRunPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[WorkflowRunForListResponse] + + +class WorkflowRunCountResponse(ResponseModel): + total: int + running: int + succeeded: int + failed: int + stopped: int + partial_succeeded: int = Field(validation_alias="partial-succeeded") + + +class WorkflowRunDetailResponse(ResponseModel): + id: str + version: str | None = None + graph: Any = Field(validation_alias="graph_dict") + inputs: Any = Field(validation_alias="inputs_dict") + status: str | None = None + outputs: Any = Field(validation_alias="outputs_dict") + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowRunNodeExecutionResponse(ResponseModel): + id: str + index: int | None = None + predecessor_node_id: str | None = None + node_id: str | None = None + node_type: str | None = None + title: str | None = None + inputs: Any = Field(default=None, validation_alias="inputs_dict") + process_data: Any = Field(default=None, validation_alias="process_data_dict") + outputs: Any = Field(default=None, validation_alias="outputs_dict") + status: str | None = None + error: str | None = None + elapsed_time: float | None = None + execution_metadata: Any = Field(default=None, validation_alias="execution_metadata_dict") + extras: Any = None + created_at: int | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + finished_at: int | None = None + inputs_truncated: bool | None = None + outputs_truncated: bool | None = None + process_data_truncated: bool | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowRunNodeExecutionListResponse(ResponseModel): + data: list[WorkflowRunNodeExecutionResponse] diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index 40027bc4245..4db79a15a96 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -3,7 +3,7 @@ import queue import threading import types from collections.abc import Generator, Iterator -from typing import Self +from typing import Any, Self from libs.broadcast_channel.channel import Subscription from libs.broadcast_channel.exc import SubscriptionClosedError @@ -221,7 +221,7 @@ class RedisSubscriptionBase(Subscription): """Unsubscribe from the Redis topic using the appropriate command.""" raise NotImplementedError - def _get_message(self) -> dict | None: + def _get_message(self) -> dict[str, Any] | None: """Get a message from Redis using the appropriate method.""" raise NotImplementedError diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index bd6d58c53fc..b76a23eb3cd 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -1,5 +1,8 @@ from __future__ import annotations +from typing import Any + +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -30,12 +33,13 @@ class Topic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.publish(self._topic, payload) + self._client.publish(self._redis_topic, payload) def as_subscriber(self) -> Subscriber: return self @@ -44,7 +48,7 @@ class Topic: return _RedisSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) @@ -62,7 +66,7 @@ class _RedisSubscription(RedisSubscriptionBase): assert self._pubsub is not None self._pubsub.unsubscribe(self._topic) - def _get_message(self) -> dict | None: + def _get_message(self) -> dict[str, Any] | None: assert self._pubsub is not None return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 20c43b8bbb3..919d8d622e0 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -1,5 +1,8 @@ from __future__ import annotations +from typing import Any + +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -28,12 +31,13 @@ class ShardedTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr] + self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr] def as_subscriber(self) -> Subscriber: return self @@ -42,7 +46,7 @@ class ShardedTopic: return _RedisShardedSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) @@ -60,7 +64,7 @@ class _RedisShardedSubscription(RedisSubscriptionBase): assert self._pubsub is not None self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined] - def _get_message(self) -> dict | None: + def _get_message(self) -> dict[str, Any] | None: assert self._pubsub is not None # NOTE(QuantumGhost): this is an issue in # upstream code. If Sharded PubSub is used with Cluster, the diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index 983f785027a..55ff6cd4f93 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -6,6 +6,7 @@ import threading from collections.abc import Iterator from typing import Self +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from libs.broadcast_channel.exc import SubscriptionClosedError from redis import Redis, RedisCluster @@ -35,7 +36,7 @@ class StreamsTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): self._client = redis_client self._topic = topic - self._key = f"stream:{topic}" + self._key = serialize_redis_name(f"stream:{topic}") self._retention_seconds = retention_seconds self.max_length = 5000 diff --git a/api/libs/db_migration_lock.py b/api/libs/db_migration_lock.py index 1d3a81e0a20..b5fe38342a4 100644 --- a/api/libs/db_migration_lock.py +++ b/api/libs/db_migration_lock.py @@ -14,9 +14,15 @@ from __future__ import annotations import logging import threading -from typing import Any +from typing import TYPE_CHECKING, Any +import redis +from redis.cluster import RedisCluster from redis.exceptions import LockNotOwnedError, RedisError +from redis.lock import Lock + +if TYPE_CHECKING: + from extensions.ext_redis import RedisClientWrapper logger = logging.getLogger(__name__) @@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock: primary error/exit code. """ - _redis_client: Any + _redis_client: redis.Redis | RedisCluster | RedisClientWrapper _name: str _ttl_seconds: float _renew_interval_seconds: float _log_context: str | None _logger: logging.Logger - _lock: Any + _lock: Lock | None _stop_event: threading.Event | None _thread: threading.Thread | None _acquired: bool def __init__( self, - redis_client: Any, + redis_client: redis.Redis | RedisCluster | RedisClientWrapper, name: str, ttl_seconds: float = 60, renew_interval_seconds: float | None = None, @@ -97,7 +103,10 @@ class DbMigrationAutoRenewLock: timeout=self._ttl_seconds, thread_local=False, ) - acquired = bool(self._lock.acquire(*args, **kwargs)) + lock = self._lock + if lock is None: + raise RuntimeError("Redis lock initialization failed.") + acquired = bool(lock.acquire(*args, **kwargs)) self._acquired = acquired if acquired: self._start_heartbeat() @@ -127,7 +136,7 @@ class DbMigrationAutoRenewLock: ) self._thread.start() - def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None: + def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None: while not stop_event.wait(self._renew_interval_seconds): try: lock.reacquire() diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index 0828cf80bf1..1519f07bb1b 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -37,6 +37,7 @@ class EmailType(StrEnum): ENTERPRISE_CUSTOM = auto() QUEUE_MONITOR_ALERT = auto() DOCUMENT_CLEAN_NOTIFY = auto() + WORKFLOW_COMMENT_MENTION = auto() EMAIL_REGISTER = auto() EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto() RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto() @@ -453,6 +454,18 @@ def create_default_email_config() -> EmailI18nConfig: branded_template_path="clean_document_job_mail_template_zh-CN.html", ), }, + EmailType.WORKFLOW_COMMENT_MENTION: { + EmailLanguage.EN_US: EmailTemplate( + subject="You were mentioned in a workflow comment", + template_path="workflow_comment_mention_template_en-US.html", + branded_template_path="without-brand/workflow_comment_mention_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="你在工作流评论中被提及", + template_path="workflow_comment_mention_template_zh-CN.html", + branded_template_path="without-brand/workflow_comment_mention_template_zh-CN.html", + ), + }, EmailType.TRIGGER_EVENTS_LIMIT_SANDBOX: { EmailLanguage.EN_US: EmailTemplate( subject="You’ve reached your Sandbox Trigger Events limit", diff --git a/api/libs/exception.py b/api/libs/exception.py index 73379dfdeda..1e4bbb44f63 100644 --- a/api/libs/exception.py +++ b/api/libs/exception.py @@ -1,9 +1,11 @@ +from typing import Any + from werkzeug.exceptions import HTTPException class BaseHTTPException(HTTPException): error_code: str = "unknown" - data: dict | None = None + data: dict[str, Any] | None = None def __init__(self, description=None, response=None): super().__init__(description, response) diff --git a/api/libs/external_api.py b/api/libs/external_api.py index e8592407c35..f907d177500 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -17,7 +17,6 @@ def http_status_message(code): def register_external_error_handlers(api: Api): - @api.errorhandler(HTTPException) def handle_http_exception(e: HTTPException): got_request_exception.send(current_app, exception=e) @@ -74,27 +73,18 @@ def register_external_error_handlers(api: Api): headers["Set-Cookie"] = build_force_logout_cookie_headers() return data, status_code, headers - _ = handle_http_exception - - @api.errorhandler(ValueError) def handle_value_error(e: ValueError): got_request_exception.send(current_app, exception=e) status_code = 400 data = {"code": "invalid_param", "message": str(e), "status": status_code} return data, status_code - _ = handle_value_error - - @api.errorhandler(AppInvokeQuotaExceededError) def handle_quota_exceeded(e: AppInvokeQuotaExceededError): got_request_exception.send(current_app, exception=e) status_code = 429 data = {"code": "too_many_requests", "message": str(e), "status": status_code} return data, status_code - _ = handle_quota_exceeded - - @api.errorhandler(Exception) def handle_general_exception(e: Exception): got_request_exception.send(current_app, exception=e) @@ -113,7 +103,10 @@ def register_external_error_handlers(api: Api): return data, status_code - _ = handle_general_exception + api.errorhandler(HTTPException)(handle_http_exception) + api.errorhandler(ValueError)(handle_value_error) + api.errorhandler(AppInvokeQuotaExceededError)(handle_quota_exceeded) + api.errorhandler(Exception)(handle_general_exception) class ExternalApi(Api): diff --git a/api/libs/flask_utils.py b/api/libs/flask_utils.py index 52fc787c790..838af2bf321 100644 --- a/api/libs/flask_utils.py +++ b/api/libs/flask_utils.py @@ -1,5 +1,5 @@ import contextvars -from collections.abc import Iterator +from collections.abc import Generator # Changed from Iterator from contextlib import contextmanager from typing import TYPE_CHECKING @@ -13,7 +13,7 @@ if TYPE_CHECKING: def preserve_flask_contexts( flask_app: Flask, context_vars: contextvars.Context, -) -> Iterator[None]: +) -> Generator[None, None, None]: # Changed from Iterator[None] """ A context manager that handles: 1. flask-login's UserProxy copy diff --git a/api/libs/helper.py b/api/libs/helper.py index a7b3da77ff8..ac69a110840 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -10,20 +10,21 @@ import uuid from collections.abc import Callable, Generator, Mapping from datetime import datetime from hashlib import sha256 -from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast +from typing import TYPE_CHECKING, Annotated, Any, Protocol, cast from uuid import UUID from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields -from graphon.file import helpers as file_helpers -from graphon.model_runtime.utils.encoders import jsonable_encoder -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from pydantic.functional_validators import AfterValidator +from typing_extensions import TypedDict from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account @@ -32,6 +33,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class _TokenData(TypedDict, total=False): + account_id: str | None + email: str + token_type: str + code: str + old_email: str + + +_token_data_adapter: TypeAdapter[_TokenData] = TypeAdapter(_TokenData) + + def _stream_with_request_context(response: object) -> Any: """Bridge Flask's loosely-typed streaming helper without leaking casts into callers.""" return cast(Any, stream_with_context)(response) @@ -69,7 +81,7 @@ def escape_like_pattern(pattern: str) -> str: return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") -def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: +def extract_tenant_id(user: "Account | EndUser") -> str | None: """ Extract tenant_id from Account or EndUser object. @@ -108,10 +120,22 @@ class AppIconUrlField(fields.Raw): obj = obj["app"] if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE: - return file_helpers.get_signed_file_url(obj.icon) + return build_icon_url(obj.icon_type, obj.icon) return None +def build_icon_url(icon_type: Any, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + + from models.model import IconType + + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE: + return None + return file_helpers.get_signed_file_url(icon) + + class AvatarUrlField(fields.Raw): def output(self, key, obj, **kwargs): if obj is None: @@ -152,7 +176,10 @@ def email(email): EmailStr = Annotated[str, AfterValidator(email)] -def uuid_value(value: Any) -> str: +def uuid_value(value: str | UUID) -> str: + if isinstance(value, UUID): + return str(value) + if value == "": return str(value) @@ -393,9 +420,9 @@ class TokenManager: def generate_token( cls, token_type: str, - account: Optional["Account"] = None, + account: "Account | None" = None, email: str | None = None, - additional_data: dict | None = None, + additional_data: dict[str, Any] | None = None, ) -> str: if account is None and email is None: raise ValueError("Account or email must be provided") @@ -443,7 +470,7 @@ class TokenManager: if token_data_json is None: logger.warning("%s token %s not found with key %s", token_type, token, key) return None - token_data: dict[str, Any] | None = json.loads(token_data_json) + token_data = dict(_token_data_adapter.validate_json(token_data_json)) return token_data @classmethod @@ -453,9 +480,7 @@ class TokenManager: return current_token @classmethod - def _set_current_token_for_account( - cls, account_id: str, token: str, token_type: str, expiry_minutes: Union[int, float] - ): + def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_minutes: int | float): key = cls._get_account_token_key(account_id, token_type) expiry_seconds = int(expiry_minutes * 60) redis_client.setex(key, expiry_seconds, token) diff --git a/api/libs/login.py b/api/libs/login.py index 68a2050747d..067597cb3c2 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -2,19 +2,19 @@ from __future__ import annotations from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast -from flask import current_app, g, has_request_context, request +from flask import Response, current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS from werkzeug.local import LocalProxy from configs import dify_config +from dify_app import DifyApp +from extensions.ext_login import DifyLoginManager from libs.token import check_csrf_token from models import Account if TYPE_CHECKING: - from flask.typing import ResponseReturnValue - from models.model import EndUser @@ -29,7 +29,13 @@ def _resolve_current_user() -> EndUser | Account | None: return get_current_object() if callable(get_current_object) else user_proxy # type: ignore -def current_account_with_tenant(): +def _get_login_manager() -> DifyLoginManager: + """Return the project login manager with Dify's narrowed unauthorized contract.""" + app = cast(DifyApp, current_app) + return app.login_manager + + +def current_account_with_tenant() -> tuple[Account, str]: """ Resolve the underlying account for the current user proxy and ensure tenant context exists. Allows tests to supply plain Account mocks without the LocalProxy helper. @@ -42,7 +48,7 @@ def current_account_with_tenant(): return user, user.current_tenant_id -def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]: +def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]: """ If you decorate a view with this, it will ensure that the current user is logged in and authenticated before calling the actual view. (If they are @@ -77,13 +83,16 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu """ @wraps(func) - def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue: + def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response: if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: return current_app.ensure_sync(func)(*args, **kwargs) user = _resolve_current_user() if user is None or not user.is_authenticated: - return current_app.login_manager.unauthorized() # type: ignore + # `DifyLoginManager` guarantees that the registered unauthorized handler + # is surfaced here as a concrete Flask `Response`. + unauthorized_response: Response = _get_login_manager().unauthorized() + return unauthorized_response g._login_user = user # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. @@ -96,7 +105,7 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | ResponseRetu def _get_user() -> EndUser | Account | None: if has_request_context(): if "_login_user" not in g: - current_app.login_manager._load_user() # type: ignore + _get_login_manager().load_user_from_request_context() return g._login_user diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 9b53918f24e..934aacb45b3 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -6,8 +6,8 @@ from flask_login import current_user from pydantic import TypeAdapter from sqlalchemy import select +from core.db.session_factory import session_factory from core.helper.http_client_pooling import get_pooled_http_client -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def save_internal_access_token(self, access_token: str) -> None: workspace_name = self.notion_workspace_name(access_token) @@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def sync_data_source(self, binding_id: str) -> None: # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, + ) ) - ) - if data_source_binding: - # get all authorized pages - pages = self.get_authorized_pages(data_source_binding.access_token) - source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) - new_source_info = self._build_source_info( - workspace_name=source_info["workspace_name"], - workspace_icon=source_info["workspace_icon"], - workspace_id=source_info["workspace_id"], - pages=pages, - ) - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - raise ValueError("Data source binding not found") + if data_source_binding: + # get all authorized pages + pages = self.get_authorized_pages(data_source_binding.access_token) + source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) + new_source_info = self._build_source_info( + workspace_name=source_info["workspace_name"], + workspace_icon=source_info["workspace_icon"], + workspace_id=source_info["workspace_id"], + pages=pages, + ) + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + raise ValueError("Data source binding not found") def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: pages: list[NotionPageSummary] = [] diff --git a/api/libs/pyrefly_type_coverage.py b/api/libs/pyrefly_type_coverage.py new file mode 100644 index 00000000000..369b8dff3c1 --- /dev/null +++ b/api/libs/pyrefly_type_coverage.py @@ -0,0 +1,145 @@ +"""Helpers for generating type-coverage summaries from pyrefly report output.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import TypedDict + + +class CoverageSummary(TypedDict): + n_modules: int + n_typable: int + n_typed: int + n_any: int + n_untyped: int + coverage: float + strict_coverage: float + + +_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__) + +_EMPTY_SUMMARY: CoverageSummary = { + "n_modules": 0, + "n_typable": 0, + "n_typed": 0, + "n_any": 0, + "n_untyped": 0, + "coverage": 0.0, + "strict_coverage": 0.0, +} + + +def parse_summary(report_json: str) -> CoverageSummary: + """Extract the summary section from ``pyrefly report`` JSON output. + + Returns an empty summary when *report_json* is empty or malformed so that + the CI workflow can degrade gracefully instead of crashing. + """ + if not report_json or not report_json.strip(): + return _EMPTY_SUMMARY.copy() + + try: + data = json.loads(report_json) + except json.JSONDecodeError: + return _EMPTY_SUMMARY.copy() + + summary = data.get("summary") + if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary): + return _EMPTY_SUMMARY.copy() + + return { + "n_modules": summary["n_modules"], + "n_typable": summary["n_typable"], + "n_typed": summary["n_typed"], + "n_any": summary["n_any"], + "n_untyped": summary["n_untyped"], + "coverage": summary["coverage"], + "strict_coverage": summary["strict_coverage"], + } + + +def format_summary_markdown(summary: CoverageSummary) -> str: + """Format a single coverage summary as a Markdown table.""" + + return ( + "| Metric | Value |\n" + "| --- | ---: |\n" + f"| Modules | {summary['n_modules']} |\n" + f"| Typable symbols | {summary['n_typable']:,} |\n" + f"| Typed symbols | {summary['n_typed']:,} |\n" + f"| Untyped symbols | {summary['n_untyped']:,} |\n" + f"| Any symbols | {summary['n_any']:,} |\n" + f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n" + f"| Strict coverage | {summary['strict_coverage']:.2f}% |" + ) + + +def format_comparison_markdown( + base: CoverageSummary, + pr: CoverageSummary, +) -> str: + """Format a comparison between base and PR coverage as Markdown.""" + + coverage_delta = pr["coverage"] - base["coverage"] + strict_delta = pr["strict_coverage"] - base["strict_coverage"] + typed_delta = pr["n_typed"] - base["n_typed"] + untyped_delta = pr["n_untyped"] - base["n_untyped"] + + def _fmt_delta(value: float, fmt: str = ".2f") -> str: + sign = "+" if value > 0 else "" + return f"{sign}{value:{fmt}}" + + lines = [ + "| Metric | Base | PR | Delta |", + "| --- | ---: | ---: | ---: |", + (f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"), + ( + f"| Strict coverage | {base['strict_coverage']:.2f}% " + f"| {pr['strict_coverage']:.2f}% " + f"| {_fmt_delta(strict_delta)}% |" + ), + (f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"), + (f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"), + ( + f"| Modules | {base['n_modules']} " + f"| {pr['n_modules']} " + f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |" + ), + ] + return "\n".join(lines) + + +def main() -> int: + """Read pyrefly report JSON from stdin and print a Markdown summary. + + Accepts an optional ``--base `` argument. When provided, the output + includes a base-vs-PR comparison table. + """ + + args = sys.argv[1:] + + base_file: str | None = None + if "--base" in args: + idx = args.index("--base") + if idx + 1 >= len(args): + sys.stderr.write("error: --base requires a file path\n") + return 1 + base_file = args[idx + 1] + + pr_report = sys.stdin.read() + pr_summary = parse_summary(pr_report) + + if base_file is not None: + base_text = Path(base_file).read_text() if Path(base_file).exists() else "" + base_summary = parse_summary(base_text) + sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n") + else: + sys.stdout.write(format_summary_markdown(pr_summary) + "\n") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index c047c54d062..0338641d114 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -1,4 +1,5 @@ import logging +from typing import Any import sendgrid from python_http_client.exceptions import ForbiddenError, UnauthorizedError @@ -12,7 +13,7 @@ class SendGridClient: self.sendgrid_api_key = sendgrid_api_key self._from = _from - def send(self, mail: dict): + def send(self, mail: dict[str, Any]): logger.debug("Sending email with SendGrid") _to = "" try: diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 6f82f1440a2..53906d1769f 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -2,6 +2,7 @@ import logging import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from typing import Any from configs import dify_config @@ -20,7 +21,7 @@ class SMTPClient: self.use_tls = use_tls self.opportunistic_tls = opportunistic_tls - def send(self, mail: dict): + def send(self, mail: dict[str, Any]): smtp: smtplib.SMTP | None = None local_host = dify_config.SMTP_LOCAL_HOSTNAME try: diff --git a/api/libs/token.py b/api/libs/token.py index a34db707649..5b043465ac1 100644 --- a/api/libs/token.py +++ b/api/libs/token.py @@ -47,23 +47,17 @@ def _cookie_domain() -> str | None: def _real_cookie_name(cookie_name: str) -> str: if is_secure() and _cookie_domain() is None: return "__Host-" + cookie_name - else: - return cookie_name + return cookie_name def _try_extract_from_header(request: Request) -> str | None: auth_header = request.headers.get("Authorization") - if auth_header: - if " " not in auth_header: - return None - else: - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - return None - else: - return auth_token - return None + if not auth_header or " " not in auth_header: + return None + auth_scheme, auth_token = auth_header.split(None, 1) + if auth_scheme.lower() != "bearer": + return None + return auth_token def extract_refresh_token(request: Request) -> str | None: @@ -90,14 +84,9 @@ def extract_webapp_access_token(request: Request) -> str | None: def extract_webapp_passport(app_code: str, request: Request) -> str | None: - def _try_extract_passport_token_from_cookie(request: Request) -> str | None: - return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) - - def _try_extract_passport_token_from_header(request: Request) -> str | None: - return request.headers.get(HEADER_NAME_PASSPORT) - - ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request) - return ret + return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) or request.headers.get( + HEADER_NAME_PASSPORT + ) def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"): @@ -209,22 +198,18 @@ def check_csrf_token(request: Request, user_id: str): if not csrf_token: _unauthorized() - verified = {} try: verified = PassportService().verify(csrf_token) - except: + except Exception: _unauthorized() + raise # unreachable, but helps the type checker see verified is always bound if verified.get("sub") != user_id: _unauthorized() exp: int | None = verified.get("exp") - if not exp: + if not exp or exp < int(datetime.now(UTC).timestamp()): _unauthorized() - else: - time_now = int(datetime.now().timestamp()) - if exp < time_now: - _unauthorized() def generate_csrf_token(user_id: str) -> str: diff --git a/api/libs/url_utils.py b/api/libs/url_utils.py new file mode 100644 index 00000000000..adcac3add0e --- /dev/null +++ b/api/libs/url_utils.py @@ -0,0 +1,3 @@ +def normalize_api_base_url(base_url: str) -> str: + """Normalize a base URL to always end with /v1, avoiding double /v1 suffixes.""" + return base_url.rstrip("/").removesuffix("/v1").rstrip("/") + "/v1" diff --git a/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py b/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py new file mode 100644 index 00000000000..0e188ec0803 --- /dev/null +++ b/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py @@ -0,0 +1,26 @@ +"""add qdrant_endpoint to tidb_auth_bindings + +Revision ID: 8574b23a38fd +Revises: 6b5f9f8b1a2c +Create Date: 2026-04-14 15:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8574b23a38fd" +down_revision = "6b5f9f8b1a2c" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op: + batch_op.add_column(sa.Column("qdrant_endpoint", sa.String(length=512), nullable=True)) + + +def downgrade(): + with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op: + batch_op.drop_column("qdrant_endpoint") diff --git a/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py b/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py new file mode 100644 index 00000000000..0548c932b58 --- /dev/null +++ b/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py @@ -0,0 +1,90 @@ +"""Add workflow comments table + +Revision ID: 227822d22895 +Revises: 8574b23a38fd +Create Date: 2025-08-22 17:26:15.255980 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '227822d22895' +down_revision = '8574b23a38fd' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_comments', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('position_x', sa.Float(), nullable=False), + sa.Column('position_y', sa.Float(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('resolved_at', sa.DateTime(), nullable=True), + sa.Column('resolved_by', models.types.StringUUID(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey') + ) + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_replies', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey') + ) + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_mentions', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('reply_id', models.types.StringUUID(), nullable=True), + sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'), + sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey') + ) + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False) + batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.drop_index('comment_mentions_user_idx') + batch_op.drop_index('comment_mentions_reply_idx') + batch_op.drop_index('comment_mentions_comment_idx') + + op.drop_table('workflow_comment_mentions') + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.drop_index('comment_replies_created_at_idx') + batch_op.drop_index('comment_replies_comment_idx') + + op.drop_table('workflow_comment_replies') + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.drop_index('workflow_comments_created_at_idx') + batch_op.drop_index('workflow_comments_app_idx') + + op.drop_table('workflow_comments') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index fcae07f948d..85be9ca3bd2 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -9,6 +9,11 @@ from .account import ( TenantStatus, ) from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .comment import ( + WorkflowComment, + WorkflowCommentMention, + WorkflowCommentReply, +) from .dataset import ( AppDatasetJoin, Dataset, @@ -208,6 +213,9 @@ __all__ = [ "WorkflowAppLog", "WorkflowAppLogCreatedFrom", "WorkflowArchiveLog", + "WorkflowComment", + "WorkflowCommentMention", + "WorkflowCommentReply", "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", diff --git a/api/models/account.py b/api/models/account.py index 5960ac65643..a3074c6f637 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -2,7 +2,7 @@ import enum import json from dataclasses import field from datetime import datetime -from typing import Any, Optional +from typing import Optional, TypedDict from uuid import uuid4 import sqlalchemy as sa @@ -232,6 +232,11 @@ class TenantStatus(enum.StrEnum): ARCHIVE = "archive" +class TenantCustomConfigDict(TypedDict, total=False): + remove_webapp_brand: bool + replace_webapp_logo: str | None + + class Tenant(TypeBase): __tablename__ = "tenants" __table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),) @@ -263,11 +268,11 @@ class Tenant(TypeBase): ) @property - def custom_config_dict(self) -> dict[str, Any]: + def custom_config_dict(self) -> TenantCustomConfigDict: return json.loads(self.custom_config) if self.custom_config else {} @custom_config_dict.setter - def custom_config_dict(self, value: dict[str, Any]) -> None: + def custom_config_dict(self, value: TenantCustomConfigDict) -> None: self.custom_config = json.dumps(value) diff --git a/api/models/base.py b/api/models/base.py index b7023b9c8bd..5acdf184f45 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase): class DefaultFieldsMixin: + """Mixin for models that inherit from Base (non-dataclass).""" + id: Mapped[str] = mapped_column( StringUUID, primary_key=True, @@ -53,6 +55,42 @@ class DefaultFieldsMixin: return f"<{self.__class__.__name__}(id={self.id})>" +class DefaultFieldsDCMixin(MappedAsDataclass): + """Mixin for models that inherit from TypeBase (MappedAsDataclass).""" + + __abstract__ = True + + id: Mapped[str] = mapped_column( + StringUUID, + primary_key=True, + insert_default=lambda: str(uuidv7()), + default_factory=lambda: str(uuidv7()), + init=False, + ) + + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + insert_default=naive_utc_now, + default_factory=naive_utc_now, + init=False, + server_default=func.current_timestamp(), + ) + + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + insert_default=naive_utc_now, + default_factory=naive_utc_now, + init=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(id={self.id})>" + + def gen_uuidv4_string() -> str: """gen_uuidv4_string generate a UUIDv4 string. diff --git a/api/models/comment.py b/api/models/comment.py new file mode 100644 index 00000000000..1154e16788b --- /dev/null +++ b/api/models/comment.py @@ -0,0 +1,219 @@ +"""Workflow comment models.""" + +from datetime import datetime +from typing import Optional + +import sqlalchemy as sa +from sqlalchemy import Index, func +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .account import Account +from .base import Base +from .engine import db +from .types import StringUUID + + +class WorkflowComment(Base): + """Workflow comment model for canvas commenting functionality. + + Comments are associated with apps rather than specific workflow versions, + since an app has only one draft workflow at a time and comments should persist + across workflow version changes. + + Attributes: + id: Comment ID + tenant_id: Workspace ID + app_id: App ID (primary association, comments belong to apps) + position_x: X coordinate on canvas + position_y: Y coordinate on canvas + content: Comment content + created_by: Creator account ID + created_at: Creation time + updated_at: Last update time + resolved: Whether comment is resolved + resolved_at: Resolution time + resolved_by: Resolver account ID + """ + + __tablename__ = "workflow_comments" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + Index("workflow_comments_app_idx", "tenant_id", "app_id"), + Index("workflow_comments_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + position_x: Mapped[float] = mapped_column(sa.Float) + position_y: Mapped[float] = mapped_column(sa.Float) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime) + resolved_by: Mapped[str | None] = mapped_column(StringUUID) + + # Relationships + replies: Mapped[list["WorkflowCommentReply"]] = relationship( + "WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan" + ) + mentions: Mapped[list["WorkflowCommentMention"]] = relationship( + "WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan" + ) + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + @property + def resolved_by_account(self): + """Get resolver account.""" + if hasattr(self, "_resolved_by_account_cache"): + return self._resolved_by_account_cache + if self.resolved_by: + return db.session.get(Account, self.resolved_by) + return None + + def cache_resolved_by_account(self, account: Account | None) -> None: + """Cache resolver account to avoid extra queries.""" + self._resolved_by_account_cache = account + + @property + def reply_count(self): + """Get reply count.""" + return len(self.replies) + + @property + def mention_count(self): + """Get mention count.""" + return len(self.mentions) + + @property + def participants(self): + """Get all participants (creator + repliers + mentioned users).""" + participant_ids: set[str] = set() + participants: list[Account] = [] + + # Use account properties to reuse preloaded caches and avoid hidden N+1. + if self.created_by not in participant_ids: + participant_ids.add(self.created_by) + created_by_account = self.created_by_account + if created_by_account: + participants.append(created_by_account) + + for reply in self.replies: + if reply.created_by in participant_ids: + continue + participant_ids.add(reply.created_by) + reply_account = reply.created_by_account + if reply_account: + participants.append(reply_account) + + for mention in self.mentions: + if mention.mentioned_user_id in participant_ids: + continue + participant_ids.add(mention.mentioned_user_id) + mentioned_account = mention.mentioned_user_account + if mentioned_account: + participants.append(mentioned_account) + + return participants + + +class WorkflowCommentReply(Base): + """Workflow comment reply model. + + Attributes: + id: Reply ID + comment_id: Parent comment ID + content: Reply content + created_by: Creator account ID + created_at: Creation time + """ + + __tablename__ = "workflow_comment_replies" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + Index("comment_replies_comment_idx", "comment_id"), + Index("comment_replies_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()")) + comment_id: Mapped[str] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + +class WorkflowCommentMention(Base): + """Workflow comment mention model. + + Mentions are only for internal accounts since end users + cannot access workflow canvas and commenting features. + + Attributes: + id: Mention ID + comment_id: Parent comment ID + mentioned_user_id: Mentioned account ID + """ + + __tablename__ = "workflow_comment_mentions" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + Index("comment_mentions_comment_idx", "comment_id"), + Index("comment_mentions_reply_idx", "reply_id"), + Index("comment_mentions_user_idx", "mentioned_user_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuidv7()")) + comment_id: Mapped[str] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + reply_id: Mapped[str | None] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True + ) + mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions") + reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply") + + @property + def mentioned_user_account(self): + """Get mentioned account.""" + if hasattr(self, "_mentioned_user_account_cache"): + return self._mentioned_user_account_cache + return db.session.get(Account, self.mentioned_user_id) + + def cache_mentioned_user_account(self, account: Account | None) -> None: + """Cache mentioned account to avoid extra queries.""" + self._mentioned_user_account_cache = account diff --git a/api/models/dataset.py b/api/models/dataset.py index e323ccfd7f6..eee5c39a0e8 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -19,6 +19,7 @@ from sqlalchemy import DateTime, String, func, select from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config +from core.rag.entities import ParentMode, Rule from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType @@ -26,7 +27,6 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_storage import storage from libs.uuid_utils import uuidv7 -from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from .account import Account from .base import Base, TypeBase @@ -108,6 +108,56 @@ class ExternalKnowledgeApiDict(TypedDict): created_at: str +class DocumentDict(TypedDict): + id: str + tenant_id: str + dataset_id: str + position: int + data_source_type: str + data_source_info: str | None + dataset_process_rule_id: str | None + batch: str + name: str + created_from: str + created_by: str + created_api_request_id: str | None + created_at: datetime + processing_started_at: datetime | None + file_id: str | None + word_count: int | None + parsing_completed_at: datetime | None + cleaning_completed_at: datetime | None + splitting_completed_at: datetime | None + tokens: int | None + indexing_latency: float | None + completed_at: datetime | None + is_paused: bool | None + paused_by: str | None + paused_at: datetime | None + error: str | None + stopped_at: datetime | None + indexing_status: str + enabled: bool + disabled_at: datetime | None + disabled_by: str | None + archived: bool + archived_reason: str | None + archived_by: str | None + archived_at: datetime | None + updated_at: datetime + doc_type: str | None + doc_metadata: Any + doc_form: IndexStructureType + doc_language: str | None + display_status: str | None + data_source_info_dict: dict[str, Any] + average_segment_length: int + dataset_process_rule: ProcessRuleDict | None + dataset: None + segment_count: int | None + hit_count: int | None + + class DatasetPermissionEnum(enum.StrEnum): ONLY_ME = "only_me" ALL_TEAM = "all_team_members" @@ -303,13 +353,17 @@ class Dataset(Base): if self.provider != "external": return None external_knowledge_binding = db.session.scalar( - select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id) + select(ExternalKnowledgeBindings).where( + ExternalKnowledgeBindings.dataset_id == self.id, + ExternalKnowledgeBindings.tenant_id == self.tenant_id, + ) ) if not external_knowledge_binding: return None external_knowledge_api = db.session.scalar( select(ExternalKnowledgeApis).where( - ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id + ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id, + ExternalKnowledgeApis.tenant_id == self.tenant_id, ) ) if external_knowledge_api is None or external_knowledge_api.settings is None: @@ -675,8 +729,8 @@ class Document(Base): ) return built_in_fields - def to_dict(self) -> dict[str, Any]: - return { + def to_dict(self) -> DocumentDict: + result: DocumentDict = { "id": self.id, "tenant_id": self.tenant_id, "dataset_id": self.dataset_id, @@ -721,10 +775,11 @@ class Document(Base): "data_source_info_dict": self.data_source_info_dict, "average_segment_length": self.average_segment_length, "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, - "dataset": None, # Dataset class doesn't have a to_dict method + "dataset": None, "segment_count": self.segment_count, "hit_count": self.hit_count, } + return result @classmethod def from_dict(cls, data: dict[str, Any]): @@ -1250,6 +1305,7 @@ class TidbAuthBinding(TypeBase): ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) + qdrant_endpoint: Mapped[str | None] = mapped_column(String(512), nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -1496,7 +1552,7 @@ class PipelineBuiltInTemplate(TypeBase): name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) - icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) copyright: Mapped[str] = mapped_column(sa.String(255), nullable=False) privacy_policy: Mapped[str] = mapped_column(sa.String(255), nullable=False) @@ -1529,7 +1585,7 @@ class PipelineCustomizedTemplate(TypeBase): name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) - icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -1602,7 +1658,7 @@ class DocumentPipelineExecutionLog(TypeBase): datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False) datasource_info: Mapped[str] = mapped_column(LongText, nullable=False) datasource_node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) - input_data: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + input_data: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1633,7 +1689,7 @@ class PipelineRecommendedPlugin(TypeBase): ) -class SegmentAttachmentBinding(Base): +class SegmentAttachmentBinding(TypeBase): __tablename__ = "segment_attachment_bindings" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"), @@ -1646,16 +1702,20 @@ class SegmentAttachmentBinding(Base): ), sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False + ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) -class DocumentSegmentSummary(Base): +class DocumentSegmentSummary(TypeBase): __tablename__ = "document_segment_summaries" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), @@ -1665,25 +1725,40 @@ class DocumentSegmentSummary(Base): sa.Index("document_segment_summaries_status_idx", "status"), ) - id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, + ) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # corresponds to DocumentSegment.id or parent chunk id chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - summary_content: Mapped[str] = mapped_column(LongText, nullable=True) - summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) - summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) - tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - status: Mapped[str] = mapped_column( - EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'") + summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + status: Mapped[SummaryStatus] = mapped_column( + EnumText(SummaryStatus, length=32), + nullable=False, + server_default=sa.text("'generating'"), + default=SummaryStatus.GENERATING, + ) + error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) - error: Mapped[str] = mapped_column(LongText, nullable=True) - enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - disabled_by = mapped_column(StringUUID, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) def __repr__(self): diff --git a/api/models/human_input.py b/api/models/human_input.py index 79c5d62f6a8..7447d3efcb3 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,11 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from core.workflow.human_input_compat import DeliveryMethodType +from core.workflow.human_input_adapter import DeliveryMethodType +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index 1d73aadf092..a1117fc43ab 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,17 +14,18 @@ from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] -from graphon.enums import WorkflowExecutionStatus -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text -from sqlalchemy.orm import Mapped, Session, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from extensions.storage.storage_type import StorageType +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] +from libs.url_utils import normalize_api_base_url from libs.uuid_utils import uuidv7 from models.utils.file_input_compat import build_file_from_input_mapping @@ -446,7 +447,8 @@ class App(Base): @property def api_base_url(self) -> str: - return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" + base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/") + return normalize_api_base_url(base) @property def tenant(self) -> Tenant | None: @@ -524,7 +526,7 @@ class App(Base): if not api_provider_ids and not builtin_provider_ids: return [] - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: if api_provider_ids: existing_api_providers = [ str(api_provider.id) @@ -674,28 +676,24 @@ class AppModelConfig(TypeBase): def suggested_questions_list(self) -> list[str]: return json.loads(self.suggested_questions) if self.suggested_questions else [] + def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig: + return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled}) + @property def suggested_questions_after_answer_dict(self) -> EnabledConfig: - return cast( - EnabledConfig, - json.loads(self.suggested_questions_after_answer) - if self.suggested_questions_after_answer - else {"enabled": False}, - ) + return self._get_enabled_config(self.suggested_questions_after_answer) @property def speech_to_text_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}) + return self._get_enabled_config(self.speech_to_text) @property def text_to_speech_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}) + return self._get_enabled_config(self.text_to_speech) @property def retriever_resource_dict(self) -> EnabledConfig: - return cast( - EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} - ) + return self._get_enabled_config(self.retriever_resource, default_enabled=True) @property def annotation_reply_dict(self) -> AnnotationReplyConfig: @@ -722,7 +720,7 @@ class AppModelConfig(TypeBase): @property def more_like_this_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}) + return self._get_enabled_config(self.more_like_this) @property def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig: @@ -813,60 +811,36 @@ class AppModelConfig(TypeBase): "file_upload": self.file_upload_dict, } + @staticmethod + def _dump_optional(value: Any) -> str | None: + return json.dumps(value) if value else None + def from_model_config_dict(self, model_config: AppModelConfigDict): self.opening_statement = model_config.get("opening_statement") - self.suggested_questions = ( - json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None - ) - self.suggested_questions_after_answer = ( - json.dumps(model_config.get("suggested_questions_after_answer")) - if model_config.get("suggested_questions_after_answer") - else None - ) - self.speech_to_text = ( - json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None - ) - self.text_to_speech = ( - json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None - ) - self.more_like_this = ( - json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None - ) - self.sensitive_word_avoidance = ( - json.dumps(model_config.get("sensitive_word_avoidance")) - if model_config.get("sensitive_word_avoidance") - else None - ) - self.external_data_tools = ( - json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None - ) - self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None - self.user_input_form = ( - json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None + self.suggested_questions = self._dump_optional(model_config.get("suggested_questions")) + self.suggested_questions_after_answer = self._dump_optional( + model_config.get("suggested_questions_after_answer") ) + self.speech_to_text = self._dump_optional(model_config.get("speech_to_text")) + self.text_to_speech = self._dump_optional(model_config.get("text_to_speech")) + self.more_like_this = self._dump_optional(model_config.get("more_like_this")) + self.sensitive_word_avoidance = self._dump_optional(model_config.get("sensitive_word_avoidance")) + self.external_data_tools = self._dump_optional(model_config.get("external_data_tools")) + self.model = self._dump_optional(model_config.get("model")) + self.user_input_form = self._dump_optional(model_config.get("user_input_form")) self.dataset_query_variable = model_config.get("dataset_query_variable") self.pre_prompt = model_config.get("pre_prompt") - self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None - self.retriever_resource = ( - json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None - ) + self.agent_mode = self._dump_optional(model_config.get("agent_mode")) + self.retriever_resource = self._dump_optional(model_config.get("retriever_resource")) self.prompt_type = PromptType(model_config.get("prompt_type", "simple")) - self.chat_prompt_config = ( - json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None - ) - self.completion_prompt_config = ( - json.dumps(model_config.get("completion_prompt_config")) - if model_config.get("completion_prompt_config") - else None - ) - self.dataset_configs = ( - json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None - ) - self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None + self.chat_prompt_config = self._dump_optional(model_config.get("chat_prompt_config")) + self.completion_prompt_config = self._dump_optional(model_config.get("completion_prompt_config")) + self.dataset_configs = self._dump_optional(model_config.get("dataset_configs")) + self.file_upload = self._dump_optional(model_config.get("file_upload")) return self -class RecommendedApp(Base): # bug +class RecommendedApp(TypeBase): __tablename__ = "recommended_apps" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"), @@ -874,20 +848,37 @@ class RecommendedApp(Base): # bug sa.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4())) - app_id = mapped_column(StringUUID, nullable=False) - description = mapped_column(sa.JSON, nullable=False) + id: Mapped[str] = mapped_column( + StringUUID, + primary_key=True, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, + ) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + description: Mapped[Any] = mapped_column(sa.JSON, nullable=False) copyright: Mapped[str] = mapped_column(String(255), nullable=False) privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False) - custom_disclaimer: Mapped[str] = mapped_column(LongText, default="") category: Mapped[str] = mapped_column(String(255), nullable=False) + custom_disclaimer: Mapped[str] = mapped_column(LongText, default="") position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) - language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'")) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + language: Mapped[str] = mapped_column( + String(255), + nullable=False, + server_default=sa.text("'en-US'"), + default="en-US", + ) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) @property @@ -926,7 +917,7 @@ class InstalledApp(TypeBase): return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) -class TrialApp(Base): +class TrialApp(TypeBase): __tablename__ = "trial_apps" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="trial_app_pkey"), @@ -935,18 +926,22 @@ class TrialApp(Base): sa.UniqueConstraint("app_id", name="unique_trail_app_id"), ) - id = mapped_column(StringUUID, default=gen_uuidv4_string) - app_id = mapped_column(StringUUID, nullable=False) - tenant_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - trial_limit = mapped_column(sa.Integer, nullable=False, default=3) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3) @property def app(self) -> App | None: return db.session.scalar(select(App).where(App.id == self.app_id)) -class AccountTrialAppRecord(Base): +class AccountTrialAppRecord(TypeBase): __tablename__ = "account_trial_app_records" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"), @@ -954,11 +949,15 @@ class AccountTrialAppRecord(Base): sa.Index("account_trial_app_record_app_id_idx", "app_id"), sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), ) - id = mapped_column(StringUUID, default=gen_uuidv4_string) - account_id = mapped_column(StringUUID, nullable=False) - app_id = mapped_column(StringUUID, nullable=False) - count = mapped_column(sa.Integer, nullable=False, default=0) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) @property def app(self) -> App | None: @@ -1010,7 +1009,7 @@ class OAuthProviderApp(TypeBase): app_icon: Mapped[str] = mapped_column(String(255), nullable=False) client_id: Mapped[str] = mapped_column(String(255), nullable=False) client_secret: Mapped[str] = mapped_column(String(255), nullable=False) - app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict) + app_label: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, default_factory=dict) redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list) scope: Mapped[str] = mapped_column( String(255), @@ -1081,7 +1080,7 @@ class Conversation(Base): messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( - "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" + lambda: MessageAnnotation, backref="conversation", lazy="select", passive_deletes="all" ) is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @@ -1632,52 +1631,53 @@ class Message(Base): files: list[File] = [] for message_file in message_files: - if message_file.transfer_method == FileTransferMethod.LOCAL_FILE: - if message_file.upload_file_id is None: - raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") - file = file_factory.build_from_mapping( - mapping={ + match message_file.transfer_method: + case FileTransferMethod.LOCAL_FILE: + if message_file.upload_file_id is None: + raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "upload_file_id": message_file.upload_file_id, + }, + tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), + ) + case FileTransferMethod.REMOTE_URL: + if message_file.url is None: + raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "upload_file_id": message_file.upload_file_id, + "url": message_file.url, + }, + tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), + ) + case FileTransferMethod.TOOL_FILE: + if message_file.upload_file_id is None: + assert message_file.url is not None + message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0] + mapping = { "id": message_file.id, "type": message_file.type, "transfer_method": message_file.transfer_method, - "upload_file_id": message_file.upload_file_id, - }, - tenant_id=current_app.tenant_id, - access_controller=_get_file_access_controller(), - ) - elif message_file.transfer_method == FileTransferMethod.REMOTE_URL: - if message_file.url is None: - raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") - file = file_factory.build_from_mapping( - mapping={ - "id": message_file.id, - "type": message_file.type, - "transfer_method": message_file.transfer_method, - "upload_file_id": message_file.upload_file_id, - "url": message_file.url, - }, - tenant_id=current_app.tenant_id, - access_controller=_get_file_access_controller(), - ) - elif message_file.transfer_method == FileTransferMethod.TOOL_FILE: - if message_file.upload_file_id is None: - assert message_file.url is not None - message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0] - mapping = { - "id": message_file.id, - "type": message_file.type, - "transfer_method": message_file.transfer_method, - "tool_file_id": message_file.upload_file_id, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=current_app.tenant_id, - access_controller=_get_file_access_controller(), - ) - else: - raise ValueError( - f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}" - ) + "tool_file_id": message_file.upload_file_id, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), + ) + case FileTransferMethod.DATASOURCE_FILE: + raise ValueError( + f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}" + ) files.append(file) result = cast( @@ -1839,7 +1839,7 @@ class MessageFile(TypeBase): ) -class MessageAnnotation(Base): +class MessageAnnotation(TypeBase): __tablename__ = "message_annotations" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"), @@ -1848,17 +1848,25 @@ class MessageAnnotation(Base): sa.Index("message_annotation_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID) - conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) - message_id: Mapped[str | None] = mapped_column(StringUUID) question: Mapped[str] = mapped_column(LongText, nullable=False) content: Mapped[str] = mapped_column(LongText, nullable=False) - hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None) + message_id: Mapped[str | None] = mapped_column(StringUUID, default=None) + hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) @property @@ -2489,7 +2497,7 @@ class TraceAppConfig(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True) - tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True) + tracing_config: Mapped[dict[str, Any] | None] = mapped_column(sa.JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/models/oauth.py b/api/models/oauth.py index 1db25524693..bd04d890d30 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any import sqlalchemy as sa from sqlalchemy import func @@ -22,7 +23,7 @@ class DatasourceOauthParamConfig(TypeBase): ) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) provider: Mapped[str] = mapped_column(sa.String(255), nullable=False) - system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + system_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) class DatasourceProvider(TypeBase): @@ -40,7 +41,7 @@ class DatasourceProvider(TypeBase): provider: Mapped[str] = mapped_column(sa.String(128), nullable=False) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False) - encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + encrypted_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default") is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1) @@ -70,7 +71,7 @@ class DatasourceOauthTenantParamConfig(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider: Mapped[str] = mapped_column(sa.String(255), nullable=False) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) - client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict) + client_params: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/provider.py b/api/models/provider.py index 8270961b311..2bb67d605b5 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,10 +6,10 @@ from functools import cached_property from uuid import uuid4 import sqlalchemy as sa -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column +from graphon.model_runtime.entities.model_entities import ModelType from libs.uuid_utils import uuidv7 from .base import TypeBase diff --git a/api/models/source.py b/api/models/source.py index a8addbe3423..8fce7df205d 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,5 +1,6 @@ import json from datetime import datetime +from typing import Any, TypedDict from uuid import uuid4 import sqlalchemy as sa @@ -24,7 +25,7 @@ class DataSourceOauthBinding(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) access_token: Mapped[str] = mapped_column(String(255), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False) - source_info: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + source_info: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -38,6 +39,17 @@ class DataSourceOauthBinding(TypeBase): disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False) +class DataSourceApiKeyAuthBindingDict(TypedDict): + id: str + tenant_id: str + category: str + provider: str + credentials: Any + created_at: float + updated_at: float + disabled: bool + + class DataSourceApiKeyAuthBinding(TypeBase): __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( @@ -65,8 +77,8 @@ class DataSourceApiKeyAuthBinding(TypeBase): ) disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False) - def to_dict(self): - return { + def to_dict(self) -> DataSourceApiKeyAuthBindingDict: + result: DataSourceApiKeyAuthBindingDict = { "id": self.id, "tenant_id": self.tenant_id, "category": self.category, @@ -76,3 +88,4 @@ class DataSourceApiKeyAuthBinding(TypeBase): "updated_at": self.updated_at.timestamp(), "disabled": self.disabled, } + return result diff --git a/api/models/tools.py b/api/models/tools.py index d8731fb8a8a..02f8b5217d5 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -356,7 +356,7 @@ class MCPToolProvider(TypeBase): return {} @property - def headers(self) -> dict[str, Any]: + def headers(self) -> dict[str, str]: if self.encrypted_headers is None: return {} try: diff --git a/api/models/types.py b/api/models/types.py index 9ab694759f6..4f35c31a27c 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,6 +1,6 @@ import enum import uuid -from typing import Any +from typing import Any, cast import sqlalchemy as sa from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator @@ -103,10 +103,14 @@ class AdjustedJSON(TypeDecorator[dict | list | None]): else: return dialect.type_descriptor(sa.JSON()) - def process_bind_param(self, value: dict | list | None, dialect: Dialect) -> dict | list | None: + def process_bind_param( + self, value: dict[str, Any] | list[Any] | None, dialect: Dialect + ) -> dict[str, Any] | list[Any] | None: return value - def process_result_value(self, value: dict | list | None, dialect: Dialect) -> dict | list | None: + def process_result_value( + self, value: dict[str, Any] | list[Any] | None, dialect: Dialect + ) -> dict[str, Any] | list[Any] | None: return value @@ -143,8 +147,14 @@ class EnumText[T: enum.StrEnum](TypeDecorator[T | None]): def process_result_value(self, value: str | None, dialect: Dialect) -> T | None: if value is None or value == "": return None - # Type annotation guarantees value is str at this point - return self._enum_class(value) + try: + # Type annotation guarantees value is str at this point + return self._enum_class(value) + except ValueError: + value_of = getattr(self._enum_class, "value_of", None) + if callable(value_of): + return cast(T, value_of(value)) + raise def compare_values(self, x: T | None, y: T | None) -> bool: if x is None or y is None: diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index f71583c1cde..77dcbd13d4d 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -4,9 +4,9 @@ from collections.abc import Callable, Mapping from functools import lru_cache from typing import Any -from graphon.file import File, FileTransferMethod - from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType +from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object @lru_cache(maxsize=1) @@ -44,6 +44,124 @@ def resolve_file_mapping_tenant_id( return tenant_resolver() +def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File: + """Build a graph `File` directly from serialized metadata.""" + + def _coerce_file_type(value: Any) -> FileType: + if isinstance(value, FileType): + return value + if isinstance(value, str): + return FileType.value_of(value) + raise ValueError("file type is required in file mapping") + + mapping = dict(file_mapping) + transfer_method_value = mapping.get("transfer_method") + if isinstance(transfer_method_value, FileTransferMethod): + transfer_method = transfer_method_value + elif isinstance(transfer_method_value, str): + transfer_method = FileTransferMethod.value_of(transfer_method_value) + else: + raise ValueError("transfer_method is required in file mapping") + + file_id = mapping.get("file_id") + if not isinstance(file_id, str) or not file_id: + legacy_id = mapping.get("id") + file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None + + related_id = resolve_file_record_id(mapping) + if related_id is None: + raw_related_id = mapping.get("related_id") + related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None + + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + remote_url = url if isinstance(url, str) and url else None + + reference = mapping.get("reference") + if not isinstance(reference, str) or not reference: + reference = None + + filename = mapping.get("filename") + if not isinstance(filename, str): + filename = None + + extension = mapping.get("extension") + if not isinstance(extension, str): + extension = None + + mime_type = mapping.get("mime_type") + if not isinstance(mime_type, str): + mime_type = None + + size = mapping.get("size", -1) + if not isinstance(size, int): + size = -1 + + storage_key = mapping.get("storage_key") + if not isinstance(storage_key, str): + storage_key = None + + tenant_id = mapping.get("tenant_id") + if not isinstance(tenant_id, str): + tenant_id = None + + dify_model_identity = mapping.get("dify_model_identity") + if not isinstance(dify_model_identity, str): + dify_model_identity = FILE_MODEL_IDENTITY + + tool_file_id = mapping.get("tool_file_id") + if not isinstance(tool_file_id, str): + tool_file_id = None + + upload_file_id = mapping.get("upload_file_id") + if not isinstance(upload_file_id, str): + upload_file_id = None + + datasource_file_id = mapping.get("datasource_file_id") + if not isinstance(datasource_file_id, str): + datasource_file_id = None + + return File( + file_id=file_id, + tenant_id=tenant_id, + file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))), + transfer_method=transfer_method, + remote_url=remote_url, + reference=reference, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + storage_key=storage_key, + dify_model_identity=dify_model_identity, + url=remote_url, + tool_file_id=tool_file_id, + upload_file_id=upload_file_id, + datasource_file_id=datasource_file_id, + ) + + +def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any: + """Recursively rebuild serialized graph file payloads into `File` objects. + + `graphon` 0.2.2 no longer accepts legacy serialized file mappings via + `model_validate_json()`. Dify keeps this recovery path at the model boundary + so historical JSON blobs remain readable without reintroducing global graph + patches or test-local coercion. + """ + if isinstance(value, list): + return [rebuild_serialized_graph_files_without_lookup(item) for item in value] + + if isinstance(value, dict): + if maybe_file_object(value): + return build_file_from_mapping_without_lookup(file_mapping=value) + return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()} + + return value + + def build_file_from_stored_mapping( *, file_mapping: Mapping[str, Any], @@ -66,20 +184,18 @@ def build_file_from_stored_mapping( record_id = resolve_file_record_id(mapping) transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) - if transfer_method == FileTransferMethod.TOOL_FILE and record_id: - mapping["tool_file_id"] = record_id - elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: - mapping["upload_file_id"] = record_id - elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: - mapping["datasource_file_id"] = record_id + match transfer_method: + case FileTransferMethod.TOOL_FILE if record_id: + mapping["tool_file_id"] = record_id + case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id: + mapping["upload_file_id"] = record_id + case FileTransferMethod.DATASOURCE_FILE if record_id: + mapping["datasource_file_id"] = record_id + case _: + pass if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: - remote_url = mapping.get("remote_url") - if not isinstance(remote_url, str) or not remote_url: - url = mapping.get("url") - if isinstance(url, str) and url: - mapping["remote_url"] = url - return File.model_validate(mapping) + return build_file_from_mapping_without_lookup(file_mapping=mapping) return file_factory.build_from_mapping( mapping=mapping, diff --git a/api/models/workflow.py b/api/models/workflow.py index 10630163701..d127244b0f2 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -4,23 +4,10 @@ import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File -from graphon.file.constants import maybe_file_object -from graphon.variables import utils as variable_utils -from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -37,13 +24,26 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -53,19 +53,21 @@ if TYPE_CHECKING: from .model import AppMode, UploadFile -from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase - from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter from factories import variable_factory +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account -from .base import Base, DefaultFieldsMixin, TypeBase +from .base import Base, DefaultFieldsDCMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID -from .utils.file_input_compat import build_file_from_stored_mapping +from .utils.file_input_compat import ( + build_file_from_mapping_without_lookup, + build_file_from_stored_mapping, +) logger = logging.getLogger(__name__) @@ -121,7 +123,7 @@ class WorkflowType(StrEnum): raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": + def from_app_mode(cls, app_mode: "str | AppMode") -> "WorkflowType": """ Get workflow type from app mode. @@ -291,7 +293,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -490,7 +492,7 @@ class Workflow(Base): # bug :return: hash """ - entity = {"graph": self.graph_dict, "features": self.features_dict} + entity = {"graph": self.graph_dict} return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @@ -671,6 +673,29 @@ class Workflow(Base): # bug return str(d) +class WorkflowRunDict(TypedDict): + id: str + tenant_id: str + app_id: str + workflow_id: str + type: WorkflowType + triggered_from: WorkflowRunTriggeredFrom + version: str + graph: Mapping[str, Any] + inputs: Mapping[str, Any] + status: WorkflowExecutionStatus + outputs: Mapping[str, Any] + error: str | None + elapsed_time: float + total_tokens: int + total_steps: int + created_by_role: CreatorUserRole + created_by: str + created_at: datetime + finished_at: datetime | None + exceptions_count: int + + class WorkflowRun(Base): """ Workflow Run @@ -742,8 +767,8 @@ class WorkflowRun(Base): exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) pause: Mapped[Optional["WorkflowPause"]] = orm.relationship( - "WorkflowPause", - primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", + lambda: WorkflowPause, + primaryjoin=lambda: WorkflowRun.id == orm.foreign(WorkflowPause.workflow_run_id), uselist=False, # require explicit preloading. lazy="raise", @@ -790,29 +815,29 @@ class WorkflowRun(Base): def workflow(self): return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id)) - def to_dict(self): - return { - "id": self.id, - "tenant_id": self.tenant_id, - "app_id": self.app_id, - "workflow_id": self.workflow_id, - "type": self.type, - "triggered_from": self.triggered_from, - "version": self.version, - "graph": self.graph_dict, - "inputs": self.inputs_dict, - "status": self.status, - "outputs": self.outputs_dict, - "error": self.error, - "elapsed_time": self.elapsed_time, - "total_tokens": self.total_tokens, - "total_steps": self.total_steps, - "created_by_role": self.created_by_role, - "created_by": self.created_by, - "created_at": self.created_at, - "finished_at": self.finished_at, - "exceptions_count": self.exceptions_count, - } + def to_dict(self) -> WorkflowRunDict: + return WorkflowRunDict( + id=self.id, + tenant_id=self.tenant_id, + app_id=self.app_id, + workflow_id=self.workflow_id, + type=self.type, + triggered_from=self.triggered_from, + version=self.version, + graph=self.graph_dict, + inputs=self.inputs_dict, + status=self.status, + outputs=self.outputs_dict, + error=self.error, + elapsed_time=self.elapsed_time, + total_tokens=self.total_tokens, + total_steps=self.total_steps, + created_by_role=self.created_by_role, + created_by=self.created_by, + created_at=self.created_at, + finished_at=self.finished_at, + exceptions_count=self.exceptions_count, + ) @classmethod def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": @@ -1051,7 +1076,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo ) return extras - def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]: + def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> "WorkflowNodeExecutionOffload | None": return next(iter([i for i in self.offload_data if i.type_ == type_]), None) @property @@ -1196,6 +1221,18 @@ class WorkflowAppLogCreatedFrom(StrEnum): raise ValueError(f"invalid workflow app log created from value {value}") +class WorkflowAppLogDict(TypedDict): + id: str + tenant_id: str + app_id: str + workflow_id: str + workflow_run_id: str + created_from: WorkflowAppLogCreatedFrom + created_by_role: CreatorUserRole + created_by: str + created_at: datetime + + class WorkflowAppLog(TypeBase): """ Workflow App execution log, excluding workflow debugging records. @@ -1273,8 +1310,8 @@ class WorkflowAppLog(TypeBase): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None - def to_dict(self): - return { + def to_dict(self) -> WorkflowAppLogDict: + result: WorkflowAppLogDict = { "id": self.id, "tenant_id": self.tenant_id, "app_id": self.app_id, @@ -1285,6 +1322,7 @@ class WorkflowAppLog(TypeBase): "created_by": self.created_by, "created_at": self.created_at, } + return result class WorkflowArchiveLog(TypeBase): @@ -1625,21 +1663,22 @@ class WorkflowDraftVariable(Base): # Rebuild them through the file factory so tenant ownership, signed URLs, # and storage-backed metadata come from canonical records instead of the # serialized JSON blob. - if segment_type == SegmentType.FILE: - if isinstance(value, File): - return build_segment_with_type(segment_type, value) - elif isinstance(value, dict): - file = self._rebuild_file_types(value) - return build_segment_with_type(segment_type, file) - else: - raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") - if segment_type == SegmentType.ARRAY_FILE: - if not isinstance(value, list): - raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") - file_list = self._rebuild_file_types(value) - return build_segment_with_type(segment_type=segment_type, value=file_list) - - return build_segment_with_type(segment_type=segment_type, value=value) + match segment_type: + case SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = self._rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + case SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = self._rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + case _: + return build_segment_with_type(segment_type=segment_type, value=value) @staticmethod def rebuild_file_types(value: Any): @@ -1652,7 +1691,7 @@ class WorkflowDraftVariable(Base): return cast(Any, value) normalized_file = dict(value) normalized_file.pop("tenant_id", None) - return File.model_validate(normalized_file) + return build_file_from_mapping_without_lookup(file_mapping=normalized_file) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] @@ -1662,7 +1701,7 @@ class WorkflowDraftVariable(Base): for item in value_list: normalized_file = dict(cast(dict[str, Any], item)) normalized_file.pop("tenant_id", None) - file_list.append(File.model_validate(normalized_file)) + file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file)) return cast(Any, file_list) else: return cast(Any, value) @@ -1672,21 +1711,22 @@ class WorkflowDraftVariable(Base): # Extends `variable_factory.build_segment_with_type` functionality by # reconstructing `FileSegment`` or `ArrayFileSegment`` objects from # their serialized dictionary or list representations, respectively. - if segment_type == SegmentType.FILE: - if isinstance(value, File): - return build_segment_with_type(segment_type, value) - elif isinstance(value, dict): - file = cls.rebuild_file_types(value) - return build_segment_with_type(segment_type, file) - else: - raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") - if segment_type == SegmentType.ARRAY_FILE: - if not isinstance(value, list): - raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") - file_list = cls.rebuild_file_types(value) - return build_segment_with_type(segment_type=segment_type, value=file_list) - - return build_segment_with_type(segment_type=segment_type, value=value) + match segment_type: + case SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + case SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + case _: + return build_segment_with_type(segment_type=segment_type, value=value) def get_value(self) -> Segment: """Decode the serialized value into its corresponding `Segment` object. @@ -1939,7 +1979,7 @@ def is_system_variable_editable(name: str) -> bool: return name in _EDITABLE_SYSTEM_VARIABLE -class WorkflowPause(DefaultFieldsMixin, Base): +class WorkflowPause(DefaultFieldsDCMixin, TypeBase): """ WorkflowPause records the paused state and related metadata for a specific workflow run. @@ -1978,6 +2018,11 @@ class WorkflowPause(DefaultFieldsMixin, Base): nullable=False, ) + # state_object_key stores the object key referencing the serialized runtime state + # of the `GraphEngine`. This object captures the complete execution context of the + # workflow at the moment it was paused, enabling accurate resumption. + state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) + # `resumed_at` records the timestamp when the suspended workflow was resumed. # It is set to `NULL` if the workflow has not been resumed. # @@ -1986,25 +2031,23 @@ class WorkflowPause(DefaultFieldsMixin, Base): resumed_at: Mapped[datetime | None] = mapped_column( sa.DateTime, nullable=True, + default=None, ) - # state_object_key stores the object key referencing the serialized runtime state - # of the `GraphEngine`. This object captures the complete execution context of the - # workflow at the moment it was paused, enabling accurate resumption. - state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) - - # Relationship to WorkflowRun + # Relationship to WorkflowRun (uses lambda to resolve across Base/TypeBase registries) workflow_run: Mapped["WorkflowRun"] = orm.relationship( + lambda: WorkflowRun, foreign_keys=[workflow_run_id], # require explicit preloading. lazy="raise", uselist=False, - primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id", + primaryjoin=lambda: WorkflowPause.workflow_run_id == WorkflowRun.id, back_populates="pause", + init=False, ) -class WorkflowPauseReason(DefaultFieldsMixin, Base): +class WorkflowPauseReason(DefaultFieldsDCMixin, TypeBase): __tablename__ = "workflow_pause_reasons" # `pause_id` represents the identifier of the pause, @@ -2047,16 +2090,20 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): lazy="raise", uselist=False, primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id", + init=False, ) @classmethod - def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason": + def from_entity(cls, *, pause_id: str, pause_reason: PauseReason) -> "WorkflowPauseReason": if isinstance(pause_reason, HumanInputRequired): return cls( - type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id + pause_id=pause_id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=pause_reason.form_id, + node_id=pause_reason.node_id, ) elif isinstance(pause_reason, SchedulingPause): - return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="") + return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message) else: raise AssertionError(f"Unknown pause reason type: {pause_reason}") diff --git a/api/providers/README.md b/api/providers/README.md new file mode 100644 index 00000000000..5d5e6db9afc --- /dev/null +++ b/api/providers/README.md @@ -0,0 +1,15 @@ +# Providers + +This directory holds **optional workspace packages** that plug into Dify’s API core. Providers are responsible for implementing the interfaces and registering themselves to the API core. Provider mechanism allows building the software with selected set of providers so as to enhance the security and flexibility of distributions. + +## Developing Providers + +- [VDB Providers](vdb/README.md) + +## Tests + +Provider tests often live next to the package, e.g. `providers///tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`). + +## Excluding Providers + +In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-` and `--group trace-` to enable specific providers. diff --git a/api/providers/trace/README.md b/api/providers/trace/README.md new file mode 100644 index 00000000000..a7ffa5ed260 --- /dev/null +++ b/api/providers/trace/README.md @@ -0,0 +1,78 @@ +# Trace providers + +This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others). + +Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping. + +## Architecture + +| Layer | Location | Role | +|--------|----------|------| +| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads | +| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class | +| Your package | `api/providers/trace/trace-/` | Pydantic config + subclass of `BaseTraceInstance` | + +At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype. + +## What you implement + +### 1. Config model (`BaseTracingConfig`) + +Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate. + +Fields fall into two groups used by the manager: + +- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords). +- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints). + +List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct. + +### 2. Trace instance (`BaseTraceInstance`) + +Subclass `BaseTraceInstance` and implement: + +```python +def trace(self, trace_info: BaseTraceInfo) -> None: + ... +``` + +Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including: + +- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace` +- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo` +- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo` + +You may ignore categories your backend does not support; existing providers often no-op unhandled types. + +Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context. + +### 3. Register in the API core + +Upstream changes are required so Dify knows your provider exists: + +1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`). +2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning: + - `config_class`: your Pydantic config type + - `secret_keys` / `other_keys`: lists of field names as above + - `trace_instance`: your `BaseTraceInstance` subclass + Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`. + +If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app. + +## Package layout + +Each provider is a normal uv workspace member, for example: + +- `api/providers/trace/trace-/pyproject.toml` — project name `dify-trace-`, dependencies on vendor SDKs +- `api/providers/trace/trace-/src/dify_trace_/` — `config.py`, `_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`. + +Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`. + +## Wiring into the `api` workspace + +In `api/pyproject.toml`: + +1. **`[tool.uv.sources]`** — `dify-trace- = { workspace = true }` +2. **`[dependency-groups]`** — add `trace- = ["dify-trace-"]` and include `dify-trace-` in `trace-all` if it should ship with the default bundle + +After changing metadata, run **`uv sync`** from `api/`. diff --git a/api/providers/trace/trace-aliyun/pyproject.toml b/api/providers/trace/trace-aliyun/pyproject.toml new file mode 100644 index 00000000000..bcef7e9fb15 --- /dev/null +++ b/api/providers/trace/trace-aliyun/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-aliyun" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Aliyun)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py similarity index 98% rename from api/core/ops/aliyun_trace/aliyun_trace.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py index 70aaf2a07be..54d2f8167f1 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py @@ -1,12 +1,23 @@ import logging from collections.abc import Sequence -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.data_exporter.traceclient import ( TraceClient, build_endpoint, convert_datetime_to_nanoseconds, @@ -14,8 +25,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( convert_to_trace_id, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata +from dify_trace_aliyun.entities.semconv import ( DIFY_APP_ID, GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, @@ -34,7 +45,7 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -46,20 +57,9 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) -from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import AliyunConfig -from core.ops.entities.trace_entity import ( - BaseTraceInfo, - DatasetRetrievalTraceInfo, - GenerateNameTraceInfo, - MessageTraceInfo, - ModerationTraceInfo, - SuggestedQuestionTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py new file mode 100644 index 00000000000..e0133e6cc92 --- /dev/null +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py @@ -0,0 +1,32 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class AliyunConfig(BaseTracingConfig): + """ + Model class for Aliyun tracing config. + """ + + app_name: str = "dify_app" + license_key: str + endpoint: str + + @field_validator("app_name") + @classmethod + def app_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") + + @field_validator("license_key") + @classmethod + def license_key_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("License key cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # aliyun uses two URL formats, which may include a URL path + return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/data_exporter/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py similarity index 98% rename from api/core/ops/aliyun_trace/data_exporter/traceclient.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py index 67d5163b0f4..00aab6bf891 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py @@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData -from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE INVALID_SPAN_ID: Final[int] = 0x0000000000000000 INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000 diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/semconv.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py diff --git a/api/core/ops/arize_phoenix_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed similarity index 100% rename from api/core/ops/arize_phoenix_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed diff --git a/api/core/ops/aliyun_trace/utils.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py similarity index 91% rename from api/core/ops/aliyun_trace/utils.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py index d8e105d6a32..5678c66adbf 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py @@ -1,12 +1,11 @@ import json from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode -from core.ops.aliyun_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -15,8 +14,9 @@ from core.ops.aliyun_trace.entities.semconv import ( OUTPUT_VALUE, GenAISpanKind, ) -from core.rag.models.document import Document from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants @@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status: def create_links_from_trace_id(trace_id: str | None) -> list[Link]: - from core.ops.aliyun_trace.data_exporter.traceclient import create_link + from dify_trace_aliyun.data_exporter.traceclient import create_link links = [] if trace_id: @@ -56,10 +56,22 @@ def create_links_from_trace_id(trace_id: str | None) -> list[Link]: return links -def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]: - documents_data = [] +class RetrievalDocumentMetadataDict(TypedDict): + dataset_id: Any + doc_id: Any + document_id: Any + + +class RetrievalDocumentDict(TypedDict): + content: str + metadata: RetrievalDocumentMetadataDict + score: Any + + +def extract_retrieval_documents(documents: list[Document]) -> list[RetrievalDocumentDict]: + documents_data: list[RetrievalDocumentDict] = [] for document in documents: - document_data = { + document_data: RetrievalDocumentDict = { "content": document.page_content, "metadata": { "dataset_id": document.metadata.get("dataset_id"), @@ -83,7 +95,7 @@ def create_common_span_attributes( framework: str = DEFAULT_FRAMEWORK_NAME, inputs: str = "", outputs: str = "", -) -> dict[str, Any]: +) -> dict[str, str]: return { GEN_AI_SESSION_ID: session_id, GEN_AI_USER_ID: user_id, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py similarity index 86% rename from api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py index acb43d40361..286dda419cf 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py @@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from opentelemetry.sdk.trace import ReadableSpan -from opentelemetry.trace import SpanKind, Status, StatusCode - -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from dify_trace_aliyun.data_exporter.traceclient import ( INVALID_SPAN_ID, SpanBuilder, TraceClient, @@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( create_link, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.trace import SpanKind, Status, StatusCode @pytest.fixture @@ -41,8 +40,8 @@ def trace_client_factory(): class TestTraceClient: - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname") def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): mock_gethostname.return_value = "test-host" client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -56,7 +55,7 @@ class TestTraceClient: client.shutdown() assert client.done is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -64,8 +63,8 @@ class TestTraceClient: client.export(spans) mock_exporter.export.assert_called_once_with(spans) - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 405 @@ -74,8 +73,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 500 @@ -84,8 +83,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is False - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): mock_head.side_effect = httpx.RequestError("Connection error") @@ -93,12 +92,12 @@ class TestTraceClient: with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"): client.api_check() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_get_project_url(self, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_add_span(self, mock_exporter_class, trace_client_factory): client = trace_client_factory( service_name="test-service", @@ -134,8 +133,8 @@ class TestTraceClient: assert len(client.queue) == 2 mock_notify.assert_called_once() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.logger") def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) @@ -159,7 +158,7 @@ class TestTraceClient: assert len(client.queue) == 1 mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export_batch_error(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value mock_exporter.export.side_effect = Exception("Export failed") @@ -168,11 +167,11 @@ class TestTraceClient: mock_span = MagicMock(spec=ReadableSpan) client.queue.append(mock_span) - with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger: + with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger: client._export_batch() mock_logger.warning.assert_called() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_worker_loop(self, mock_exporter_class, trace_client_factory): # We need to test the wait timeout in _worker # But _worker runs in a thread. Let's mock condition.wait. @@ -189,7 +188,7 @@ class TestTraceClient: # mock_wait might have been called assert mock_wait.called or client.done - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -268,7 +267,7 @@ def test_generate_span_id(): assert span_id != INVALID_SPAN_ID # Test retry loop - with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand: + with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand: mock_rand.side_effect = [INVALID_SPAN_ID, 999] span_id = generate_span_id() assert span_id == 999 @@ -290,7 +289,7 @@ def test_convert_to_trace_id(): def test_convert_string_to_id(): assert convert_string_to_id("test") > 0 # Test with None string - with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen: + with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen: mock_gen.return_value = 12345 assert convert_string_to_id(None) == 12345 diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py index 2fcb927e0c3..38d33dd21ba 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -1,11 +1,10 @@ import pytest +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event from opentelemetry.trace import SpanKind, Status, StatusCode from pydantic import ValidationError -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata - class TestTraceMetadata: def test_trace_metadata_init(self): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py index 3961555b9a4..9cab40748f6 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py @@ -1,4 +1,4 @@ -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( ACS_ARMS_SERVICE_FEATURE, GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py similarity index 99% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py index 62d631a7541..c1b11c9186b 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py @@ -4,14 +4,11 @@ from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import MagicMock +import dify_trace_aliyun.aliyun_trace as aliyun_trace_module import pytest -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey -from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags - -import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module -from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.aliyun_trace import AliyunDataTrace +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, GEN_AI_OUTPUT_MESSAGE, @@ -26,7 +23,8 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.entities.config_entity import AliyunConfig +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -36,6 +34,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py similarity index 95% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py index 2d2be12f051..a9e7b80c2a0 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py @@ -1,11 +1,7 @@ import json from unittest.mock import MagicMock -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus -from opentelemetry.trace import Link, StatusCode - -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -13,7 +9,7 @@ from core.ops.aliyun_trace.entities.semconv import ( INPUT_VALUE, OUTPUT_VALUE, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -25,7 +21,11 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) +from opentelemetry.trace import Link, StatusCode + from core.rag.models.document import Document +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser @@ -48,7 +48,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = end_user_data - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -63,7 +63,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = None - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -112,9 +112,9 @@ def test_get_workflow_node_status(): def test_create_links_from_trace_id(monkeypatch): # Mock create_link mock_link = MagicMock(spec=Link) - import core.ops.aliyun_trace.data_exporter.traceclient + import dify_trace_aliyun.data_exporter.traceclient - monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) + monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) # Trace ID None assert create_links_from_trace_id(None) == [] diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py new file mode 100644 index 00000000000..1b24ee7421f --- /dev/null +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py @@ -0,0 +1,85 @@ +import pytest +from dify_trace_aliyun.config import AliyunConfig +from pydantic import ValidationError + + +class TestAliyunConfig: + """Test cases for AliyunConfig""" + + def test_valid_config(self): + """Test valid Aliyun configuration""" + config = AliyunConfig( + app_name="test_app", + license_key="test_license_key", + endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", + ) + assert config.app_name == "test_app" + assert config.license_key == "test_license_key" + assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + assert config.app_name == "dify_app" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + AliyunConfig() + + with pytest.raises(ValidationError): + AliyunConfig(license_key="test_license") + + with pytest.raises(ValidationError): + AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_app_name_validation_empty(self): + """Test app_name validation with empty value""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" + ) + assert config.app_name == "dify_app" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = AliyunConfig(license_key="test_license", endpoint="") + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation preserves path for Aliyun endpoints""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + ) + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_license_key_required(self): + """Test that license_key is required and cannot be empty""" + with pytest.raises(ValidationError): + AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_valid_endpoint_format_examples(self): + """Test valid endpoint format examples from comments""" + valid_endpoints = [ + # cms2.0 public endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", + # cms2.0 intranet endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", + # xtrace public endpoint + "http://tracing-cn-heyuan.arms.aliyuncs.com", + # xtrace intranet endpoint + "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", + ] + + for endpoint in valid_endpoints: + config = AliyunConfig(license_key="test_license", endpoint=endpoint) + assert config.endpoint == endpoint diff --git a/api/providers/trace/trace-arize-phoenix/pyproject.toml b/api/providers/trace/trace-arize-phoenix/pyproject.toml new file mode 100644 index 00000000000..9e756944c9a --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-arize-phoenix" +version = "0.0.1" +dependencies = [ + "arize-phoenix-otel~=0.15.0", +] +description = "Dify ops tracing provider (Arize / Phoenix)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langfuse_trace/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py similarity index 100% rename from api/core/ops/langfuse_trace/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py similarity index 99% rename from api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index 66933cea287..96df49ed0e7 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -6,7 +6,6 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse -from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -26,7 +25,6 @@ from opentelemetry.util.types import AttributeValue from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -40,7 +38,9 @@ from core.ops.entities.trace_entity import ( ) from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True) raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}") - def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: + def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]: """Construct LLM attributes with passed prompts for Arize/Phoenix.""" attributes: dict[str, str] = {} @@ -797,7 +797,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}" set_attribute(path, value) - def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None: + def set_tool_call_attributes( + message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None + ) -> None: """Extract and assign tool call details safely.""" if not tool_call: return diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py new file mode 100644 index 00000000000..6eac5b30d22 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py @@ -0,0 +1,45 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class ArizeConfig(BaseTracingConfig): + """ + Model class for Arize tracing config. + """ + + api_key: str | None = None + space_id: str | None = None + project: str | None = None + endpoint: str = "https://otlp.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://otlp.arize.com") + + +class PhoenixConfig(BaseTracingConfig): + """ + Model class for Phoenix tracing config. + """ + + api_key: str | None = None + project: str | None = None + endpoint: str = "https://app.phoenix.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://app.phoenix.arize.com") diff --git a/api/core/ops/langfuse_trace/entities/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed similarity index 100% rename from api/core/ops/langfuse_trace/entities/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index 4ce9e22fd77..b0691a87eae 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -2,11 +2,7 @@ from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, patch import pytest -from opentelemetry.sdk.trace import Tracer -from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes -from opentelemetry.trace import StatusCode - -from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( +from dify_trace_arize_phoenix.arize_phoenix_trace import ( ArizePhoenixDataTrace, datetime_to_nanos, error_to_string, @@ -15,7 +11,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( setup_tracer, wrap_span_metadata, ) -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from opentelemetry.sdk.trace import Tracer +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import StatusCode + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -80,7 +80,7 @@ def test_datetime_to_nanos(): expected = int(dt.timestamp() * 1_000_000_000) assert datetime_to_nanos(dt) == expected - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt: mock_now = MagicMock() mock_now.timestamp.return_value = 1704110400.0 mock_dt.now.return_value = mock_now @@ -142,8 +142,8 @@ def test_wrap_span_metadata(): assert res == {"a": 1, "b": 2, "created_from": "Dify"} -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_arize(mock_provider, mock_exporter): config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") setup_tracer(config) @@ -151,8 +151,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter): assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1" -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_phoenix(mock_provider, mock_exporter): config = PhoenixConfig(endpoint="http://p.com", project="p") setup_tracer(config) @@ -162,7 +162,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter): def test_setup_tracer_exception(): config = ArizeConfig(endpoint="http://a.com", project="p") - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): with pytest.raises(Exception, match="boom"): setup_tracer(config) @@ -172,7 +172,7 @@ def test_setup_tracer_exception(): @pytest.fixture def trace_instance(): - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup: mock_tracer = MagicMock(spec=Tracer) mock_processor = MagicMock() mock_setup.return_value = (mock_tracer, mock_processor) @@ -228,9 +228,9 @@ def test_trace_exception(trace_instance): trace_instance.trace(_make_workflow_info()) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -262,7 +262,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac assert trace_instance.tracer.start_span.call_count >= 2 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_no_app_id(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -271,7 +271,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance): trace_instance.workflow_trace(info) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_success(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() @@ -291,7 +291,7 @@ def test_message_trace_success(mock_db, trace_instance): assert trace_instance.tracer.start_span.call_count >= 1 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_with_error(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py index 6b5cb5b09a8..a01c63ae610 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ -from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes +from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind from openinference.semconv.trace import OpenInferenceSpanKindValues -from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py new file mode 100644 index 00000000000..11e951c3b1e --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py @@ -0,0 +1,88 @@ +import pytest +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from pydantic import ValidationError + + +class TestArizeConfig: + """Test cases for ArizeConfig""" + + def test_valid_config(self): + """Test valid Arize configuration""" + config = ArizeConfig( + api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" + ) + assert config.api_key == "test_key" + assert config.space_id == "test_space" + assert config.project == "test_project" + assert config.endpoint == "https://custom.arize.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = ArizeConfig() + assert config.api_key is None + assert config.space_id is None + assert config.project is None + assert config.endpoint == "https://otlp.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = ArizeConfig(project="") + assert config.project == "default" + + def test_project_validation_none(self): + """Test project validation with None value""" + config = ArizeConfig(project=None) + assert config.project == "default" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = ArizeConfig(endpoint="") + assert config.endpoint == "https://otlp.arize.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation normalizes URL by removing path""" + config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") + assert config.endpoint == "https://custom.arize.com" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="ftp://invalid.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="invalid.com") + + +class TestPhoenixConfig: + """Test cases for PhoenixConfig""" + + def test_valid_config(self): + """Test valid Phoenix configuration""" + config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.phoenix.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = PhoenixConfig() + assert config.api_key is None + assert config.project is None + assert config.endpoint == "https://app.phoenix.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = PhoenixConfig(project="") + assert config.project == "default" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" diff --git a/api/providers/trace/trace-langfuse/pyproject.toml b/api/providers/trace/trace-langfuse/pyproject.toml new file mode 100644 index 00000000000..27d2273a694 --- /dev/null +++ b/api/providers/trace/trace-langfuse/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langfuse" +version = "0.0.1" +dependencies = [ + "langfuse>=4.2.0,<5.0.0", +] +description = "Dify ops tracing provider (Langfuse)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langsmith_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py new file mode 100644 index 00000000000..90d1a2846bf --- /dev/null +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py @@ -0,0 +1,19 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class LangfuseConfig(BaseTracingConfig): + """ + Model class for Langfuse tracing config. + """ + + public_key: str + secret_key: str + host: str = "https://api.langfuse.com" + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://api.langfuse.com") diff --git a/api/core/ops/langsmith_trace/entities/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py similarity index 100% rename from api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py similarity index 93% rename from api/core/ops/langfuse_trace/langfuse_trace.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py index 9be2ce1bdfc..68881378a70 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py @@ -3,7 +3,6 @@ import os import uuid from datetime import UTC, datetime, timedelta -from graphon.enums import BuiltinNodeTypes from langfuse import Langfuse from langfuse.api import ( CreateGenerationBody, @@ -17,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -29,7 +27,10 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( +from core.ops.utils import filter_none_values +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( GenerationUsage, LangfuseGeneration, LangfuseSpan, @@ -37,9 +38,8 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( LevelEnum, UnitEnum, ) -from core.ops.utils import filter_none_values -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus @@ -59,6 +59,24 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") + @staticmethod + def _get_completion_start_time( + start_time: datetime | None, time_to_first_token: float | int | None + ) -> datetime | None: + """Convert a relative TTFT value in seconds into Langfuse's absolute completion start time.""" + if start_time is None or time_to_first_token is None: + return None + + try: + ttft_seconds = float(time_to_first_token) + except (TypeError, ValueError): + return None + + if ttft_seconds < 0: + return None + + return start_time + timedelta(seconds=ttft_seconds) + def trace(self, trace_info: BaseTraceInfo): if isinstance(trace_info, WorkflowTraceInfo): self.workflow_trace(trace_info) @@ -189,10 +207,18 @@ class LangFuseDataTrace(BaseTraceInstance): total_token = metadata.get("total_tokens", 0) prompt_tokens = 0 completion_tokens = 0 + completion_start_time = None try: - usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + usage_data = process_data.get("usage") + if not isinstance(usage_data, dict): + usage_data = outputs.get("usage") + if not isinstance(usage_data, dict): + usage_data = {} prompt_tokens = usage_data.get("prompt_tokens", 0) completion_tokens = usage_data.get("completion_tokens", 0) + completion_start_time = self._get_completion_start_time( + created_at, usage_data.get("time_to_first_token") + ) except Exception: logger.error("Failed to extract usage", exc_info=True) @@ -210,6 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance): trace_id=trace_id, model=process_data.get("model_name"), start_time=created_at, + completion_start_time=completion_start_time, end_time=finished_at, input=inputs, output=outputs, @@ -290,11 +317,16 @@ class LangFuseDataTrace(BaseTraceInstance): unit=UnitEnum.TOKENS, totalCost=message_data.total_price, ) + completion_start_time = self._get_completion_start_time( + trace_info.start_time, + trace_info.gen_ai_server_time_to_first_token, + ) langfuse_generation_data = LangfuseGeneration( name="llm", trace_id=trace_id, start_time=trace_info.start_time, + completion_start_time=completion_start_time, end_time=trace_info.end_time, model=message_data.model_id, input=trace_info.inputs, diff --git a/api/core/ops/mlflow_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed similarity index 100% rename from api/core/ops/mlflow_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py index 374371fb42d..952f10c34f9 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py @@ -5,9 +5,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( + LangfuseGeneration, + LangfuseSpan, + LangfuseTrace, + LevelEnum, + UnitEnum, +) +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -18,14 +25,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( - LangfuseGeneration, - LangfuseSpan, - LangfuseTrace, - LevelEnum, - UnitEnum, -) -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -43,7 +43,7 @@ def langfuse_config(): def trace_instance(langfuse_config, monkeypatch): # Mock Langfuse client to avoid network calls mock_client = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client) instance = LangFuseDataTrace(langfuse_config) return instance @@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch): def test_init(langfuse_config, monkeypatch): mock_langfuse = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangFuseDataTrace(langfuse_config) @@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): # Mock DB and Repositories mock_session = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="log-1", error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() @@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py new file mode 100644 index 00000000000..103d888eefe --- /dev/null +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py @@ -0,0 +1,42 @@ +import pytest +from dify_trace_langfuse.config import LangfuseConfig +from pydantic import ValidationError + + +class TestLangfuseConfig: + """Test cases for LangfuseConfig""" + + def test_valid_config(self): + """Test valid Langfuse configuration""" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == "https://custom.langfuse.com" + + def test_valid_config_with_path(self): + host = "https://custom.langfuse.com/api/v1" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == host + + def test_default_values(self): + """Test default values are set correctly""" + config = LangfuseConfig(public_key="public", secret_key="secret") + assert config.host == "https://api.langfuse.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangfuseConfig() + + with pytest.raises(ValidationError): + LangfuseConfig(public_key="public") + + with pytest.raises(ValidationError): + LangfuseConfig(secret_key="secret") + + def test_host_validation_empty(self): + """Test host validation with empty value""" + config = LangfuseConfig(public_key="public", secret_key="secret", host="") + assert config.host == "https://api.langfuse.com" diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py new file mode 100644 index 00000000000..0340ffb6695 --- /dev/null +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py @@ -0,0 +1,137 @@ +"""Tests for Langfuse TTFT reporting support.""" + +from datetime import datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace + +from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo +from graphon.enums import BuiltinNodeTypes + + +def _create_trace_instance() -> LangFuseDataTrace: + with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True): + return LangFuseDataTrace( + LangfuseConfig( + public_key="public-key", + secret_key="secret-key", + host="https://cloud.langfuse.com", + ) + ) + + +class TestLangFuseDataTraceCompletionStartTime: + def test_message_trace_reports_completion_start_time(self): + trace = _create_trace_instance() + start_time = datetime(2026, 3, 11, 13, 0, 0) + trace_info = MessageTraceInfo( + trace_id="trace-123", + message_id="message-123", + message_data=SimpleNamespace( + id="message-123", + from_account_id="account-1", + from_end_user_id=None, + conversation_id="conversation-1", + model_id="gpt-4o-mini", + answer="hi there", + status="normal", + error="", + total_price=0.12, + provider_response_latency=3.5, + ), + conversation_model="chat", + message_tokens=10, + answer_tokens=20, + total_tokens=30, + error="", + inputs="hello", + outputs="hi there", + file_list=[], + start_time=start_time, + end_time=start_time + timedelta(seconds=3.5), + metadata={}, + message_file_data=None, + conversation_mode="chat", + gen_ai_server_time_to_first_token=1.2, + llm_streaming_time_to_generate=2.3, + is_streaming_request=True, + ) + + with patch.object(trace, "add_trace"), patch.object(trace, "add_generation") as add_generation: + trace.message_trace(trace_info) + + generation = add_generation.call_args.args[0] + assert generation.completion_start_time == start_time + timedelta(seconds=1.2) + + def test_workflow_trace_reports_completion_start_time_from_llm_usage(self): + trace = _create_trace_instance() + start_time = datetime(2026, 3, 11, 13, 0, 0) + node_execution = SimpleNamespace( + id="node-exec-1", + title="Chat LLM", + node_type=BuiltinNodeTypes.LLM, + status="succeeded", + process_data={ + "model_mode": "chat", + "model_name": "gpt-4o-mini", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "time_to_first_token": 1.2, + }, + }, + inputs={"question": "hello"}, + outputs={"text": "hi there"}, + created_at=start_time, + elapsed_time=3.5, + metadata={}, + ) + trace_info = WorkflowTraceInfo( + trace_id="trace-123", + workflow_data={}, + conversation_id=None, + workflow_app_log_id=None, + workflow_id="workflow-1", + tenant_id="tenant-1", + workflow_run_id="workflow-run-1", + workflow_run_elapsed_time=3.5, + workflow_run_status="succeeded", + workflow_run_inputs={"question": "hello"}, + workflow_run_outputs={"answer": "hi there"}, + workflow_run_version="1", + error="", + total_tokens=30, + file_list=[], + query="hello", + metadata={"app_id": "app-1", "user_id": "user-1"}, + start_time=start_time, + end_time=start_time + timedelta(seconds=3.5), + ) + repository = MagicMock() + repository.get_by_workflow_execution.return_value = [node_execution] + + with ( + patch.object(trace, "add_trace"), + patch.object(trace, "add_span"), + patch.object(trace, "add_generation") as add_generation, + patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()), + patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()), + patch( + "dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=repository, + ), + ): + trace.workflow_trace(trace_info) + + generation = add_generation.call_args.kwargs["langfuse_generation_data"] + assert generation.completion_start_time == start_time + timedelta(seconds=1.2) + + def test_ignores_invalid_ttft_values(self): + trace = _create_trace_instance() + start_time = datetime(2026, 3, 11, 13, 0, 0) + + assert trace._get_completion_start_time(start_time, None) is None + assert trace._get_completion_start_time(start_time, -1) is None + assert trace._get_completion_start_time(start_time, "invalid") is None diff --git a/api/providers/trace/trace-langsmith/pyproject.toml b/api/providers/trace/trace-langsmith/pyproject.toml new file mode 100644 index 00000000000..8131952b280 --- /dev/null +++ b/api/providers/trace/trace-langsmith/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langsmith" +version = "0.0.1" +dependencies = [ + "langsmith~=0.7.30", +] +description = "Dify ops tracing provider (LangSmith)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/opik_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py similarity index 100% rename from api/core/ops/opik_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py new file mode 100644 index 00000000000..498b8c5e7e3 --- /dev/null +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py @@ -0,0 +1,20 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class LangSmithConfig(BaseTracingConfig): + """ + Model class for Langsmith tracing config. + """ + + api_key: str + project: str + endpoint: str = "https://api.smith.langchain.com" + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # LangSmith only allows HTTPS + return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) diff --git a/api/core/ops/tencent_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py similarity index 99% rename from api/core/ops/langsmith_trace/langsmith_trace.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py index 490c64af84d..145bd70dbc3 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py @@ -4,13 +4,11 @@ import uuid from datetime import datetime, timedelta from typing import cast -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -22,14 +20,16 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( +from core.ops.utils import filter_none_values, generate_dotted_order +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( LangSmithRunModel, LangSmithRunType, LangSmithRunUpdateModel, ) -from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/weave_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed similarity index 100% rename from api/core/ops/weave_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py rename to api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py index bfe916f0182..45e5894e4a0 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py @@ -3,9 +3,14 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( + LangSmithRunModel, + LangSmithRunType, + LangSmithRunUpdateModel, +) +from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -16,12 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( - LangSmithRunModel, - LangSmithRunType, - LangSmithRunUpdateModel, -) -from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -38,7 +38,7 @@ def langsmith_config(): def trace_instance(langsmith_config, monkeypatch): # Mock LangSmith client mock_client = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client) instance = LangSmithDataTrace(langsmith_config) return instance @@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch): def test_init(langsmith_config, monkeypatch): mock_client_class = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangSmithDataTrace(langsmith_config) @@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch): # Mock dependencies mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() @@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): trace_info.error = "" mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() @@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py new file mode 100644 index 00000000000..37efaf69cfd --- /dev/null +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py @@ -0,0 +1,35 @@ +import pytest +from dify_trace_langsmith.config import LangSmithConfig +from pydantic import ValidationError + + +class TestLangSmithConfig: + """Test cases for LangSmithConfig""" + + def test_valid_config(self): + """Test valid LangSmith configuration""" + config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.smith.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = LangSmithConfig(api_key="key", project="project") + assert config.endpoint == "https://api.smith.langchain.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangSmithConfig() + + with pytest.raises(ValidationError): + LangSmithConfig(api_key="key") + + with pytest.raises(ValidationError): + LangSmithConfig(project="project") + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") diff --git a/api/providers/trace/trace-mlflow/pyproject.toml b/api/providers/trace/trace-mlflow/pyproject.toml new file mode 100644 index 00000000000..fad60029449 --- /dev/null +++ b/api/providers/trace/trace-mlflow/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-mlflow" +version = "0.0.1" +dependencies = [ + "mlflow-skinny>=3.11.1", +] +description = "Dify ops tracing provider (MLflow / Databricks)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py similarity index 100% rename from api/core/ops/weave_trace/entities/__init__.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py new file mode 100644 index 00000000000..84914165e3b --- /dev/null +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py @@ -0,0 +1,46 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_integer_id, validate_url_with_path + + +class MLflowConfig(BaseTracingConfig): + """ + Model class for MLflow tracing config. + """ + + tracking_uri: str = "http://localhost:5000" + experiment_id: str = "0" # Default experiment id in MLflow is 0 + username: str | None = None + password: str | None = None + + @field_validator("tracking_uri") + @classmethod + def tracking_uri_validator(cls, v, info: ValidationInfo): + if isinstance(v, str) and v.startswith("databricks"): + raise ValueError( + "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." + ) + return validate_url_with_path(v, "http://localhost:5000") + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) + + +class DatabricksConfig(BaseTracingConfig): + """ + Model class for Databricks (Databricks-managed MLflow) tracing config. + """ + + experiment_id: str + host: str + client_id: str | None = None + client_secret: str | None = None + personal_access_token: str | None = None + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py similarity index 98% rename from api/core/ops/mlflow_trace/mlflow_trace.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py index 3d8c1dd038a..4e4c45a5324 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py @@ -4,7 +4,6 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow -from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace @@ -12,7 +11,6 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -25,7 +23,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import JSON_DICT_ADAPTER +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel @@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance): return inputs, attributes - def _parse_knowledge_retrieval_outputs(self, outputs: dict): + def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]): """Parse KR outputs and attributes from KR workflow node""" retrieved = outputs.get("result", []) @@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance): end_time_ns=datetime_to_nanoseconds(trace_info.end_time), ) - def _get_message_user_id(self, metadata: dict) -> str | None: + def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None: if (end_user_id := metadata.get("from_end_user_id")) and ( end_user_data := db.session.get(EndUser, end_user_id) ): @@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance): } return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload] - def _set_trace_metadata(self, span: Span, metadata: dict): + def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]): token = None try: # NB: Set span in context such that we can use update_current_trace() API @@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance): return messages return prompts # Fallback to original format - def _parse_single_message(self, item: dict): + def _parse_single_message(self, item: dict[str, Any]): """Postprocess single message format to be standard chat message""" role = item.get("role", "user") msg = {"role": role, "content": item.get("text", "")} diff --git a/api/core/rag/datasource/vdb/alibabacloud_mysql/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed similarity index 100% rename from api/core/rag/datasource/vdb/alibabacloud_mysql/__init__.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py similarity index 98% rename from api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py rename to api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py index f4c485a9fc5..20211456e36 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module.""" +"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module.""" from __future__ import annotations @@ -9,9 +9,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig +from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds +from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -179,7 +179,7 @@ def _make_node(**overrides): @pytest.fixture def mock_mlflow(): - with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock: + with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock: yield mock @@ -187,10 +187,10 @@ def mock_mlflow(): def mock_tracing(): """Patch all MLflow tracing functions used by the module.""" with ( - patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start, - patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update, - patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set, - patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach, + patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start, + patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update, + patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set, + patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach, ): yield { "start": mock_start, @@ -202,7 +202,7 @@ def mock_tracing(): @pytest.fixture def mock_db(): - with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock: + with patch("dify_trace_mlflow.mlflow_trace.db") as mock: yield mock diff --git a/api/providers/trace/trace-opik/pyproject.toml b/api/providers/trace/trace-opik/pyproject.toml new file mode 100644 index 00000000000..874997168ec --- /dev/null +++ b/api/providers/trace/trace-opik/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-opik" +version = "0.0.1" +dependencies = [ + "opik~=1.11.2", +] +description = "Dify ops tracing provider (Opik)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/analyticdb/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/analyticdb/__init__.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/config.py b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py new file mode 100644 index 00000000000..c16ff1d9035 --- /dev/null +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py @@ -0,0 +1,25 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class OpikConfig(BaseTracingConfig): + """ + Model class for Opik tracing config. + """ + + api_key: str | None = None + project: str | None = None + workspace: str | None = None + url: str = "https://www.comet.com/opik/api/" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "Default Project") + + @field_validator("url") + @classmethod + def url_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py similarity index 98% rename from api/core/ops/opik_trace/opik_trace.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py index 2215bdeb33b..2d124ac989c 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py @@ -3,15 +3,13 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import cast +from typing import Any, cast -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -24,7 +22,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory +from dify_trace_opik.config import OpikConfig from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance): self.add_span(span_data) - def add_trace(self, opik_trace_data: dict) -> Trace: + def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace: try: trace = self.opik_client.trace(**opik_trace_data) logger.debug("Opik Trace created successfully") @@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"Opik Failed to create trace: {str(e)}") - def add_span(self, opik_span_data: dict): + def add_span(self, opik_span_data: dict[str, Any]): try: self.opik_client.span(**opik_span_data) logger.debug("Opik Span created successfully") diff --git a/api/core/rag/datasource/vdb/baidu/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed similarity index 100% rename from api/core/rag/datasource/vdb/baidu/__init__.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/py.typed diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py index 1cb32f2ee02..eefed3c78c7 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py @@ -5,9 +5,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from dify_trace_opik.config import OpikConfig +from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -18,7 +18,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -37,7 +37,7 @@ def opik_config(): @pytest.fixture def trace_instance(opik_config, monkeypatch): mock_client = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client) instance = OpikDataTrace(opik_config) return instance @@ -67,7 +67,7 @@ def test_prepare_opik_uuid(): def test_init(opik_config, monkeypatch): mock_opik = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik) monkeypatch.setenv("FILES_URL", "http://test.url") instance = OpikDataTrace(opik_config) @@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) node_llm = MagicMock() node_llm.id = LLM_NODE_ID @@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e", error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() @@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py new file mode 100644 index 00000000000..5a54b70bba7 --- /dev/null +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py @@ -0,0 +1,48 @@ +import pytest +from dify_trace_opik.config import OpikConfig +from pydantic import ValidationError + + +class TestOpikConfig: + """Test cases for OpikConfig""" + + def test_valid_config(self): + """Test valid Opik configuration""" + config = OpikConfig( + api_key="test_key", + project="test_project", + workspace="test_workspace", + url="https://custom.comet.com/opik/api/", + ) + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.workspace == "test_workspace" + assert config.url == "https://custom.comet.com/opik/api/" + + def test_default_values(self): + """Test default values are set correctly""" + config = OpikConfig() + assert config.api_key is None + assert config.project is None + assert config.workspace is None + assert config.url == "https://www.comet.com/opik/api/" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = OpikConfig(project="") + assert config.project == "Default Project" + + def test_url_validation_empty(self): + """Test URL validation with empty value""" + config = OpikConfig(url="") + assert config.url == "https://www.comet.com/opik/api/" + + def test_url_validation_missing_suffix(self): + """Test URL validation requires /api/ suffix""" + with pytest.raises(ValidationError, match="URL should end with /api/"): + OpikConfig(url="https://custom.comet.com/opik/") + + def test_url_validation_invalid_scheme(self): + """Test URL validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + OpikConfig(url="ftp://custom.comet.com/opik/api/") diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py index ad9d0846be1..fba290f5b8c 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py @@ -14,8 +14,9 @@ import uuid from datetime import datetime from unittest.mock import MagicMock, patch +from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid + from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo -from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid # A stable UUID4 used as the workflow_run_id throughout all tests. _WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6" @@ -56,8 +57,8 @@ def _make_workflow_trace_info( def _make_opik_trace_instance() -> OpikDataTrace: """Construct an OpikDataTrace with the Opik SDK client mocked out.""" - with patch("core.ops.opik_trace.opik_trace.Opik"): - from core.ops.entities.config_entity import OpikConfig + with patch("dify_trace_opik.opik_trace.Opik"): + from dify_trace_opik.config import OpikConfig config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/") instance = OpikDataTrace(config) @@ -133,10 +134,10 @@ class TestWorkflowTraceWithoutMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): @@ -265,10 +266,10 @@ class TestWorkflowTraceWithMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): diff --git a/api/providers/trace/trace-tencent/pyproject.toml b/api/providers/trace/trace-tencent/pyproject.toml new file mode 100644 index 00000000000..eab06fc7080 --- /dev/null +++ b/api/providers/trace/trace-tencent/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-tencent" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Tencent APM)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/chroma/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/chroma/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py diff --git a/api/core/ops/tencent_trace/client.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py similarity index 100% rename from api/core/ops/tencent_trace/client.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py new file mode 100644 index 00000000000..398e6c55a86 --- /dev/null +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py @@ -0,0 +1,30 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig + + +class TencentConfig(BaseTracingConfig): + """ + Tencent APM tracing config + """ + + token: str + endpoint: str + service_name: str + + @field_validator("token") + @classmethod + def token_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("Token cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") + + @field_validator("service_name") + @classmethod + def service_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") diff --git a/api/core/ops/tencent_trace/entities/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/entities/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py diff --git a/api/core/ops/tencent_trace/entities/semconv.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py similarity index 100% rename from api/core/ops/tencent_trace/entities/semconv.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py diff --git a/api/core/ops/tencent_trace/entities/tencent_trace_entity.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py similarity index 100% rename from api/core/ops/tencent_trace/entities/tencent_trace_entity.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py diff --git a/api/core/rag/datasource/vdb/couchbase/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed similarity index 100% rename from api/core/rag/datasource/vdb/couchbase/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py similarity index 98% rename from api/core/ops/tencent_trace/span_builder.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py index f79095d9662..763a85ffd71 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py @@ -6,8 +6,6 @@ import json import logging from datetime import datetime -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -16,7 +14,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_tencent.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, GEN_AI_IS_ENTRY, @@ -40,9 +39,10 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.utils import TencentTraceUtils -from core.rag.models.document import Document +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.utils import TencentTraceUtils +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py similarity index 93% rename from api/core/ops/tencent_trace/tencent_trace.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py index 2bd6db22bf7..a8c480e4a5d 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py @@ -1,18 +1,12 @@ -""" -Tencent APM tracing implementation with separated concerns -""" +"""Tencent APM tracing with idempotent client cleanup.""" +import inspect import logging -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -23,12 +17,17 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.client import TencentTraceClient -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.span_builder import TencentSpanBuilder -from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from dify_trace_tencent.client import TencentTraceClient +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.span_builder import TencentSpanBuilder +from dify_trace_tencent.utils import TencentTraceUtils from extensions.ext_database import db +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance): """ Tencent APM trace implementation with single responsibility principle. Acts as a coordinator that delegates specific tasks to specialized classes. + + The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen + explicitly in tests or implicitly during garbage collection, so shutdown + must be safe to call multiple times. """ + trace_client: TencentTraceClient + _closed: bool + def __init__(self, tencent_config: TencentConfig): super().__init__(tencent_config) + self._closed = False self.trace_client = TencentTraceClient( service_name=tencent_config.service_name, endpoint=tencent_config.endpoint, @@ -241,8 +248,10 @@ class TencentDataTrace(BaseTraceInstance): if not service_account: raise ValueError(f"Creator account not found for app {app_id}") - current_tenant = ( - session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() + current_tenant = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True)) + .limit(1) ) if not current_tenant: raise ValueError(f"Current tenant not found for account {service_account.id}") @@ -511,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance): except Exception: logger.debug("[Tencent APM] Failed to record message trace duration") - def __del__(self): - """Ensure proper cleanup on garbage collection.""" + def close(self) -> None: + """Synchronously and idempotently shutdown the underlying trace client.""" + if getattr(self, "_closed", False): + return + + self._closed = True + trace_client = getattr(self, "trace_client", None) + if trace_client is None: + return + try: - if hasattr(self, "trace_client"): - self.trace_client.shutdown() + shutdown_result = trace_client.shutdown() + if inspect.isawaitable(shutdown_result): + close_awaitable = getattr(shutdown_result, "close", None) + if callable(close_awaitable): + close_awaitable() except Exception: logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup") + + def __del__(self): + """Ensure best-effort cleanup on garbage collection without retrying shutdown.""" + self.close() diff --git a/api/core/ops/tencent_trace/utils.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py similarity index 100% rename from api/core/ops/tencent_trace/utils.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py similarity index 98% rename from api/tests/unit_tests/core/ops/tencent_trace/test_client.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py index 870c18e53ee..1e656e24623 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py @@ -8,13 +8,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from dify_trace_tencent import client as client_module +from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version +from dify_trace_tencent.entities.tencent_trace_entity import SpanData from opentelemetry.sdk.trace import Event from opentelemetry.trace import Status, StatusCode -from core.ops.tencent_trace import client as client_module -from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData - metric_reader_instances: list[DummyMetricReader] = [] meter_provider_instances: list[DummyMeterProvider] = [] diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py similarity index 89% rename from api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py index 696f859b6f7..e850a801f3b 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py @@ -1,17 +1,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from opentelemetry.trace import StatusCode - -from core.ops.entities.trace_entity import ( - DatasetRetrievalTraceInfo, - MessageTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.ops.tencent_trace.entities.semconv import ( +from dify_trace_tencent.entities.semconv import ( GEN_AI_IS_ENTRY, GEN_AI_IS_STREAMING_REQUEST, GEN_AI_MODEL_NAME, @@ -25,13 +15,23 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.span_builder import TencentSpanBuilder +from dify_trace_tencent.span_builder import TencentSpanBuilder +from opentelemetry.trace import StatusCode + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + MessageTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) from core.rag.models.document import Document +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: def test_get_time_nanoseconds(self): - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: mock_convert.return_value = 123456789 dt = datetime.now() result = TencentSpanBuilder._get_time_nanoseconds(dt) @@ -48,7 +48,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {"answer": "world"} trace_info.metadata = {"conversation_id": "conv_id"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -70,7 +70,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {} trace_info.metadata = {} # No conversation_id - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 1 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -98,7 +98,7 @@ class TestTencentSpanBuilder: } node_execution.outputs = {"text": "world"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -123,7 +123,7 @@ class TestTencentSpanBuilder: "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, } - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -142,7 +142,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {"conversation_id": "conv_id"} trace_info.is_streaming_request = True - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -162,7 +162,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {} trace_info.is_streaming_request = False - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -182,7 +182,7 @@ class TestTencentSpanBuilder: trace_info.tool_inputs = {"i": 2} trace_info.tool_outputs = "result" - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 101 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1) @@ -204,7 +204,7 @@ class TestTencentSpanBuilder: ) trace_info.documents = [doc] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -222,7 +222,7 @@ class TestTencentSpanBuilder: trace_info.end_time = datetime.now() trace_info.documents = [] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -264,7 +264,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -286,7 +286,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -307,7 +307,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -329,7 +329,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -350,7 +350,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 505 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py similarity index 86% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py index 382e5dadc3b..54524b09ca6 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py @@ -1,11 +1,12 @@ +import gc import logging -from unittest.mock import MagicMock, patch +import warnings +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.tencent_trace import TencentDataTrace -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -15,7 +16,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.tencent_trace import TencentDataTrace +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) @@ -28,19 +30,19 @@ def tencent_config(): @pytest.fixture def mock_trace_client(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceClient") as mock: yield mock @pytest.fixture def mock_span_builder(): - with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentSpanBuilder") as mock: yield mock @pytest.fixture def mock_trace_utils(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceUtils") as mock: yield mock @@ -198,9 +200,9 @@ class TestTencentDataTrace: trace_info.workflow_run_id = "run-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.workflow_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") @@ -230,9 +232,9 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.message_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") @@ -262,9 +264,9 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.tool_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") @@ -294,22 +296,22 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.dataset_retrieval_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") def test_suggested_question_trace(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") def test_suggested_question_trace_exception(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") @@ -342,7 +344,7 @@ class TestTencentDataTrace: with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) # The exception should be caught by the outer handler since convert_to_span_id is called first mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -351,7 +353,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -381,7 +383,7 @@ class TestTencentDataTrace: node.id = "n1" mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) assert result is None mock_log.assert_called_once() @@ -403,16 +405,13 @@ class TestTencentDataTrace: mock_executions = [MagicMock()] - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value - session.scalar.side_effect = [app, account] - session.query.return_value.filter_by.return_value.first.return_value = tenant_join + session.scalar.side_effect = [app, account, tenant_join] - with patch( - "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" - ) as mock_repo: + with patch("dify_trace_tencent.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository") as mock_repo: mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) @@ -424,7 +423,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {} - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -433,14 +432,14 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {"app_id": "app-1"} - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() # Ensure init_app is mocked mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.return_value = None - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -450,8 +449,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "tenant-1" trace_info.metadata = {"user_id": "user-1"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")): - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")): + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() mock_db.engine = MagicMock() @@ -477,8 +476,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "t" trace_info.metadata = {"user_id": "u"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: user_id = tencent_data_trace._get_user_id(trace_info) assert user_id == "unknown" mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") @@ -520,7 +519,7 @@ class TestTencentDataTrace: node.process_data = None node.outputs = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_llm_metrics(node) # Should not crash @@ -558,7 +557,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.metadata = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_llm_metrics(trace_info) # Should not crash @@ -610,7 +609,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_workflow_trace_duration(trace_info) def test_record_message_trace_duration(self, tencent_data_trace): @@ -632,16 +631,41 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.start_time = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_trace_duration(trace_info) - def test_del(self, tencent_data_trace): + def test_close(self, tencent_data_trace): client = tencent_data_trace.trace_client - tencent_data_trace.__del__() + tencent_data_trace.close() client.shutdown.assert_called_once() - def test_del_exception(self, tencent_data_trace): + def test_close_is_idempotent(self, tencent_data_trace): + client = tencent_data_trace.trace_client + + tencent_data_trace.close() + tencent_data_trace.close() + + client.shutdown.assert_called_once() + + def test_close_exception(self, tencent_data_trace): tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.__del__() + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.close() mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") + + def test_close_handles_async_shutdown_mock(self, tencent_data_trace): + shutdown = AsyncMock() + tencent_data_trace.trace_client.shutdown = shutdown + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + tencent_data_trace.close() + gc.collect() + + shutdown.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) + and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py similarity index 88% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py index ef28d18e200..63c6d680d78 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py @@ -8,10 +8,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from dify_trace_tencent.utils import TencentTraceUtils from opentelemetry.trace import Link, TraceFlags -from core.ops.tencent_trace.utils import TencentTraceUtils - def test_convert_to_trace_id_with_valid_uuid() -> None: uuid_str = "12345678-1234-5678-1234-567812345678" @@ -20,7 +19,7 @@ def test_convert_to_trace_id_with_valid_uuid() -> None: def test_convert_to_trace_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int uuid4_mock.assert_called_once() @@ -45,7 +44,7 @@ def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None: def test_convert_to_span_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: span_id = TencentTraceUtils.convert_to_span_id(None, "workflow") assert isinstance(span_id, int) uuid4_mock.assert_called_once() @@ -58,7 +57,7 @@ def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None: def test_generate_span_id_skips_invalid_span_id() -> None: with patch( - "core.ops.tencent_trace.utils.random.getrandbits", + "dify_trace_tencent.utils.random.getrandbits", side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42], ) as bits_mock: assert TencentTraceUtils.generate_span_id() == 42 @@ -75,7 +74,7 @@ def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None: fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) expected = int(fixed.timestamp() * 1e9) - with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock: + with patch("dify_trace_tencent.utils.datetime") as datetime_mock: datetime_mock.now.return_value = fixed assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected datetime_mock.now.assert_called_once() @@ -100,7 +99,7 @@ def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: i @pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None]) def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None: fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type] assert link.context.trace_id == fallback_uuid.int uuid4_mock.assert_called_once() diff --git a/api/providers/trace/trace-weave/pyproject.toml b/api/providers/trace/trace-weave/pyproject.toml new file mode 100644 index 00000000000..ba449f2a933 --- /dev/null +++ b/api/providers/trace/trace-weave/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-weave" +version = "0.0.1" +dependencies = [ + "weave>=0.52.36", +] +description = "Dify ops tracing provider (Weave)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/elasticsearch/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/elasticsearch/__init__.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/config.py b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py new file mode 100644 index 00000000000..5942bd57fe5 --- /dev/null +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py @@ -0,0 +1,29 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class WeaveConfig(BaseTracingConfig): + """ + Model class for Weave tracing config. + """ + + api_key: str + entity: str | None = None + project: str + endpoint: str = "https://trace.wandb.ai" + host: str | None = None + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # Weave only allows HTTPS for endpoint + return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + if v is not None and v.strip() != "": + return validate_url(v, v, allowed_schemes=("https", "http")) + return v diff --git a/api/core/rag/datasource/vdb/hologres/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/hologres/__init__.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py similarity index 100% rename from api/core/ops/weave_trace/entities/weave_trace_entity.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py diff --git a/api/core/rag/datasource/vdb/huawei/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed similarity index 100% rename from api/core/rag/datasource/vdb/huawei/__init__.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/py.typed diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py similarity index 99% rename from api/core/ops/weave_trace/weave_trace.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py index 8d9ba4694d9..4292cbf0f1b 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py @@ -6,7 +6,6 @@ from typing import Any, cast import wandb import weave -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -18,7 +17,6 @@ from weave.trace_server.trace_server_interface import ( ) from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -30,9 +28,11 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py new file mode 100644 index 00000000000..eeb1fe1d877 --- /dev/null +++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py @@ -0,0 +1,61 @@ +import pytest +from dify_trace_weave.config import WeaveConfig +from pydantic import ValidationError + + +class TestWeaveConfig: + """Test cases for WeaveConfig""" + + def test_valid_config(self): + """Test valid Weave configuration""" + config = WeaveConfig( + api_key="test_key", + entity="test_entity", + project="test_project", + endpoint="https://custom.wandb.ai", + host="https://custom.host.com", + ) + assert config.api_key == "test_key" + assert config.entity == "test_entity" + assert config.project == "test_project" + assert config.endpoint == "https://custom.wandb.ai" + assert config.host == "https://custom.host.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = WeaveConfig(api_key="key", project="project") + assert config.entity is None + assert config.endpoint == "https://trace.wandb.ai" + assert config.host is None + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + WeaveConfig() + + with pytest.raises(ValidationError): + WeaveConfig(api_key="key") + + with pytest.raises(ValidationError): + WeaveConfig(project="project") + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") + + def test_host_validation_optional(self): + """Test host validation is optional but validates when provided""" + config = WeaveConfig(api_key="key", project="project", host=None) + assert config.host is None + + config = WeaveConfig(api_key="key", project="project", host="") + assert config.host == "" + + config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") + assert config.host == "https://valid.host.com" + + def test_host_validation_invalid_scheme(self): + """Test host validation rejects invalid schemes when provided""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py similarity index 97% rename from api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py rename to api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py index 5014f40afca..6028d0c550f 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.weave_trace.weave_trace module.""" +"""Comprehensive tests for dify_trace_weave.weave_trace module.""" from __future__ import annotations @@ -7,10 +7,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel +from dify_trace_weave.weave_trace import WeaveDataTrace from weave.trace_server.trace_server_interface import TraceStatus -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -21,8 +22,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.ops.weave_trace.weave_trace import WeaveDataTrace +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -191,14 +191,14 @@ def _make_node(**overrides): @pytest.fixture def mock_wandb(): - with patch("core.ops.weave_trace.weave_trace.wandb") as mock: + with patch("dify_trace_weave.weave_trace.wandb") as mock: mock.login.return_value = True yield mock @pytest.fixture def mock_weave(): - with patch("core.ops.weave_trace.weave_trace.weave") as mock: + with patch("dify_trace_weave.weave_trace.weave") as mock: client = MagicMock() client.entity = "my-entity" client.project = "my-project" @@ -307,7 +307,7 @@ class TestGetProjectUrl: monkeypatch.setattr(trace_instance, "entity", None) monkeypatch.setattr(trace_instance, "project_name", None) # Force an error by making string formatting fail - with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger: + with patch("dify_trace_weave.weave_trace.logger") as mock_logger: # Simulate exception via property original_entity = trace_instance.entity trace_instance.entity = None @@ -594,9 +594,9 @@ class TestWorkflowTrace: mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) return repo def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch): @@ -703,8 +703,8 @@ class TestWorkflowTrace: def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch): """Raises ValueError when app_id is missing from metadata.""" - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) trace_info = _make_workflow_trace_info( message_id=None, @@ -802,7 +802,7 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.get", + "dify_trace_weave.weave_trace.db.session.get", lambda model, pk: None, ) @@ -824,7 +824,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -846,7 +846,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = end_user - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -866,7 +866,7 @@ class TestMessageTrace: """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -884,7 +884,7 @@ class TestMessageTrace: """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -899,7 +899,7 @@ class TestMessageTrace: """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() diff --git a/api/providers/vdb/README.md b/api/providers/vdb/README.md new file mode 100644 index 00000000000..b5b4197f63c --- /dev/null +++ b/api/providers/vdb/README.md @@ -0,0 +1,58 @@ +# VDB providers + +This directory contains all VDB providers. + +## Architecture +1. **Core** (`api/core/rag/datasource/vdb/`) defines the contracts and loads plugins. +2. **Each provider** (`api/providers/vdb//`) implements those contracts and registers an entry point. +3. At runtime, **`importlib.metadata.entry_points`** resolves the backend name (e.g. `pgvector`) to a factory class. The registry caches loaded classes (see `vector_backend_registry.py`). + +### Interfaces + +| Piece | Role | +|--------|----------| +| `AbstractVectorFactory` | You subclass this. Implement `init_vector(dataset, attributes, embeddings) -> BaseVector`. Optionally use `gen_index_struct_dict()` for new datasets. | +| `BaseVector` | Your store class subclasses this: `create`, `add_texts`, `search_by_vector`, `delete`, etc. | +| `VectorType` | `StrEnum` of supported backend **string ids**. Add a member when you introduce a new backend that should be selectable like existing ones. | +| Discovery | Loads `dify.vector_backends` entry points and caches `get_vector_factory_class(vector_type)`. | + +The high-level caller is `Vector` in `vector_factory.py`: it reads the configured or dataset-specific vector type, calls `get_vector_factory_class`, instantiates the factory, and uses the returned `BaseVector` implementation. + +### Entry point name must match the vector type string + +Entry points are registered under the group **`dify.vector_backends`**. The **entry point name** (left-hand side) must be exactly the string used as `vector_type` everywhere else—typically the **`VectorType` enum value** (e.g. `PGVECTOR = "pgvector"` → entry point name `pgvector`; `TIDB_ON_QDRANT = "tidb_on_qdrant"` → `tidb_on_qdrant`). + +In `pyproject.toml`: + +```toml +[project.entry-points."dify.vector_backends"] +pgvector = "dify_vdb_pgvector.pgvector:PGVectorFactory" +``` + +The value is **`module:attribute`**: a importable module path and the class implementing `AbstractVectorFactory`. + +### How registration works + +1. On first use, `get_vector_factory_class(vector_type)` looks up `vector_type` in a process cache. +2. If missing, it scans **`entry_points().select(group="dify.vector_backends")`** for an entry whose **`name` equals `vector_type`**. +3. It loads that entry (`ep.load()`), which must return the **factory class** (not an instance). +4. There is an optional internal map `_BUILTIN_VECTOR_FACTORY_TARGETS` for non-distribution builtins; **normal VDB plugins use entry points only**. + +After you change a provider’s `pyproject.toml` (entry points or dependencies), run **`uv sync`** in `api/` so the installed environment’s dist-info matches the project metadata. + +### Package layout (VDB) + +Each backend usually follows: + +- `api/providers/vdb//pyproject.toml` — project name `dify-vdb-`, dependencies, entry points. +- `api/providers/vdb//src/dify_vdb_/` — implementation (e.g. `PGVector`, `PGVectorFactory`). + +See `vdb/pgvector/` as a reference implementation. + +### Wiring a new backend into the API workspace + +The API uses a **uv workspace** (`api/pyproject.toml`): + +1. **`[tool.uv.workspace]`** — `members = ["providers/vdb/*"]` already includes every subdirectory under `vdb/`; new folders there are workspace members. +2. **`[tool.uv.sources]`** — add a line for your package: `dify-vdb-mine = { workspace = true }`. +3. **`[project.optional-dependencies]`** — add a group such as `vdb-mine = ["dify-vdb-mine"]`, and list `dify-vdb-mine` under `vdb-all` if it should install with the default bundle. \ No newline at end of file diff --git a/api/providers/vdb/conftest.py b/api/providers/vdb/conftest.py new file mode 100644 index 00000000000..c4b1cdef29e --- /dev/null +++ b/api/providers/vdb/conftest.py @@ -0,0 +1,22 @@ +from unittest.mock import MagicMock + +import pytest + +from extensions import ext_redis + + +@pytest.fixture(autouse=True) +def _init_mock_redis(): + """Ensure redis_client has a backing client so __getattr__ never raises.""" + if ext_redis.redis_client._client is None: + ext_redis.redis_client.initialize(MagicMock()) + + +@pytest.fixture +def setup_mock_redis(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(ext_redis.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(ext_redis.redis_client, "set", MagicMock(return_value=None)) + mock_redis_lock = MagicMock() + mock_redis_lock.__enter__ = MagicMock() + mock_redis_lock.__exit__ = MagicMock() + monkeypatch.setattr(ext_redis.redis_client, "lock", mock_redis_lock) diff --git a/api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml b/api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml new file mode 100644 index 00000000000..bbc0e06ffab --- /dev/null +++ b/api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "dify-vdb-alibabacloud-mysql" +version = "0.0.1" +dependencies = [ + "mysql-connector-python>=9.3.0", +] +description = "Dify vector store backend (dify-vdb-alibabacloud-mysql)." + +[project.entry-points."dify.vector_backends"] +alibabacloud_mysql = "dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector:AlibabaCloudMySQLVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/iris/__init__.py b/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/iris/__init__.py rename to api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/__init__.py diff --git a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py b/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py rename to api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py index 6e76827a422..37ffd110630 100644 --- a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py +++ b/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py @@ -35,7 +35,7 @@ class AlibabaCloudMySQLVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values.get("host"): raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required") if not values.get("port"): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py similarity index 94% rename from api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py rename to api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py index e063a49f22d..a907f918c36 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py +++ b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py @@ -1,10 +1,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +import dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module import pytest - -import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module -from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory +from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory def test_validate_distance_function_accepts_supported_values(): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_vector.py b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py similarity index 87% rename from api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_vector.py rename to api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py index 8ccd739e64f..54eeb78ca98 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_vector.py +++ b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py @@ -3,11 +3,11 @@ import unittest from unittest.mock import MagicMock, patch import pytest - -from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import ( +from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import ( AlibabaCloudMySQLVector, AlibabaCloudMySQLVectorConfig, ) + from core.rag.models.document import Document try: @@ -49,9 +49,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): # Sample embeddings self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_init(self, mock_pool_class): """Test AlibabaCloudMySQLVector initialization.""" # Mock the connection pool @@ -76,10 +74,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert alibabacloud_mysql_vector.distance_function == "cosine" assert alibabacloud_mysql_vector.pool is not None - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) - @patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") def test_create_collection(self, mock_redis, mock_pool_class): """Test collection creation.""" # Mock Redis operations @@ -110,9 +106,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes mock_redis.set.assert_called_once() - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_vector_support_check_success(self, mock_pool_class): """Test successful vector support check.""" # Mock the connection pool @@ -129,9 +123,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) assert vector_store is not None - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_vector_support_check_failure(self, mock_pool_class): """Test vector support check failure.""" # Mock the connection pool @@ -149,9 +141,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "RDS MySQL Vector functions are not available" in str(context.value) - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_vector_support_check_function_error(self, mock_pool_class): """Test vector support check with function not found error.""" # Mock the connection pool @@ -170,10 +160,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "RDS MySQL Vector functions are not available" in str(context.value) - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) - @patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") def test_create_documents(self, mock_redis, mock_pool_class): """Test creating documents with embeddings.""" # Setup mocks @@ -186,9 +174,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "doc1" in result assert "doc2" in result - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_add_texts(self, mock_pool_class): """Test adding texts to the vector store.""" # Mock the connection pool @@ -207,9 +193,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert len(result) == 2 mock_cursor.executemany.assert_called_once() - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_text_exists(self, mock_pool_class): """Test checking if text exists.""" # Mock the connection pool @@ -236,9 +220,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "SELECT id FROM" in last_call[0][0] assert last_call[0][1] == ("doc1",) - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_text_not_exists(self, mock_pool_class): """Test checking if text does not exist.""" # Mock the connection pool @@ -260,9 +242,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert not exists - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_get_by_ids(self, mock_pool_class): """Test getting documents by IDs.""" # Mock the connection pool @@ -288,9 +268,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert docs[0].page_content == "Test document 1" assert docs[1].page_content == "Test document 2" - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_get_by_ids_empty_list(self, mock_pool_class): """Test getting documents with empty ID list.""" # Mock the connection pool @@ -308,9 +286,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert len(docs) == 0 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_by_ids(self, mock_pool_class): """Test deleting documents by IDs.""" # Mock the connection pool @@ -334,9 +310,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "DELETE FROM" in delete_call[0][0] assert delete_call[0][1] == ["doc1", "doc2"] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_by_ids_empty_list(self, mock_pool_class): """Test deleting with empty ID list.""" # Mock the connection pool @@ -357,9 +331,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): delete_calls = [call for call in execute_calls if "DELETE" in str(call)] assert len(delete_calls) == 0 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_by_ids_table_not_exists(self, mock_pool_class): """Test deleting when table doesn't exist.""" # Mock the connection pool @@ -384,9 +356,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): # Should not raise an exception vector_store.delete_by_ids(["doc1"]) - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_by_metadata_field(self, mock_pool_class): """Test deleting documents by metadata field.""" # Mock the connection pool @@ -410,9 +380,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0] assert delete_call[0][1] == ("$.document_id", "dataset1") - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_cosine(self, mock_pool_class): """Test vector search with cosine distance.""" # Mock the connection pool @@ -437,9 +405,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9 assert docs[0].metadata["distance"] == 0.1 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_euclidean(self, mock_pool_class): """Test vector search with euclidean distance.""" config = AlibabaCloudMySQLVectorConfig( @@ -472,9 +438,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert len(docs) == 1 assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_with_filter(self, mock_pool_class): """Test vector search with document ID filter.""" # Mock the connection pool @@ -499,9 +463,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): search_call = search_calls[0] assert "WHERE JSON_UNQUOTE" in search_call[0][0] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_with_score_threshold(self, mock_pool_class): """Test vector search with score threshold.""" # Mock the connection pool @@ -536,9 +498,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].page_content == "High similarity document" - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_invalid_top_k(self, mock_pool_class): """Test vector search with invalid top_k.""" # Mock the connection pool @@ -560,9 +520,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): with pytest.raises(ValueError): vector_store.search_by_vector(query_vector, top_k="invalid") - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_full_text(self, mock_pool_class): """Test full-text search.""" # Mock the connection pool @@ -591,9 +549,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert docs[0].page_content == "This document contains machine learning content" assert docs[0].metadata["score"] == 1.5 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_full_text_with_filter(self, mock_pool_class): """Test full-text search with document ID filter.""" # Mock the connection pool @@ -617,9 +573,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): search_call = search_calls[0] assert "AND JSON_UNQUOTE" in search_call[0][0] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_full_text_invalid_top_k(self, mock_pool_class): """Test full-text search with invalid top_k.""" # Mock the connection pool @@ -640,9 +594,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): with pytest.raises(ValueError): vector_store.search_by_full_text("test", top_k="invalid") - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_collection(self, mock_pool_class): """Test deleting the entire collection.""" # Mock the connection pool @@ -665,9 +617,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): drop_call = drop_calls[0] assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_unsupported_distance_function(self, mock_pool_class): """Test that Pydantic validation rejects unsupported distance functions.""" # Test that creating config with unsupported distance function raises ValidationError diff --git a/api/providers/vdb/vdb-analyticdb/pyproject.toml b/api/providers/vdb/vdb-analyticdb/pyproject.toml new file mode 100644 index 00000000000..af5def3061d --- /dev/null +++ b/api/providers/vdb/vdb-analyticdb/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "dify-vdb-analyticdb" +version = "0.0.1" +dependencies = [ + "alibabacloud_gpdb20160503~=5.2.0", + "alibabacloud_tea_openapi~=0.4.3", + "clickhouse-connect~=0.15.0", +] +description = "Dify vector store backend (dify-vdb-analyticdb)." + +[project.entry-points."dify.vector_backends"] +analyticdb = "dify_vdb_analyticdb.analyticdb_vector:AnalyticdbVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/lindorm/__init__.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/lindorm/__init__.py rename to api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/__init__.py diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py similarity index 93% rename from api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py rename to api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py index ddb549ba9d2..e56bb74ba32 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py @@ -2,16 +2,16 @@ import json from typing import Any from configs import dify_config -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( - AnalyticdbVectorOpenAPI, - AnalyticdbVectorOpenAPIConfig, -) -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from dify_vdb_analyticdb.analyticdb_vector_openapi import ( + AnalyticdbVectorOpenAPI, + AnalyticdbVectorOpenAPIConfig, +) +from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig from models.dataset import Dataset @@ -37,11 +37,12 @@ class AnalyticdbVector(BaseVector): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) - self.analyticdb_vector._create_collection_if_not_exists(dimension) + self.analyticdb_vector.create_collection_if_not_exists(dimension) self.analyticdb_vector.add_texts(texts, embeddings) - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: self.analyticdb_vector.add_texts(documents, embeddings) + return [] def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_openapi.py similarity index 96% rename from api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py rename to api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_openapi.py index ce626bbd7e1..f13d9c0817a 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, TypedDict from pydantic import BaseModel, model_validator @@ -13,6 +13,13 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client +class AnalyticdbClientParamsDict(TypedDict): + access_key_id: str + access_key_secret: str + region_id: str + read_timeout: int + + class AnalyticdbVectorOpenAPIConfig(BaseModel): access_key_id: str access_key_secret: str @@ -27,7 +34,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["access_key_id"]: raise ValueError("config ANALYTICDB_KEY_ID is required") if not values["access_key_secret"]: @@ -44,13 +51,14 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required") return values - def to_analyticdb_client_params(self): - return { + def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict: + result: AnalyticdbClientParamsDict = { "access_key_id": self.access_key_id, "access_key_secret": self.access_key_secret, "region_id": self.region_id, "read_timeout": self.read_timeout, } + return result class AnalyticdbVectorOpenAPI: @@ -115,7 +123,7 @@ class AnalyticdbVectorOpenAPI: else: raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") - def _create_collection_if_not_exists(self, embedding_dimension: int): + def create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py similarity index 97% rename from api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py rename to api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py index 12126f32d66..11398efb586 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py @@ -1,5 +1,6 @@ import json import uuid +from collections.abc import Generator # Added Generator from contextlib import contextmanager from typing import Any @@ -23,7 +24,7 @@ class AnalyticdbVectorBySqlConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config ANALYTICDB_HOST is required") if not values["port"]: @@ -74,7 +75,7 @@ class AnalyticdbVectorBySql: ) @contextmanager - def _get_cursor(self): + def _get_cursor(self) -> Generator[Any, None, None]: # Changed from Iterator[Any] assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() @@ -130,7 +131,7 @@ class AnalyticdbVectorBySql: ) cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}") - def _create_collection_if_not_exists(self, embedding_dimension: int): + def create_collection_if_not_exists(self, embedding_dimension: int): cache_key = f"vector_indexing_{self._collection_name}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): diff --git a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py b/api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py similarity index 79% rename from api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py rename to api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py index 09815238097..2bb413dcc12 100644 --- a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py +++ b/api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py @@ -1,9 +1,8 @@ -from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest +from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector +from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest class AnalyticdbVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector.py similarity index 92% rename from api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py rename to api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector.py index 545565cdf4a..d1d471761d8 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py +++ b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector.py @@ -1,12 +1,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +import dify_vdb_analyticdb.analyticdb_vector as analyticdb_module import pytest +from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory +from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig -import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module -from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig from core.rag.models.document import Document @@ -71,7 +71,7 @@ def test_vector_methods_delegate_to_underlying_implementation(): assert vector.search_by_full_text("hello", top_k=2) == runner.search_by_full_text.return_value vector.delete() - runner._create_collection_if_not_exists.assert_called_once_with(2) + runner.create_collection_if_not_exists.assert_called_once_with(2) runner.add_texts.assert_any_call(texts, [[0.1, 0.2]]) runner.delete_by_ids.assert_called_once_with(["d1"]) runner.delete_by_metadata_field.assert_called_once_with("document_id", "doc-1") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_openapi.py similarity index 97% rename from api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py rename to api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_openapi.py index 45777774d03..d2d735ae3ea 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py +++ b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_openapi.py @@ -4,13 +4,13 @@ import types from types import SimpleNamespace from unittest.mock import MagicMock +import dify_vdb_analyticdb.analyticdb_vector_openapi as openapi_module import pytest - -import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( +from dify_vdb_analyticdb.analyticdb_vector_openapi import ( AnalyticdbVectorOpenAPI, AnalyticdbVectorOpenAPIConfig, ) + from core.rag.models.document import Document @@ -249,7 +249,7 @@ def test_create_collection_if_not_exists_creates_when_missing(monkeypatch): vector._client = MagicMock() vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=404) - vector._create_collection_if_not_exists(embedding_dimension=1024) + vector.create_collection_if_not_exists(embedding_dimension=1024) vector._client.create_collection.assert_called_once() openapi_module.redis_client.set.assert_called_once() @@ -268,7 +268,7 @@ def test_create_collection_if_not_exists_skips_when_cached(monkeypatch): vector.config = _config() vector._client = MagicMock() - vector._create_collection_if_not_exists(embedding_dimension=1024) + vector.create_collection_if_not_exists(embedding_dimension=1024) vector._client.describe_collection.assert_not_called() vector._client.create_collection.assert_not_called() @@ -290,7 +290,7 @@ def test_create_collection_if_not_exists_raises_on_non_404_errors(monkeypatch): vector._client.describe_collection.side_effect = stubs.TeaException(statusCode=500) with pytest.raises(ValueError, match="failed to create collection collection_1"): - vector._create_collection_if_not_exists(embedding_dimension=512) + vector.create_collection_if_not_exists(embedding_dimension=512) def test_openapi_add_delete_and_search_methods(monkeypatch): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_sql.py similarity index 98% rename from api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py rename to api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_sql.py index 8f1206696bb..49a2ae72d0a 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py +++ b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_sql.py @@ -2,14 +2,14 @@ from contextlib import contextmanager from types import SimpleNamespace from unittest.mock import MagicMock +import dify_vdb_analyticdb.analyticdb_vector_sql as sql_module import psycopg2.errors import pytest - -import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import ( +from dify_vdb_analyticdb.analyticdb_vector_sql import ( AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig, ) + from core.rag.models.document import Document @@ -374,7 +374,7 @@ def test_create_collection_if_not_exists_creates_table_indexes_and_cache(monkeyp vector._get_cursor = _cursor_context - vector._create_collection_if_not_exists(embedding_dimension=3) + vector.create_collection_if_not_exists(embedding_dimension=3) assert any("CREATE TABLE IF NOT EXISTS dify.collection" in call.args[0] for call in cursor.execute.call_args_list) assert any("CREATE INDEX collection_embedding_idx" in call.args[0] for call in cursor.execute.call_args_list) @@ -404,7 +404,7 @@ def test_create_collection_if_not_exists_raises_for_non_existing_error(monkeypat vector._get_cursor = _cursor_context with pytest.raises(RuntimeError, match="permission denied"): - vector._create_collection_if_not_exists(embedding_dimension=3) + vector.create_collection_if_not_exists(embedding_dimension=3) def test_delete_methods_raise_when_error_is_not_missing_table(): diff --git a/api/providers/vdb/vdb-baidu/pyproject.toml b/api/providers/vdb/vdb-baidu/pyproject.toml new file mode 100644 index 00000000000..bacff08793a --- /dev/null +++ b/api/providers/vdb/vdb-baidu/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "dify-vdb-baidu" +version = "0.0.1" +dependencies = [ + "pymochow==2.4.0", +] +description = "Dify vector store backend (dify-vdb-baidu)." + +[project.entry-points."dify.vector_backends"] +baidu = "dify_vdb_baidu.baidu_vector:BaiduVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/matrixone/__init__.py b/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/matrixone/__init__.py rename to api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/__init__.py diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py similarity index 97% rename from api/core/rag/datasource/vdb/baidu/baidu_vector.py rename to api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py index 2b220fc04dc..bdd5a42c873 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py @@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams, from configs import dify_config from core.rag.datasource.vdb.field import Field as VDBField from core.rag.datasource.vdb.field import parse_metadata_json -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -59,7 +59,7 @@ class BaiduConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["endpoint"]: raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required") if not values["account"]: @@ -85,8 +85,12 @@ class BaiduVector(BaseVector): def get_type(self) -> str: return VectorType.BAIDU - def to_index_struct(self): - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + def to_index_struct(self) -> VectorIndexStructDict: + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): self._create_table(len(embeddings[0])) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/providers/vdb/vdb-baidu/tests/integration_tests/conftest.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/baiduvectordb.py rename to api/providers/vdb/vdb-baidu/tests/integration_tests/conftest.py diff --git a/api/tests/integration_tests/vdb/baidu/test_baidu.py b/api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py similarity index 73% rename from api/tests/integration_tests/vdb/baidu/test_baidu.py rename to api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py index 716f88af672..2c1d0e3554a 100644 --- a/api/tests/integration_tests/vdb/baidu/test_baidu.py +++ b/api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py @@ -1,10 +1,6 @@ -from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.baiduvectordb", -) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text class BaiduVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py b/api/providers/vdb/vdb-baidu/tests/unit_tests/test_baidu_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py rename to api/providers/vdb/vdb-baidu/tests/unit_tests/test_baidu_vector.py index 487d0216970..851c09f47a3 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py +++ b/api/providers/vdb/vdb-baidu/tests/unit_tests/test_baidu_vector.py @@ -124,7 +124,7 @@ def _build_fake_pymochow_modules(): def baidu_module(monkeypatch): for name, module in _build_fake_pymochow_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.baidu.baidu_vector as module + import dify_vdb_baidu.baidu_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-chroma/pyproject.toml b/api/providers/vdb/vdb-chroma/pyproject.toml new file mode 100644 index 00000000000..b37ee2a588f --- /dev/null +++ b/api/providers/vdb/vdb-chroma/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "dify-vdb-chroma" +version = "0.0.1" +dependencies = [ + "chromadb==0.5.20", +] +description = "Dify vector store backend (dify-vdb-chroma)." + +[project.entry-points."dify.vector_backends"] +chroma = "dify_vdb_chroma.chroma_vector:ChromaVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/milvus/__init__.py b/api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/milvus/__init__.py rename to api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/__init__.py diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/chroma_vector.py similarity index 86% rename from api/core/rag/datasource/vdb/chroma/chroma_vector.py rename to api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/chroma_vector.py index cbc846f7162..5b0cfbea151 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/chroma_vector.py @@ -1,12 +1,12 @@ import json -from typing import Any +from typing import Any, TypedDict import chromadb -from chromadb import QueryResult, Settings +from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage] from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -15,6 +15,15 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset +class ChromaParamsDict(TypedDict): + host: str + port: int + ssl: bool + tenant: str + database: str + settings: Settings + + class ChromaConfig(BaseModel): host: str port: int @@ -23,14 +32,13 @@ class ChromaConfig(BaseModel): auth_provider: str | None = None auth_credentials: str | None = None - def to_chroma_params(self): + def to_chroma_params(self) -> ChromaParamsDict: settings = Settings( # auth chroma_client_auth_provider=self.auth_provider, chroma_client_auth_credentials=self.auth_credentials, ) - - return { + result: ChromaParamsDict = { "host": self.host, "port": self.port, "ssl": False, @@ -38,6 +46,7 @@ class ChromaConfig(BaseModel): "database": self.database, "settings": settings, } + return result class ChromaVector(BaseVector): @@ -97,14 +106,15 @@ class ChromaVector(BaseVector): def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) document_ids_filter = kwargs.get("document_ids_filter") + results: QueryResult if document_ids_filter: - results: QueryResult = collection.query( + results = collection.query( query_embeddings=query_vector, n_results=kwargs.get("top_k", 4), where={"document_id": {"$in": document_ids_filter}}, # type: ignore ) else: - results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore + results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore score_threshold = float(kwargs.get("score_threshold") or 0.0) # Check if results contain data @@ -145,7 +155,10 @@ class ChromaVectorFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}} + index_struct_dict: VectorIndexStructDict = { + "type": VectorType.CHROMA, + "vector_store": {"class_prefix": collection_name}, + } dataset.index_struct = json.dumps(index_struct_dict) return ChromaVector( @@ -153,8 +166,8 @@ class ChromaVectorFactory(AbstractVectorFactory): config=ChromaConfig( host=dify_config.CHROMA_HOST or "", port=dify_config.CHROMA_PORT, - tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, - database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, + tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage] + database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage] auth_provider=dify_config.CHROMA_AUTH_PROVIDER, auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS, ), diff --git a/api/tests/integration_tests/vdb/chroma/test_chroma.py b/api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py similarity index 80% rename from api/tests/integration_tests/vdb/chroma/test_chroma.py rename to api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py index 52beba9979d..87c259f3d00 100644 --- a/api/tests/integration_tests/vdb/chroma/test_chroma.py +++ b/api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py @@ -1,13 +1,11 @@ import chromadb +from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector -from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector -from tests.integration_tests.vdb.test_vector_store import ( +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_text, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class ChromaVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py b/api/providers/vdb/vdb-chroma/tests/unit_tests/test_chroma_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py rename to api/providers/vdb/vdb-chroma/tests/unit_tests/test_chroma_vector.py index 44427b7d879..b209c9df967 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py +++ b/api/providers/vdb/vdb-chroma/tests/unit_tests/test_chroma_vector.py @@ -47,7 +47,7 @@ def _build_fake_chroma_modules(): def chroma_module(monkeypatch): fake_chroma = _build_fake_chroma_modules() monkeypatch.setitem(sys.modules, "chromadb", fake_chroma) - import core.rag.datasource.vdb.chroma.chroma_vector as module + import dify_vdb_chroma.chroma_vector as module return importlib.reload(module) diff --git a/api/core/rag/datasource/vdb/clickzetta/README.md b/api/providers/vdb/vdb-clickzetta/README.md similarity index 99% rename from api/core/rag/datasource/vdb/clickzetta/README.md rename to api/providers/vdb/vdb-clickzetta/README.md index 969d4e40a09..faa76707ce0 100644 --- a/api/core/rag/datasource/vdb/clickzetta/README.md +++ b/api/providers/vdb/vdb-clickzetta/README.md @@ -198,4 +198,4 @@ Clickzetta supports advanced full-text search with multiple analyzers: - [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search) - [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index) -- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference) +- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference) \ No newline at end of file diff --git a/api/providers/vdb/vdb-clickzetta/pyproject.toml b/api/providers/vdb/vdb-clickzetta/pyproject.toml new file mode 100644 index 00000000000..aea94fdb2a1 --- /dev/null +++ b/api/providers/vdb/vdb-clickzetta/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-clickzetta" +version = "0.0.1" + +dependencies = [ + "clickzetta-connector-python>=0.8.102", +] +description = "Dify vector store backend (dify-vdb-clickzetta)." + +[project.entry-points."dify.vector_backends"] +clickzetta = "dify_vdb_clickzetta.clickzetta_vector:ClickzettaVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/clickzetta/__init__.py b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/clickzetta/__init__.py rename to api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/__init__.py diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py rename to api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py index a4dddc68f07..72b8c5e9ebf 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py @@ -51,7 +51,7 @@ class ClickzettaConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): """ Validate the configuration values. """ diff --git a/api/tests/integration_tests/vdb/clickzetta/README.md b/api/providers/vdb/vdb-clickzetta/tests/README.md similarity index 100% rename from api/tests/integration_tests/vdb/clickzetta/README.md rename to api/providers/vdb/vdb-clickzetta/tests/README.md diff --git a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py b/api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_clickzetta.py similarity index 92% rename from api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py rename to api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_clickzetta.py index 21de8be6e3a..1c6819f9f10 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py +++ b/api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_clickzetta.py @@ -2,10 +2,10 @@ import contextlib import os import pytest +from dify_vdb_clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector -from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis class TestClickzettaVector(AbstractVectorTest): @@ -14,9 +14,8 @@ class TestClickzettaVector(AbstractVectorTest): """ @pytest.fixture - def vector_store(self): + def vector_store(self, setup_mock_redis): """Create a Clickzetta vector store instance for testing.""" - # Skip test if Clickzetta credentials are not configured if not os.getenv("CLICKZETTA_USERNAME"): pytest.skip("CLICKZETTA_USERNAME is not configured") if not os.getenv("CLICKZETTA_PASSWORD"): @@ -32,21 +31,19 @@ class TestClickzettaVector(AbstractVectorTest): workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"), vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"), - batch_size=10, # Small batch size for testing + batch_size=10, enable_inverted_index=True, analyzer_type="chinese", analyzer_mode="smart", vector_distance_function="cosine_distance", ) - with setup_mock_redis(): - vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config) + vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config) - yield vector + yield vector - # Cleanup: delete the test collection - with contextlib.suppress(Exception): - vector.delete() + with contextlib.suppress(Exception): + vector.delete() def test_clickzetta_vector_basic_operations(self, vector_store): """Test basic CRUD operations on Clickzetta vector store.""" diff --git a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py b/api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_docker_integration.py similarity index 55% rename from api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py rename to api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_docker_integration.py index 60e3f30f26a..a5d32f5e81c 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py +++ b/api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_docker_integration.py @@ -3,16 +3,19 @@ Test Clickzetta integration in Docker environment """ +import logging import os import time import httpx from clickzetta import connect +logger = logging.getLogger(__name__) + def test_clickzetta_connection(): """Test direct connection to Clickzetta""" - print("=== Testing direct Clickzetta connection ===") + logger.info("=== Testing direct Clickzetta connection ===") try: conn = connect( username=os.getenv("CLICKZETTA_USERNAME", "test_user"), @@ -25,100 +28,93 @@ def test_clickzetta_connection(): ) with conn.cursor() as cursor: - # Test basic connectivity cursor.execute("SELECT 1 as test") result = cursor.fetchone() - print(f"✓ Connection test: {result}") + logger.info("✓ Connection test: %s", result) - # Check if our test table exists cursor.execute("SHOW TABLES IN dify") tables = cursor.fetchall() - print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}") + logger.info("✓ Existing tables: %s", [t[1] for t in tables if t[0] == "dify"]) - # Check if test collection exists test_collection = "collection_test_dataset" if test_collection in [t[1] for t in tables if t[0] == "dify"]: cursor.execute(f"DESCRIBE dify.{test_collection}") columns = cursor.fetchall() - print(f"✓ Table structure for {test_collection}:") + logger.info("✓ Table structure for %s:", test_collection) for col in columns: - print(f" - {col[0]}: {col[1]}") + logger.info(" - %s: %s", col[0], col[1]) - # Check for indexes cursor.execute(f"SHOW INDEXES IN dify.{test_collection}") indexes = cursor.fetchall() - print(f"✓ Indexes on {test_collection}:") + logger.info("✓ Indexes on %s:", test_collection) for idx in indexes: - print(f" - {idx}") + logger.info(" - %s", idx) return True - except Exception as e: - print(f"✗ Connection test failed: {e}") + except Exception: + logger.exception("✗ Connection test failed") return False def test_dify_api(): """Test Dify API with Clickzetta backend""" - print("\n=== Testing Dify API ===") + logger.info("\n=== Testing Dify API ===") base_url = "http://localhost:5001" - # Wait for API to be ready max_retries = 30 for i in range(max_retries): try: response = httpx.get(f"{base_url}/console/api/health") if response.status_code == 200: - print("✓ Dify API is ready") + logger.info("✓ Dify API is ready") break except: if i == max_retries - 1: - print("✗ Dify API is not responding") + logger.exception("✗ Dify API is not responding") return False time.sleep(2) - # Check vector store configuration try: - # This is a simplified check - in production, you'd use proper auth - print("✓ Dify is configured to use Clickzetta as vector store") + logger.info("✓ Dify is configured to use Clickzetta as vector store") return True - except Exception as e: - print(f"✗ API test failed: {e}") + except Exception: + logger.exception("✗ API test failed") return False def verify_table_structure(): """Verify the table structure meets Dify requirements""" - print("\n=== Verifying Table Structure ===") + logger.info("\n=== Verifying Table Structure ===") expected_columns = { "id": "VARCHAR", "page_content": "VARCHAR", - "metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta + "metadata": "VARCHAR", "vector": "ARRAY", } expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"] - print("✓ Expected table structure:") + logger.info("✓ Expected table structure:") for col, dtype in expected_columns.items(): - print(f" - {col}: {dtype}") + logger.info(" - %s: %s", col, dtype) - print("\n✓ Required metadata fields:") + logger.info("\n✓ Required metadata fields:") for field in expected_metadata_fields: - print(f" - {field}") + logger.info(" - %s", field) - print("\n✓ Index requirements:") - print(" - Vector index (HNSW) on 'vector' column") - print(" - Full-text index on 'page_content' (optional)") - print(" - Functional index on metadata->>'$.doc_id' (recommended)") - print(" - Functional index on metadata->>'$.document_id' (recommended)") + logger.info("\n✓ Index requirements:") + logger.info(" - Vector index (HNSW) on 'vector' column") + logger.info(" - Full-text index on 'page_content' (optional)") + logger.info(" - Functional index on metadata->>'$.doc_id' (recommended)") + logger.info(" - Functional index on metadata->>'$.document_id' (recommended)") return True def main(): """Run all tests""" - print("Starting Clickzetta integration tests for Dify Docker\n") + logger.info("Starting Clickzetta integration tests for Dify Docker\n") tests = [ ("Direct Clickzetta Connection", test_clickzetta_connection), @@ -131,33 +127,34 @@ def main(): try: success = test_func() results.append((test_name, success)) - except Exception as e: - print(f"\n✗ {test_name} crashed: {e}") + except Exception: + logger.exception("\n✗ %s crashed", test_name) results.append((test_name, False)) - # Summary - print("\n" + "=" * 50) - print("Test Summary:") - print("=" * 50) + logger.info("\n%s", "=" * 50) + logger.info("Test Summary:") + logger.info("=" * 50) passed = sum(1 for _, success in results if success) total = len(results) for test_name, success in results: status = "✅ PASSED" if success else "❌ FAILED" - print(f"{test_name}: {status}") + logger.info("%s: %s", test_name, status) - print(f"\nTotal: {passed}/{total} tests passed") + logger.info("\nTotal: %s/%s tests passed", passed, total) if passed == total: - print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.") - print("\nNext steps:") - print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d") - print("2. Access Dify at http://localhost:3000") - print("3. Create a dataset and test vector storage with Clickzetta") + logger.info("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.") + logger.info("\nNext steps:") + logger.info( + "1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d" + ) + logger.info("2. Access Dify at http://localhost:3000") + logger.info("3. Create a dataset and test vector storage with Clickzetta") return 0 else: - print("\n⚠️ Some tests failed. Please check the errors above.") + logger.error("\n⚠️ Some tests failed. Please check the errors above.") return 1 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py b/api/providers/vdb/vdb-clickzetta/tests/unit_tests/test_clickzetta_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py rename to api/providers/vdb/vdb-clickzetta/tests/unit_tests/test_clickzetta_vector.py index 0ce5c04dd61..a7473f1b917 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py +++ b/api/providers/vdb/vdb-clickzetta/tests/unit_tests/test_clickzetta_vector.py @@ -47,7 +47,7 @@ def _build_fake_clickzetta_module(): @pytest.fixture def clickzetta_module(monkeypatch): monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module()) - import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module + import dify_vdb_clickzetta.clickzetta_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-couchbase/pyproject.toml b/api/providers/vdb/vdb-couchbase/pyproject.toml new file mode 100644 index 00000000000..6bc348b2eb5 --- /dev/null +++ b/api/providers/vdb/vdb-couchbase/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-couchbase" +version = "0.0.1" + +dependencies = [ + "couchbase~=4.6.0", +] +description = "Dify vector store backend (dify-vdb-couchbase)." + +[project.entry-points."dify.vector_backends"] +couchbase = "dify_vdb_couchbase.couchbase_vector:CouchbaseVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/myscale/__init__.py b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/myscale/__init__.py rename to api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/__init__.py diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/couchbase/couchbase_vector.py rename to api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py index 9a4a65cf6f0..815ac30c0bd 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py @@ -36,7 +36,7 @@ class CouchbaseConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values.get("connection_string"): raise ValueError("config COUCHBASE_CONNECTION_STRING is required") if not values.get("user"): diff --git a/api/tests/integration_tests/vdb/couchbase/test_couchbase.py b/api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py similarity index 80% rename from api/tests/integration_tests/vdb/couchbase/test_couchbase.py rename to api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py index 0371f042330..918dae328fa 100644 --- a/api/tests/integration_tests/vdb/couchbase/test_couchbase.py +++ b/api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py @@ -1,12 +1,14 @@ +import logging import subprocess import time -from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +logger = logging.getLogger(__name__) def wait_for_healthy_container(service_name="couchbase-server", timeout=300): @@ -16,10 +18,10 @@ def wait_for_healthy_container(service_name="couchbase-server", timeout=300): ["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True ) if result.stdout.strip() == "healthy": - print(f"{service_name} is healthy!") + logger.info("%s is healthy!", service_name) return True else: - print(f"Waiting for {service_name} to be healthy...") + logger.info("Waiting for %s to be healthy...", service_name) time.sleep(10) raise TimeoutError(f"{service_name} did not become healthy in time") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py b/api/providers/vdb/vdb-couchbase/tests/unit_tests/test_couchbase_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py rename to api/providers/vdb/vdb-couchbase/tests/unit_tests/test_couchbase_vector.py index 9fea187615e..7e5c40b8f2f 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py +++ b/api/providers/vdb/vdb-couchbase/tests/unit_tests/test_couchbase_vector.py @@ -154,7 +154,7 @@ def couchbase_module(monkeypatch): for name, module in _build_fake_couchbase_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.couchbase.couchbase_vector as module + import dify_vdb_couchbase.couchbase_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-elasticsearch/pyproject.toml b/api/providers/vdb/vdb-elasticsearch/pyproject.toml new file mode 100644 index 00000000000..d40908f92da --- /dev/null +++ b/api/providers/vdb/vdb-elasticsearch/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "dify-vdb-elasticsearch" +version = "0.0.1" + +dependencies = [ + "elasticsearch==8.14.0", +] +description = "Dify vector store backend (dify-vdb-elasticsearch)." + +[project.entry-points."dify.vector_backends"] +elasticsearch = "dify_vdb_elasticsearch.elasticsearch_vector:ElasticSearchVectorFactory" +elasticsearch-ja = "dify_vdb_elasticsearch.elasticsearch_ja_vector:ElasticSearchJaVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/oceanbase/__init__.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/oceanbase/__init__.py rename to api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/__init__.py diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py similarity index 97% rename from api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py rename to api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py index 1e7fe526660..e2f390402ad 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py @@ -4,14 +4,14 @@ from typing import Any from flask import current_app -from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ( +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from dify_vdb_elasticsearch.elasticsearch_vector import ( ElasticSearchConfig, ElasticSearchVector, ElasticSearchVectorFactory, ) -from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.vector_type import VectorType -from core.rag.embedding.embedding_base import Embeddings from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -23,7 +23,7 @@ class ElasticSearchJaVector(ElasticSearchVector): self, embeddings: list[list[float]], metadatas: list[dict[Any, Any]] | None = None, - index_params: dict | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py rename to api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py index 1470713b887..11463b6c58c 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py @@ -43,7 +43,7 @@ class ElasticSearchConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): use_cloud = values.get("use_cloud", False) cloud_url = values.get("cloud_url") @@ -258,7 +258,7 @@ class ElasticSearchVector(BaseVector): self, embeddings: list[list[float]], metadatas: list[dict[Any, Any]] | None = None, - index_params: dict | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/providers/vdb/vdb-elasticsearch/tests/integration_tests/test_elasticsearch.py similarity index 71% rename from api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py rename to api/providers/vdb/vdb-elasticsearch/tests/integration_tests/test_elasticsearch.py index 970d2cce1ae..c8b679e0219 100644 --- a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py +++ b/api/providers/vdb/vdb-elasticsearch/tests/integration_tests/test_elasticsearch.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class ElasticSearchVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py b/api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_ja_vector.py similarity index 96% rename from api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py rename to api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_ja_vector.py index edd29a46491..f81ed6beeab 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_ja_vector.py @@ -32,8 +32,8 @@ def elasticsearch_ja_module(monkeypatch): for name, module in _build_fake_elasticsearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module - import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module + import dify_vdb_elasticsearch.elasticsearch_ja_vector as ja_module + import dify_vdb_elasticsearch.elasticsearch_vector as base_module importlib.reload(base_module) return importlib.reload(ja_module) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py b/api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py rename to api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_vector.py index 9ecf0caa244..48f1f6dc266 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_vector.py @@ -42,7 +42,7 @@ def elasticsearch_module(monkeypatch): for name, module in _build_fake_elasticsearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module + import dify_vdb_elasticsearch.elasticsearch_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-hologres/pyproject.toml b/api/providers/vdb/vdb-hologres/pyproject.toml new file mode 100644 index 00000000000..88044bf6d6f --- /dev/null +++ b/api/providers/vdb/vdb-hologres/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-hologres" +version = "0.0.1" + +dependencies = [ + "holo-search-sdk>=0.4.2", +] +description = "Dify vector store backend (dify-vdb-hologres)." + +[project.entry-points."dify.vector_backends"] +hologres = "dify_vdb_hologres.hologres_vector:HologresVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/opengauss/__init__.py b/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/opengauss/__init__.py rename to api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/__init__.py diff --git a/api/core/rag/datasource/vdb/hologres/hologres_vector.py b/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py similarity index 97% rename from api/core/rag/datasource/vdb/hologres/hologres_vector.py rename to api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py index 13d48b5668d..80c0ed582e0 100644 --- a/api/core/rag/datasource/vdb/hologres/hologres_vector.py +++ b/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py @@ -1,7 +1,7 @@ import json import logging import time -from typing import Any +from typing import Any, cast import holo_search_sdk as holo # type: ignore from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType @@ -43,7 +43,7 @@ class HologresVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values.get("host"): raise ValueError("config HOLOGRES_HOST is required") if not values.get("database"): @@ -351,9 +351,9 @@ class HologresVectorFactory(AbstractVectorFactory): access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "", access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "", schema_name=dify_config.HOLOGRES_SCHEMA, - tokenizer=dify_config.HOLOGRES_TOKENIZER, - distance_method=dify_config.HOLOGRES_DISTANCE_METHOD, - base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE, + tokenizer=cast(TokenizerType, dify_config.HOLOGRES_TOKENIZER), + distance_method=cast(DistanceType, dify_config.HOLOGRES_DISTANCE_METHOD), + base_quantization_type=cast(BaseQuantizationType, dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE), max_degree=dify_config.HOLOGRES_MAX_DEGREE, ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION, ), diff --git a/api/tests/integration_tests/vdb/__mock/hologres.py b/api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py similarity index 82% rename from api/tests/integration_tests/vdb/__mock/hologres.py rename to api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py index b60cf358c04..d28ded01873 100644 --- a/api/tests/integration_tests/vdb/__mock/hologres.py +++ b/api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py @@ -7,13 +7,10 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from psycopg import sql as psql -# Shared in-memory storage: {table_name: {doc_id: {"id", "text", "meta", "embedding"}}} _mock_tables: dict[str, dict[str, dict[str, Any]]] = {} class MockSearchQuery: - """Mock query builder for search_vector and search_text results.""" - def __init__(self, table_name: str, search_type: str): self._table_name = table_name self._search_type = search_type @@ -32,17 +29,13 @@ class MockSearchQuery: return self def _apply_filter(self, row: dict[str, Any]) -> bool: - """Apply the filter SQL to check if a row matches.""" if self._filter_sql is None: return True - # Extract literals (the document IDs) from the filter SQL - # Filter format: meta->>'document_id' IN ('doc1', 'doc2') literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"] if not literals: return True - # Get the document_id from the row's meta field meta = row.get("meta", "{}") if isinstance(meta, str): meta = json.loads(meta) @@ -54,22 +47,17 @@ class MockSearchQuery: data = _mock_tables.get(self._table_name, {}) results = [] for row in list(data.values())[: self._limit_val]: - # Apply filter if present if not self._apply_filter(row): continue if self._search_type == "vector": - # row format expected by _process_vector_results: (distance, id, text, meta) results.append((0.1, row["id"], row["text"], row["meta"])) else: - # row format expected by _process_full_text_results: (id, text, meta, embedding, score) results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9)) return results class MockTable: - """Mock table object returned by client.open_table().""" - def __init__(self, table_name: str): self._table_name = table_name @@ -97,7 +85,6 @@ class MockTable: def _extract_sql_template(query) -> str: - """Extract the SQL template string from a psycopg Composed object.""" if isinstance(query, psql.Composed): for part in query: if isinstance(part, psql.SQL): @@ -108,7 +95,6 @@ def _extract_sql_template(query) -> str: def _extract_identifiers_and_literals(query) -> list[Any]: - """Extract Identifier and Literal values from a psycopg Composed object.""" values: list[Any] = [] if isinstance(query, psql.Composed): for part in query: @@ -117,7 +103,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]: elif isinstance(part, psql.Literal): values.append(("literal", part._obj)) elif isinstance(part, psql.Composed): - # Handles SQL(...).join(...) for IN clauses for sub in part: if isinstance(sub, psql.Literal): values.append(("literal", sub._obj)) @@ -125,8 +110,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]: class MockHologresClient: - """Mock holo_search_sdk client that stores data in memory.""" - def connect(self): pass @@ -141,21 +124,18 @@ class MockHologresClient: params = _extract_identifiers_and_literals(query) if "CREATE TABLE" in template.upper(): - # Extract table name from first identifier table_name = next((v for t, v in params if t == "ident"), "unknown") if table_name not in _mock_tables: _mock_tables[table_name] = {} return None if "SELECT 1" in template: - # text_exists: SELECT 1 FROM {table} WHERE id = {id} LIMIT 1 table_name = next((v for t, v in params if t == "ident"), "") doc_id = next((v for t, v in params if t == "literal"), "") data = _mock_tables.get(table_name, {}) return [(1,)] if doc_id in data else [] if "SELECT id" in template: - # get_ids_by_metadata_field: SELECT id FROM {table} WHERE meta->>{key} = {value} table_name = next((v for t, v in params if t == "ident"), "") literals = [v for t, v in params if t == "literal"] key = literals[0] if len(literals) > 0 else "" @@ -166,12 +146,10 @@ class MockHologresClient: if "DELETE" in template.upper(): table_name = next((v for t, v in params if t == "ident"), "") if "id IN" in template: - # delete_by_ids ids_to_delete = [v for t, v in params if t == "literal"] for did in ids_to_delete: _mock_tables.get(table_name, {}).pop(did, None) elif "meta->>" in template: - # delete_by_metadata_field literals = [v for t, v in params if t == "literal"] key = literals[0] if len(literals) > 0 else "" value = literals[1] if len(literals) > 1 else "" @@ -190,7 +168,6 @@ class MockHologresClient: def mock_connect(**kwargs): - """Replacement for holo_search_sdk.connect() that returns a mock client.""" return MockHologresClient() diff --git a/api/tests/integration_tests/vdb/hologres/test_hologres.py b/api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py similarity index 94% rename from api/tests/integration_tests/vdb/hologres/test_hologres.py rename to api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py index d81e18841e0..04024be4aea 100644 --- a/api/tests/integration_tests/vdb/hologres/test_hologres.py +++ b/api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py @@ -2,16 +2,11 @@ import os import uuid from typing import cast +from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType -from core.rag.datasource.vdb.hologres.hologres_vector import HologresVector, HologresVectorConfig +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text - -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.hologres", -) MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py b/api/providers/vdb/vdb-hologres/tests/unit_tests/test_hologres_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py rename to api/providers/vdb/vdb-hologres/tests/unit_tests/test_hologres_vector.py index 5d9e744ded4..f9a557ecce6 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py +++ b/api/providers/vdb/vdb-hologres/tests/unit_tests/test_hologres_vector.py @@ -42,7 +42,7 @@ def hologres_module(monkeypatch): for name, module in _build_fake_hologres_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.hologres.hologres_vector as module + import dify_vdb_hologres.hologres_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-huawei-cloud/pyproject.toml b/api/providers/vdb/vdb-huawei-cloud/pyproject.toml new file mode 100644 index 00000000000..71af56786cf --- /dev/null +++ b/api/providers/vdb/vdb-huawei-cloud/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-huawei-cloud" +version = "0.0.1" + +dependencies = [ + "elasticsearch==8.14.0", +] +description = "Dify vector store backend (dify-vdb-huawei-cloud)." + +[project.entry-points."dify.vector_backends"] +huawei_cloud = "dify_vdb_huawei_cloud.huawei_cloud_vector:HuaweiCloudVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/opensearch/__init__.py b/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/opensearch/__init__.py rename to api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/__init__.py diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py similarity index 91% rename from api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py rename to api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py index df02c584ede..d51075d2e8f 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py @@ -5,6 +5,7 @@ from typing import Any from elasticsearch import Elasticsearch from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -19,6 +20,16 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class HuaweiElasticsearchParamsDict(TypedDict, total=False): + hosts: list[str] + verify_certs: bool + ssl_show_warn: bool + request_timeout: int + retry_on_timeout: bool + max_retries: int + basic_auth: tuple[str, str] + + def create_ssl_context() -> ssl.SSLContext: ssl_context = ssl.create_default_context() ssl_context.check_hostname = False @@ -33,20 +44,20 @@ class HuaweiCloudVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["hosts"]: raise ValueError("config HOSTS is required") return values - def to_elasticsearch_params(self) -> dict[str, Any]: - params = { - "hosts": self.hosts.split(","), - "verify_certs": False, - "ssl_show_warn": False, - "request_timeout": 30000, - "retry_on_timeout": True, - "max_retries": 10, - } + def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict: + params = HuaweiElasticsearchParamsDict( + hosts=self.hosts.split(","), + verify_certs=False, + ssl_show_warn=False, + request_timeout=30000, + retry_on_timeout=True, + max_retries=10, + ) if self.username and self.password: params["basic_auth"] = (self.username, self.password) return params @@ -158,7 +169,7 @@ class HuaweiCloudVector(BaseVector): self, embeddings: list[list[float]], metadatas: list[dict[Any, Any]] | None = None, - index_params: dict | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py b/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/conftest.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py rename to api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/conftest.py diff --git a/api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py b/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py similarity index 69% rename from api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py rename to api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py index 01f511358af..bb5f5b72efa 100644 --- a/api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py +++ b/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py @@ -1,10 +1,6 @@ -from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_huawei_cloud.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.huaweicloudvectordb", -) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text class HuaweiCloudVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py b/api/providers/vdb/vdb-huawei-cloud/tests/unit_tests/test_huawei_cloud_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py rename to api/providers/vdb/vdb-huawei-cloud/tests/unit_tests/test_huawei_cloud_vector.py index 9d23dfcf631..ba3f14912b4 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py +++ b/api/providers/vdb/vdb-huawei-cloud/tests/unit_tests/test_huawei_cloud_vector.py @@ -33,7 +33,7 @@ def huawei_module(monkeypatch): for name, module in _build_fake_elasticsearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module + import dify_vdb_huawei_cloud.huawei_cloud_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-iris/pyproject.toml b/api/providers/vdb/vdb-iris/pyproject.toml new file mode 100644 index 00000000000..6dd7a8e073f --- /dev/null +++ b/api/providers/vdb/vdb-iris/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-iris" +version = "0.0.1" + +dependencies = [ + "intersystems-irispython>=5.1.0", +] +description = "Dify vector store backend (dify-vdb-iris)." + +[project.entry-points."dify.vector_backends"] +iris = "dify_vdb_iris.iris_vector:IrisVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/oracle/__init__.py b/api/providers/vdb/vdb-iris/src/dify_vdb_iris/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/oracle/__init__.py rename to api/providers/vdb/vdb-iris/src/dify_vdb_iris/__init__.py diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/providers/vdb/vdb-iris/src/dify_vdb_iris/iris_vector.py similarity index 100% rename from api/core/rag/datasource/vdb/iris/iris_vector.py rename to api/providers/vdb/vdb-iris/src/dify_vdb_iris/iris_vector.py diff --git a/api/tests/integration_tests/vdb/iris/test_iris.py b/api/providers/vdb/vdb-iris/tests/integration_tests/test_iris.py similarity index 85% rename from api/tests/integration_tests/vdb/iris/test_iris.py rename to api/providers/vdb/vdb-iris/tests/integration_tests/test_iris.py index 4b2da8387bd..8281e89c8ab 100644 --- a/api/tests/integration_tests/vdb/iris/test_iris.py +++ b/api/providers/vdb/vdb-iris/tests/integration_tests/test_iris.py @@ -1,12 +1,11 @@ """Integration tests for IRIS vector database.""" -from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_iris.iris_vector import IrisVector, IrisVectorConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class IrisVectorTest(AbstractVectorTest): """Test suite for IRIS vector store implementation.""" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py b/api/providers/vdb/vdb-iris/tests/unit_tests/test_iris_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py rename to api/providers/vdb/vdb-iris/tests/unit_tests/test_iris_vector.py index 63338ca809c..8c038e82b95 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py +++ b/api/providers/vdb/vdb-iris/tests/unit_tests/test_iris_vector.py @@ -26,7 +26,7 @@ def _build_fake_iris_module(): def iris_module(monkeypatch): monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module()) - import core.rag.datasource.vdb.iris.iris_vector as module + import dify_vdb_iris.iris_vector as module reloaded = importlib.reload(module) reloaded._pool_instance = None diff --git a/api/providers/vdb/vdb-lindorm/pyproject.toml b/api/providers/vdb/vdb-lindorm/pyproject.toml new file mode 100644 index 00000000000..0cffc674915 --- /dev/null +++ b/api/providers/vdb/vdb-lindorm/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "dify-vdb-lindorm" +version = "0.0.1" + +dependencies = [ + "opensearch-py==3.1.0", + "tenacity>=8.0.0", +] +description = "Dify vector store backend (dify-vdb-lindorm)." + +[project.entry-points."dify.vector_backends"] +lindorm = "dify_vdb_lindorm.lindorm_vector:LindormVectorStoreFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/__init__.py b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/pgvecto_rs/__init__.py rename to api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/__init__.py diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py similarity index 96% rename from api/core/rag/datasource/vdb/lindorm/lindorm_vector.py rename to api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py index bfcb620618d..9187ca943dc 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py @@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from tenacity import retry, stop_after_attempt, wait_exponential +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field" UGC_INDEX_PREFIX = "ugc_index" +class LindormOpenSearchParamsDict(TypedDict, total=False): + hosts: str | None + use_ssl: bool + pool_maxsize: int + timeout: int + http_auth: tuple[str, str] + + class LindormVectorStoreConfig(BaseModel): hosts: str | None username: str | None = None @@ -35,7 +44,7 @@ class LindormVectorStoreConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["hosts"]: raise ValueError("config URL is required") if not values["username"]: @@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel): raise ValueError("config PASSWORD is required") return values - def to_opensearch_params(self) -> dict[str, Any]: - params: dict[str, Any] = { - "hosts": self.hosts, - "use_ssl": False, - "pool_maxsize": 128, - "timeout": 30, - } + def to_opensearch_params(self) -> LindormOpenSearchParamsDict: + params = LindormOpenSearchParamsDict( + hosts=self.hosts, + use_ssl=False, + pool_maxsize=128, + timeout=30, + ) if self.username and self.password: params["http_auth"] = (self.username, self.password) return params @@ -327,7 +336,10 @@ class LindormVectorStore(BaseVector): return docs def create_collection( - self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None + self, + embeddings: list[list[float]], + metadatas: list[dict[str, Any]] | None = None, + index_params: dict[str, Any] | None = None, ): if not embeddings: raise ValueError(f"Embeddings list cannot be empty for collection create '{self._collection_name}'") diff --git a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py b/api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py similarity index 88% rename from api/tests/integration_tests/vdb/lindorm/test_lindorm.py rename to api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py index b24498fdfdd..0a0c2d2d591 100644 --- a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py +++ b/api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py @@ -1,9 +1,8 @@ import os -from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest +from dify_vdb_lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest class Config: diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py b/api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py rename to api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py index 34357d5907d..238145c1d66 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py +++ b/api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py @@ -51,7 +51,7 @@ def lindorm_module(monkeypatch): for name, module in _build_fake_opensearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.lindorm.lindorm_vector as module + import dify_vdb_lindorm.lindorm_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-matrixone/pyproject.toml b/api/providers/vdb/vdb-matrixone/pyproject.toml new file mode 100644 index 00000000000..53363ed7d9e --- /dev/null +++ b/api/providers/vdb/vdb-matrixone/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-matrixone" +version = "0.0.1" + +dependencies = [ + "mo-vector~=0.1.13", +] +description = "Dify vector store backend (dify-vdb-matrixone)." + +[project.entry-points."dify.vector_backends"] +matrixone = "dify_vdb_matrixone.matrixone_vector:MatrixoneVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/pgvector/__init__.py b/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/pgvector/__init__.py rename to api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/__init__.py diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/matrixone/matrixone_vector.py rename to api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py index c6ebccd204d..75fb54e6f4b 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py @@ -43,7 +43,7 @@ class MatrixoneConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config host is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/matrixone/test_matrixone.py b/api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py similarity index 74% rename from api/tests/integration_tests/vdb/matrixone/test_matrixone.py rename to api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py index fe592f6699c..d6f4781e65e 100644 --- a/api/tests/integration_tests/vdb/matrixone/test_matrixone.py +++ b/api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class MatrixoneVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py b/api/providers/vdb/vdb-matrixone/tests/unit_tests/test_matrixone_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py rename to api/providers/vdb/vdb-matrixone/tests/unit_tests/test_matrixone_vector.py index 55e7b9112ec..c22f4304e51 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py +++ b/api/providers/vdb/vdb-matrixone/tests/unit_tests/test_matrixone_vector.py @@ -36,7 +36,7 @@ def matrixone_module(monkeypatch): for name, module in _build_fake_mo_vector_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.matrixone.matrixone_vector as module + import dify_vdb_matrixone.matrixone_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-milvus/pyproject.toml b/api/providers/vdb/vdb-milvus/pyproject.toml new file mode 100644 index 00000000000..57385a44310 --- /dev/null +++ b/api/providers/vdb/vdb-milvus/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-milvus" +version = "0.0.1" + +dependencies = [ + "pymilvus~=2.6.12", +] +description = "Dify vector store backend (dify-vdb-milvus)." + +[project.entry-points."dify.vector_backends"] +milvus = "dify_vdb_milvus.milvus_vector:MilvusVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/pyvastbase/__init__.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/pyvastbase/__init__.py rename to api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/__init__.py diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py similarity index 96% rename from api/core/rag/datasource/vdb/milvus/milvus_vector.py rename to api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py index 96eb4654019..46f3224a952 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any +from typing import Any, TypedDict from packaging import version from pydantic import BaseModel, model_validator @@ -20,6 +20,15 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class MilvusParamsDict(TypedDict): + uri: str + token: str | None + user: str | None + password: str | None + db_name: str + analyzer_params: str | None + + class MilvusConfig(BaseModel): """ Configuration class for Milvus connection. @@ -36,7 +45,7 @@ class MilvusConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): """ Validate the configuration values. Raises ValueError if required fields are missing. @@ -50,11 +59,11 @@ class MilvusConfig(BaseModel): raise ValueError("config MILVUS_PASSWORD is required") return values - def to_milvus_params(self): + def to_milvus_params(self) -> MilvusParamsDict: """ Convert the configuration to a dictionary of Milvus connection parameters. """ - return { + result: MilvusParamsDict = { "uri": self.uri, "token": self.token, "user": self.user, @@ -62,6 +71,7 @@ class MilvusConfig(BaseModel): "db_name": self.database, "analyzer_params": self.analyzer_params, } + return result class MilvusVector(BaseVector): @@ -292,7 +302,10 @@ class MilvusVector(BaseVector): ) def create_collection( - self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None + self, + embeddings: list[list[float]], + metadatas: list[dict[str, Any]] | None = None, + index_params: dict[str, Any] | None = None, ): """ Create a new collection in Milvus with the specified schema and index parameters. @@ -352,6 +365,7 @@ class MilvusVector(BaseVector): # Create Index params for the collection index_params_obj = IndexParams() + assert index_params is not None index_params_obj.add_index(field_name=Field.VECTOR, **index_params) # Create Sparse Vector Index for the collection diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py similarity index 80% rename from api/tests/integration_tests/vdb/milvus/test_milvus.py rename to api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py index b5fc4b4d109..084d808bed2 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py @@ -1,11 +1,10 @@ -from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_milvus.milvus_vector import MilvusConfig, MilvusVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_text, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class MilvusVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py rename to api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py index 2ac2c40d38c..36c0ed8f6fe 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py @@ -103,7 +103,7 @@ def milvus_module(monkeypatch): for name, module in _build_fake_pymilvus_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.milvus.milvus_vector as module + import dify_vdb_milvus.milvus_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-myscale/pyproject.toml b/api/providers/vdb/vdb-myscale/pyproject.toml new file mode 100644 index 00000000000..13e0f35d23d --- /dev/null +++ b/api/providers/vdb/vdb-myscale/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-myscale" +version = "0.0.1" + +dependencies = [ + "clickhouse-connect~=0.15.0", +] +description = "Dify vector store backend (dify-vdb-myscale)." + +[project.entry-points."dify.vector_backends"] +myscale = "dify_vdb_myscale.myscale_vector:MyScaleVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/qdrant/__init__.py b/api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/qdrant/__init__.py rename to api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/__init__.py diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/myscale_vector.py similarity index 100% rename from api/core/rag/datasource/vdb/myscale/myscale_vector.py rename to api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/myscale_vector.py diff --git a/api/tests/integration_tests/vdb/myscale/test_myscale.py b/api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py similarity index 76% rename from api/tests/integration_tests/vdb/myscale/test_myscale.py rename to api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py index 74cefad2af7..8ea42d5f45c 100644 --- a/api/tests/integration_tests/vdb/myscale/test_myscale.py +++ b/api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleConfig, MyScaleVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_myscale.myscale_vector import MyScaleConfig, MyScaleVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class MyScaleVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py b/api/providers/vdb/vdb-myscale/tests/unit_tests/test_myscale_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py rename to api/providers/vdb/vdb-myscale/tests/unit_tests/test_myscale_vector.py index a75ba822385..228ea926395 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py +++ b/api/providers/vdb/vdb-myscale/tests/unit_tests/test_myscale_vector.py @@ -42,7 +42,7 @@ def myscale_module(monkeypatch): fake_module = _build_fake_clickhouse_connect_module() monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module) - import core.rag.datasource.vdb.myscale.myscale_vector as module + import dify_vdb_myscale.myscale_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-oceanbase/pyproject.toml b/api/providers/vdb/vdb-oceanbase/pyproject.toml new file mode 100644 index 00000000000..887869a41cf --- /dev/null +++ b/api/providers/vdb/vdb-oceanbase/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "dify-vdb-oceanbase" +version = "0.0.1" + +dependencies = [ + "pyobvector~=0.2.17", + "mysql-connector-python>=9.3.0", +] +description = "Dify vector store backend (dify-vdb-oceanbase)." + +[project.entry-points."dify.vector_backends"] +oceanbase = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory" +seekdb = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/relyt/__init__.py b/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/relyt/__init__.py rename to api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/__init__.py diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py rename to api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py index 82f419871c6..69dc42169aa 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py @@ -49,7 +49,7 @@ class OceanBaseVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config OCEANBASE_VECTOR_HOST is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/oceanbase/bench_oceanbase.py b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/bench_oceanbase.py similarity index 87% rename from api/tests/integration_tests/vdb/oceanbase/bench_oceanbase.py rename to api/providers/vdb/vdb-oceanbase/tests/integration_tests/bench_oceanbase.py index 8b57be08c55..50f67369425 100644 --- a/api/tests/integration_tests/vdb/oceanbase/bench_oceanbase.py +++ b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/bench_oceanbase.py @@ -2,11 +2,12 @@ Benchmark: OceanBase vector store — old (single-row) vs new (batch) insertion, metadata query with/without functional index, and vector search across metrics. -Usage: - uv run --project api python -m tests.integration_tests.vdb.oceanbase.bench_oceanbase +Usage (from repo root): + uv run --project api python api/packages/dify-vdb-oceanbase/tests/bench_oceanbase.py """ import json +import logging import random import statistics import time @@ -16,6 +17,8 @@ from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_d from sqlalchemy import JSON, Column, String, text from sqlalchemy.dialects.mysql import LONGTEXT +logger = logging.getLogger(__name__) + # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @@ -114,7 +117,7 @@ def bench_metadata_query(client, table, doc_id, with_index=False): try: client.perform_raw_text_sql(f"CREATE INDEX idx_metadata_doc_id ON `{table}` ((metadata->>'$.document_id'))") except Exception: - pass # already exists + logger.debug("Index idx_metadata_doc_id already exists, skipping creation") sql = text(f"SELECT id FROM `{table}` WHERE metadata->>'$.document_id' = :val") times = [] @@ -164,11 +167,11 @@ def main(): client = _make_client() client_pooled = _make_client(pool_size=5, max_overflow=10, pool_recycle=3600, pool_pre_ping=True) - print("=" * 70) - print("OceanBase Vector Store — Performance Benchmark") - print(f" Endpoint : {HOST}:{PORT}") - print(f" Vec dim : {VEC_DIM}") - print("=" * 70) + logger.info("=" * 70) + logger.info("OceanBase Vector Store — Performance Benchmark") + logger.info(" Endpoint : %s:%s", HOST, PORT) + logger.info(" Vec dim : %s", VEC_DIM) + logger.info("=" * 70) # ------------------------------------------------------------------ # 1. Insertion benchmark @@ -187,10 +190,10 @@ def main(): t_batch = bench_insert_batch(client_pooled, tbl_batch, rows, batch_size=100) speedup = t_single / t_batch if t_batch > 0 else float("inf") - print(f"\n[Insert {n_docs} docs]") - print(f" Single-row : {t_single:.2f}s") - print(f" Batch(100) : {t_batch:.2f}s") - print(f" Speedup : {speedup:.1f}x") + logger.info("\n[Insert %s docs]", n_docs) + logger.info(" Single-row : %.2fs", t_single) + logger.info(" Batch(100) : %.2fs", t_batch) + logger.info(" Speedup : %.1fx", speedup) # ------------------------------------------------------------------ # 2. Metadata query benchmark (use the 1000-doc batch table) @@ -203,16 +206,16 @@ def main(): res = conn.execute(text(f"SELECT metadata->>'$.document_id' FROM `{tbl_meta}` LIMIT 1")) doc_id_1000 = res.fetchone()[0] - print("\n[Metadata filter query — 1000 rows, by document_id]") + logger.info("\n[Metadata filter query — 1000 rows, by document_id]") times_no_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=False) - print(f" Without index : {_fmt(times_no_idx)}") + logger.info(" Without index : %s", _fmt(times_no_idx)) times_with_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=True) - print(f" With index : {_fmt(times_with_idx)}") + logger.info(" With index : %s", _fmt(times_with_idx)) # ------------------------------------------------------------------ # 3. Vector search benchmark — across metrics # ------------------------------------------------------------------ - print("\n[Vector search — top-10, 20 queries each, on 1000 rows]") + logger.info("\n[Vector search — top-10, 20 queries each, on 1000 rows]") for metric in ["l2", "cosine", "inner_product"]: tbl_vs = f"bench_vs_{metric}" @@ -222,7 +225,7 @@ def main(): rows_vs, _ = _gen_rows(1000) bench_insert_batch(client_pooled, tbl_vs, rows_vs, batch_size=100) times = bench_vector_search(client_pooled, tbl_vs, metric, topk=10, n_queries=20) - print(f" {metric:15s}: {_fmt(times)}") + logger.info(" %-15s: %s", metric, _fmt(times)) _drop(client_pooled, tbl_vs) # ------------------------------------------------------------------ @@ -232,9 +235,9 @@ def main(): _drop(client, f"bench_single_{n}") _drop(client, f"bench_batch_{n}") - print("\n" + "=" * 70) - print("Benchmark complete.") - print("=" * 70) + logger.info("\n%s", "=" * 70) + logger.info("Benchmark complete.") + logger.info("=" * 70) if __name__ == "__main__": diff --git a/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py similarity index 82% rename from api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py rename to api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py index 410de2c5ad2..28f22d3cbc1 100644 --- a/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py +++ b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py @@ -1,15 +1,13 @@ import pytest - -from core.rag.datasource.vdb.oceanbase.oceanbase_vector import ( +from dify_vdb_oceanbase.oceanbase_vector import ( OceanBaseVector, OceanBaseVectorConfig, ) -from tests.integration_tests.vdb.test_vector_store import ( + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - @pytest.fixture def oceanbase_vector(): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py b/api/providers/vdb/vdb-oceanbase/tests/unit_tests/test_oceanbase_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py rename to api/providers/vdb/vdb-oceanbase/tests/unit_tests/test_oceanbase_vector.py index 27d8198ec02..31f9ff3e56e 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py +++ b/api/providers/vdb/vdb-oceanbase/tests/unit_tests/test_oceanbase_vector.py @@ -56,7 +56,7 @@ def _build_fake_pyobvector_module(): def oceanbase_module(monkeypatch): monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module()) - import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module + import dify_vdb_oceanbase.oceanbase_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-opengauss/pyproject.toml b/api/providers/vdb/vdb-opengauss/pyproject.toml new file mode 100644 index 00000000000..79be94b9e32 --- /dev/null +++ b/api/providers/vdb/vdb-opengauss/pyproject.toml @@ -0,0 +1,12 @@ +[project] +name = "dify-vdb-opengauss" +version = "0.0.1" + +dependencies = [] +description = "Dify vector store backend (dify-vdb-opengauss)." + +[project.entry-points."dify.vector_backends"] +opengauss = "dify_vdb_opengauss.opengauss:OpenGaussFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/tablestore/__init__.py b/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/tablestore/__init__.py rename to api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/__init__.py diff --git a/api/core/rag/datasource/vdb/opengauss/opengauss.py b/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py similarity index 99% rename from api/core/rag/datasource/vdb/opengauss/opengauss.py rename to api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py index f9dbfbeeafb..acd2471cf67 100644 --- a/api/core/rag/datasource/vdb/opengauss/opengauss.py +++ b/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py @@ -29,7 +29,7 @@ class OpenGaussConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config OPENGAUSS_HOST is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/opengauss/test_opengauss.py b/api/providers/vdb/vdb-opengauss/tests/integration_tests/test_opengauss.py similarity index 82% rename from api/tests/integration_tests/vdb/opengauss/test_opengauss.py rename to api/providers/vdb/vdb-opengauss/tests/integration_tests/test_opengauss.py index 78436a19eef..8b444527d7a 100644 --- a/api/tests/integration_tests/vdb/opengauss/test_opengauss.py +++ b/api/providers/vdb/vdb-opengauss/tests/integration_tests/test_opengauss.py @@ -1,14 +1,12 @@ import time import psycopg2 +from dify_vdb_opengauss.opengauss import OpenGauss, OpenGaussConfig -from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig -from tests.integration_tests.vdb.test_vector_store import ( +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class OpenGaussTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py b/api/providers/vdb/vdb-opengauss/tests/unit_tests/test_opengauss.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py rename to api/providers/vdb/vdb-opengauss/tests/unit_tests/test_opengauss.py index 6641dbe4a09..09abd625fc1 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py +++ b/api/providers/vdb/vdb-opengauss/tests/unit_tests/test_opengauss.py @@ -41,7 +41,7 @@ def opengauss_module(monkeypatch): for name, module in _build_fake_psycopg2_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.opengauss.opengauss as module + import dify_vdb_opengauss.opengauss as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-opensearch/pyproject.toml b/api/providers/vdb/vdb-opensearch/pyproject.toml new file mode 100644 index 00000000000..56f303fdf5c --- /dev/null +++ b/api/providers/vdb/vdb-opensearch/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-opensearch" +version = "0.0.1" + +dependencies = [ + "opensearch-py==3.1.0", +] +description = "Dify vector store backend (dify-vdb-opensearch)." + +[project.entry-points."dify.vector_backends"] +opensearch = "dify_vdb_opensearch.opensearch_vector:OpenSearchVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/tencent/__init__.py b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/tencent/__init__.py rename to api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/__init__.py diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py similarity index 93% rename from api/core/rag/datasource/vdb/opensearch/opensearch_vector.py rename to api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py index 2f777768077..843c495d823 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py @@ -6,6 +6,7 @@ from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from configs.middleware.vdb.opensearch_config import AuthMethod @@ -21,6 +22,20 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class _OpenSearchHostDict(TypedDict): + host: str + port: int + + +class OpenSearchParamsDict(TypedDict, total=False): + hosts: list[_OpenSearchHostDict] + use_ssl: bool + verify_certs: bool + connection_class: type + pool_maxsize: int + http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth + + class OpenSearchConfig(BaseModel): host: str port: int @@ -34,7 +49,7 @@ class OpenSearchConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") if not values.get("port"): @@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel): service=self.aws_service, # type: ignore[arg-type] ) - def to_opensearch_params(self) -> dict[str, Any]: - params = { - "hosts": [{"host": self.host, "port": self.port}], - "use_ssl": self.secure, - "verify_certs": self.verify_certs, - "connection_class": Urllib3HttpConnection, - "pool_maxsize": 20, - } + def to_opensearch_params(self) -> OpenSearchParamsDict: + params = OpenSearchParamsDict( + hosts=[{"host": self.host, "port": self.port}], + use_ssl=self.secure, + verify_certs=self.verify_certs, + connection_class=Urllib3HttpConnection, + pool_maxsize=20, + ) if self.auth_method == "basic": logger.info("Using basic authentication for OpenSearch Vector DB") @@ -237,7 +252,10 @@ class OpenSearchVector(BaseVector): return docs def create_collection( - self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None + self, + embeddings: list[list[float]], + metadatas: list[dict[str, Any]] | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch.py b/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch.py new file mode 100644 index 00000000000..f2ed7cb6fbe --- /dev/null +++ b/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch.py @@ -0,0 +1,332 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.rag.datasource.vdb.field import Field +from core.rag.models.document import Document +from extensions import ext_redis + + +def _build_fake_opensearch_modules(): + """Build fake opensearchpy modules to avoid the ``from events import Events`` + namespace collision (opensearch-py #756).""" + opensearchpy = types.ModuleType("opensearchpy") + opensearchpy_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class Urllib3AWSV4SignerAuth: + def __init__(self, credentials, region, service): + self.credentials = credentials + self.region = region + self.service = service + + class Urllib3HttpConnection: + pass + + class _IndicesClient: + def __init__(self): + self.exists = MagicMock(return_value=False) + self.create = MagicMock() + self.delete = MagicMock() + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = _IndicesClient() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.get = MagicMock() + + helpers = SimpleNamespace(bulk=MagicMock()) + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth + opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection + opensearchpy.helpers = helpers + opensearchpy_helpers.BulkIndexError = BulkIndexError + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearchpy_helpers, + } + + +@pytest.fixture +def opensearch_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import dify_vdb_opensearch.opensearch_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "secure": False, + "user": "admin", + "password": "password", + } + values.update(overrides) + return module.OpenSearchConfig.model_validate(values) + + +def get_example_text() -> str: + return "This is a sample text for testing purposes." + + +class TestOpenSearchConfig: + def test_to_opensearch_params(self, opensearch_module): + config = _config(opensearch_module, secure=True) + params = config.to_opensearch_params() + + assert params["hosts"] == [{"host": "localhost", "port": 9200}] + assert params["use_ssl"] is True + assert params["verify_certs"] is True + assert params["connection_class"].__name__ == "Urllib3HttpConnection" + assert params["http_auth"] == ("admin", "password") + + def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch): + class _Session: + def get_credentials(self): + return "creds" + + boto3 = types.ModuleType("boto3") + boto3.Session = _Session + monkeypatch.setitem(sys.modules, "boto3", boto3) + + config = _config( + opensearch_module, + secure=True, + auth_method="aws_managed_iam", + aws_region="ap-southeast-2", + aws_service="aoss", + host="aoss-endpoint.ap-southeast-2.aoss.amazonaws.com", + port=9201, + ) + params = config.to_opensearch_params() + + assert params["hosts"] == [{"host": "aoss-endpoint.ap-southeast-2.aoss.amazonaws.com", "port": 9201}] + assert params["use_ssl"] is True + assert params["verify_certs"] is True + assert params["connection_class"].__name__ == "Urllib3HttpConnection" + assert params["http_auth"].credentials == "creds" + assert params["http_auth"].region == "ap-southeast-2" + assert params["http_auth"].service == "aoss" + + +class TestOpenSearchVector: + COLLECTION_NAME = "test_collection" + EXAMPLE_DOC_ID = "example_doc_id" + + def _make_vector(self, module): + vector = module.OpenSearchVector(self.COLLECTION_NAME, _config(module)) + vector._client = MagicMock() + return vector + + @pytest.mark.parametrize( + ("search_response", "expected_length", "expected_doc_id"), + [ + ( + { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + "page_content": get_example_text(), + "metadata": {"document_id": "example_doc_id"}, + } + } + ], + } + }, + 1, + "example_doc_id", + ), + ({"hits": {"total": {"value": 0}, "hits": []}}, 0, None), + ], + ) + def test_search_by_full_text(self, opensearch_module, search_response, expected_length, expected_doc_id): + vector = self._make_vector(opensearch_module) + vector._client.search.return_value = search_response + + hits = vector.search_by_full_text(query=get_example_text()) + assert len(hits) == expected_length + if expected_length > 0: + assert hits[0].metadata["document_id"] == expected_doc_id + + def test_search_by_vector(self, opensearch_module): + vector = self._make_vector(opensearch_module) + query_vector = [0.1] * 128 + mock_response = { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + Field.CONTENT_KEY: get_example_text(), + Field.METADATA_KEY: {"document_id": self.EXAMPLE_DOC_ID}, + }, + "_score": 1.0, + } + ], + } + } + vector._client.search.return_value = mock_response + + hits = vector.search_by_vector(query_vector=query_vector) + + assert len(hits) > 0 + assert hits[0].metadata["document_id"] == self.EXAMPLE_DOC_ID + + def test_get_ids_by_metadata_field(self, opensearch_module): + vector = self._make_vector(opensearch_module) + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + vector._client.search.return_value = mock_response + + doc = Document(page_content="Test content", metadata={"document_id": self.EXAMPLE_DOC_ID}) + embedding = [0.1] * 128 + + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts([doc], [embedding]) + + ids = vector.get_ids_by_metadata_field(key="document_id", value=self.EXAMPLE_DOC_ID) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + def test_add_texts(self, opensearch_module): + vector = self._make_vector(opensearch_module) + vector._client.index.return_value = {"result": "created"} + + doc = Document(page_content="Test content", metadata={"document_id": self.EXAMPLE_DOC_ID}) + embedding = [0.1] * 128 + + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts([doc], [embedding]) + + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + vector._client.search.return_value = mock_response + + ids = vector.get_ids_by_metadata_field(key="document_id", value=self.EXAMPLE_DOC_ID) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + def test_delete_nonexistent_index(self, opensearch_module): + """ignore_unavailable=True handles non-existent indices gracefully.""" + vector = self._make_vector(opensearch_module) + vector.delete() + + vector._client.indices.delete.assert_called_once_with( + index=self.COLLECTION_NAME.lower(), ignore_unavailable=True + ) + + def test_delete_existing_index(self, opensearch_module): + vector = self._make_vector(opensearch_module) + vector.delete() + + vector._client.indices.delete.assert_called_once_with( + index=self.COLLECTION_NAME.lower(), ignore_unavailable=True + ) + + +@pytest.fixture(scope="module") +def setup_mock_redis(): + ext_redis.redis_client.get = MagicMock(return_value=None) + ext_redis.redis_client.set = MagicMock(return_value=None) + + mock_redis_lock = MagicMock() + mock_redis_lock.__enter__ = MagicMock() + mock_redis_lock.__exit__ = MagicMock() + ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock) + + +@pytest.mark.usefixtures("setup_mock_redis") +class TestOpenSearchVectorWithRedis: + COLLECTION_NAME = "test_collection" + EXAMPLE_DOC_ID = "example_doc_id" + + def _make_vector(self, module): + vector = module.OpenSearchVector(self.COLLECTION_NAME, _config(module)) + vector._client = MagicMock() + return vector + + def test_search_by_full_text(self, opensearch_module): + vector = self._make_vector(opensearch_module) + search_response = { + "hits": { + "total": {"value": 1}, + "hits": [ + {"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}} + ], + } + } + vector._client.search.return_value = search_response + + hits = vector.search_by_full_text(query=get_example_text()) + assert len(hits) == 1 + assert hits[0].metadata["document_id"] == "example_doc_id" + + def test_get_ids_by_metadata_field(self, opensearch_module): + vector = self._make_vector(opensearch_module) + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + vector._client.search.return_value = mock_response + + doc = Document(page_content="Test content", metadata={"document_id": self.EXAMPLE_DOC_ID}) + embedding = [0.1] * 128 + + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts([doc], [embedding]) + + ids = vector.get_ids_by_metadata_field(key="document_id", value=self.EXAMPLE_DOC_ID) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + def test_add_texts(self, opensearch_module): + vector = self._make_vector(opensearch_module) + vector._client.index.return_value = {"result": "created"} + + doc = Document(page_content="Test content", metadata={"document_id": self.EXAMPLE_DOC_ID}) + embedding = [0.1] * 128 + + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts([doc], [embedding]) + + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + vector._client.search.return_value = mock_response + + ids = vector.get_ids_by_metadata_field(key="document_id", value=self.EXAMPLE_DOC_ID) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + def test_search_by_vector(self, opensearch_module): + vector = self._make_vector(opensearch_module) + query_vector = [0.1] * 128 + mock_response = { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + Field.CONTENT_KEY: get_example_text(), + Field.METADATA_KEY: {"document_id": self.EXAMPLE_DOC_ID}, + }, + "_score": 1.0, + } + ], + } + } + vector._client.search.return_value = mock_response + + hits = vector.search_by_vector(query_vector=query_vector) + assert len(hits) > 0 + assert hits[0].metadata["document_id"] == self.EXAMPLE_DOC_ID diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py b/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch_vector.py similarity index 98% rename from api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py rename to api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch_vector.py index 1030158dd1a..1c2921f85be 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py +++ b/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch_vector.py @@ -10,6 +10,8 @@ from pydantic import ValidationError from core.rag.models.document import Document +# TODO(wylswz): There's a known issue with namespace collision +# https://github.com/langgenius/dify/issues/34732 def _build_fake_opensearch_modules(): opensearchpy = types.ModuleType("opensearchpy") opensearchpy_helpers = types.ModuleType("opensearchpy.helpers") @@ -60,7 +62,7 @@ def opensearch_module(monkeypatch): for name, module in _build_fake_opensearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.opensearch.opensearch_vector as module + import dify_vdb_opensearch.opensearch_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-oracle/pyproject.toml b/api/providers/vdb/vdb-oracle/pyproject.toml new file mode 100644 index 00000000000..6747485041e --- /dev/null +++ b/api/providers/vdb/vdb-oracle/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-oracle" +version = "0.0.1" + +dependencies = [ + "oracledb==3.4.2", +] +description = "Dify vector store backend (dify-vdb-oracle)." + +[project.entry-points."dify.vector_backends"] +oracle = "dify_vdb_oracle.oraclevector:OracleVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py rename to api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/__init__.py diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py similarity index 99% rename from api/core/rag/datasource/vdb/oracle/oraclevector.py rename to api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py index cb05c22b555..70377c82c8d 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py @@ -36,7 +36,7 @@ class OracleVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["user"]: raise ValueError("config ORACLE_USER is required") if not values["password"]: diff --git a/api/tests/integration_tests/vdb/oracle/test_oraclevector.py b/api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py similarity index 76% rename from api/tests/integration_tests/vdb/oracle/test_oraclevector.py rename to api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py index 8920dc97ebb..aceb41289cf 100644 --- a/api/tests/integration_tests/vdb/oracle/test_oraclevector.py +++ b/api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py @@ -1,11 +1,10 @@ -from core.rag.datasource.vdb.oracle.oraclevector import OracleVector, OracleVectorConfig -from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_oracle.oraclevector import OracleVector, OracleVectorConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_text, ) - -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.models.document import Document class OracleVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py b/api/providers/vdb/vdb-oracle/tests/unit_tests/test_oraclevector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py rename to api/providers/vdb/vdb-oracle/tests/unit_tests/test_oraclevector.py index 817a7d342b3..678cf876b07 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py +++ b/api/providers/vdb/vdb-oracle/tests/unit_tests/test_oraclevector.py @@ -55,7 +55,7 @@ def oracle_module(monkeypatch): for name, module in _build_fake_oracle_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.oracle.oraclevector as module + import dify_vdb_oracle.oraclevector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-pgvecto-rs/pyproject.toml b/api/providers/vdb/vdb-pgvecto-rs/pyproject.toml new file mode 100644 index 00000000000..9a25442e9ec --- /dev/null +++ b/api/providers/vdb/vdb-pgvecto-rs/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-pgvecto-rs" +version = "0.0.1" + +dependencies = [ + "pgvecto-rs[sqlalchemy]~=0.2.2", +] +description = "Dify vector store backend (dify-vdb-pgvecto-rs)." + +[project.entry-points."dify.vector_backends"] +pgvecto-rs = "dify_vdb_pgvecto_rs.pgvecto_rs:PGVectoRSFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/tidb_vector/__init__.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/tidb_vector/__init__.py rename to api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/__init__.py diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/collection.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py similarity index 80% rename from api/core/rag/datasource/vdb/pgvecto_rs/collection.py rename to api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py index c335bc610d7..e087ec30a5a 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/collection.py +++ b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py @@ -1,3 +1,4 @@ +from typing import Any from uuid import UUID from numpy import ndarray @@ -8,5 +9,5 @@ class CollectionORM(DeclarativeBase): __tablename__: str id: Mapped[UUID] text: Mapped[str] - meta: Mapped[dict] + meta: Mapped[dict[str, Any]] vector: Mapped[ndarray] diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py similarity index 92% rename from api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py rename to api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py index 90d9173409c..9c721c8bde5 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py @@ -9,15 +9,15 @@ from pydantic import BaseModel, model_validator from sqlalchemy import Float, create_engine, insert, select, text from sqlalchemy import text as sql_text from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import Mapped, Session, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker from configs import dify_config -from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from dify_vdb_pgvecto_rs.collection import CollectionORM from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -33,7 +33,7 @@ class PgvectoRSConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") if not values["port"]: @@ -55,9 +55,8 @@ class PGVectoRS(BaseVector): f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" ) self._client = create_engine(self._url) - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) - session.commit() self._fields: list[str] = [] class _Table(CollectionORM): @@ -68,7 +67,7 @@ class PGVectoRS(BaseVector): primary_key=True, ) text: Mapped[str] - meta: Mapped[dict] = mapped_column(postgresql.JSONB) + meta: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB) vector: Mapped[ndarray] = mapped_column(VECTOR(dim)) self._table = _Table @@ -88,7 +87,7 @@ class PGVectoRS(BaseVector): if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: create_statement = sql_text(f""" CREATE TABLE IF NOT EXISTS {self._collection_name} ( id UUID PRIMARY KEY, @@ -111,12 +110,11 @@ class PGVectoRS(BaseVector): $$); """) session.execute(index_statement) - session.commit() redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): pks = [] - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: for document, embedding in zip(documents, embeddings): pk = uuid4() session.execute( @@ -128,7 +126,6 @@ class PGVectoRS(BaseVector): ), ) pks.append(pk) - session.commit() return pks @@ -145,10 +142,9 @@ class PGVectoRS(BaseVector): def delete_by_metadata_field(self, key: str, value: str): ids = self.get_ids_by_metadata_field(key, value) if ids: - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") session.execute(select_statement, {"ids": ids}) - session.commit() def delete_by_ids(self, ids: list[str]): with Session(self._client) as session: @@ -159,15 +155,13 @@ class PGVectoRS(BaseVector): if result: ids = [item[0] for item in result] if ids: - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") session.execute(select_statement, {"ids": ids}) - session.commit() def delete(self): - with Session(self._client) as session: + with sessionmaker(bind=self._client).begin() as session: session.execute(sql_text(f"DROP TABLE IF EXISTS {self._collection_name}")) - session.commit() def text_exists(self, id: str) -> bool: with Session(self._client) as session: diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py similarity index 82% rename from api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py rename to api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py index 6210613d421..9fc86278519 100644 --- a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py @@ -1,11 +1,10 @@ -from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_text, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class PGVectoRSVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/tests/unit_tests/test_pgvecto_rs.py similarity index 87% rename from api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py rename to api/providers/vdb/vdb-pgvecto-rs/tests/unit_tests/test_pgvecto_rs.py index 1aec81b8ac8..c3291f7f125 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/tests/unit_tests/test_pgvecto_rs.py @@ -53,13 +53,38 @@ def _session_factory(calls, execute_results=None): return _session +class _FakeBeginContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return None + + +def _sessionmaker_factory(calls, execute_results=None): + def _sessionmaker(*args, **kwargs): + session = _FakeSessionContext(calls=calls, execute_results=execute_results) + return MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session))) + + return _sessionmaker + + +def _patch_both(monkeypatch, module, calls, execute_results=None): + """Patch both Session and sessionmaker on the module with the same call tracker.""" + monkeypatch.setattr(module, "Session", _session_factory(calls, execute_results)) + monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(calls, execute_results)) + + @pytest.fixture def pgvecto_module(monkeypatch): for name, module in _build_fake_pgvecto_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module - import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module + import dify_vdb_pgvecto_rs.collection as collection_module + import dify_vdb_pgvecto_rs.pgvecto_rs as module return importlib.reload(module), importlib.reload(collection_module) @@ -105,7 +130,7 @@ def test_init_get_type_and_create_delegate(pgvecto_module, monkeypatch): module, _ = pgvecto_module session_calls = [] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + _patch_both(monkeypatch, module, session_calls) vector = module.PGVectoRS("collection_1", _config(module), dim=3) vector.create_collection = MagicMock() @@ -124,7 +149,7 @@ def test_create_collection_cache_and_sql_execution(pgvecto_module, monkeypatch): module, _ = pgvecto_module session_calls = [] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(session_calls)) + _patch_both(monkeypatch, module, session_calls) lock = MagicMock() lock.__enter__.return_value = None @@ -151,10 +176,10 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): execute_results = [SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)]), SimpleNamespace(fetchall=lambda: [])] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + _patch_both(monkeypatch, module, init_calls) vector = module.PGVectoRS("collection_1", _config(module), dim=3) - monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=list(execute_results))) + _patch_both(monkeypatch, module, runtime_calls, execute_results=list(execute_results)) class _InsertBuilder: def __init__(self, table): @@ -179,6 +204,7 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): "Session", _session_factory(runtime_calls, execute_results=[SimpleNamespace(fetchall=lambda: [("id-1",), ("id-2",)])]), ) + monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls)) assert vector.get_ids_by_metadata_field("document_id", "doc-1") == ["id-1", "id-2"] monkeypatch.setattr( @@ -204,12 +230,13 @@ def test_add_texts_get_ids_and_delete_methods(pgvecto_module, monkeypatch): ], ), ) + monkeypatch.setattr(module, "sessionmaker", _sessionmaker_factory(runtime_calls)) vector.delete_by_ids(["doc-1"]) assert any("meta->>'doc_id' = ANY (:doc_ids)" in str(args[0]) for args, _ in runtime_calls) assert any("DELETE FROM collection_1 WHERE id = ANY(:ids)" in str(args[0]) for args, _ in runtime_calls) runtime_calls.clear() - monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[MagicMock()])) + _patch_both(monkeypatch, module, runtime_calls, execute_results=[MagicMock()]) vector.delete() assert any("DROP TABLE IF EXISTS collection_1" in str(args[0]) for args, _ in runtime_calls) @@ -218,7 +245,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): module, _ = pgvecto_module init_calls = [] monkeypatch.setattr(module, "create_engine", MagicMock(return_value="engine")) - monkeypatch.setattr(module, "Session", _session_factory(init_calls)) + _patch_both(monkeypatch, module, init_calls) vector = module.PGVectoRS("collection_1", _config(module), dim=3) runtime_calls = [] @@ -277,7 +304,7 @@ def test_text_exists_search_and_full_text(pgvecto_module, monkeypatch): (SimpleNamespace(meta={"doc_id": "1"}, text="text-1"), 0.1), (SimpleNamespace(meta={"doc_id": "2"}, text="text-2"), 0.8), ] - monkeypatch.setattr(module, "Session", _session_factory(runtime_calls, execute_results=[rows])) + _patch_both(monkeypatch, module, runtime_calls, execute_results=[rows]) docs = vector.search_by_vector([0.1, 0.2], top_k=2, score_threshold=0.5, document_ids_filter=["d-1"]) assert len(docs) == 1 diff --git a/api/providers/vdb/vdb-pgvector/pyproject.toml b/api/providers/vdb/vdb-pgvector/pyproject.toml new file mode 100644 index 00000000000..2a972aa2772 --- /dev/null +++ b/api/providers/vdb/vdb-pgvector/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-pgvector" +version = "0.0.1" + +dependencies = [ + "pgvector==0.4.2", +] +description = "Dify vector store backend (dify-vdb-pgvector)." + +[project.entry-points."dify.vector_backends"] +pgvector = "dify_vdb_pgvector.pgvector:PGVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/upstash/__init__.py b/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/upstash/__init__.py rename to api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/__init__.py diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py similarity index 99% rename from api/core/rag/datasource/vdb/pgvector/pgvector.py rename to api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py index 0615b8312cb..b1bdce0ad44 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py @@ -34,7 +34,7 @@ class PGVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py b/api/providers/vdb/vdb-pgvector/tests/integration_tests/test_pgvector.py similarity index 73% rename from api/tests/integration_tests/vdb/pgvector/test_pgvector.py rename to api/providers/vdb/vdb-pgvector/tests/integration_tests/test_pgvector.py index 4fdeca5a3a8..974657510ec 100644 --- a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py +++ b/api/providers/vdb/vdb-pgvector/tests/integration_tests/test_pgvector.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_pgvector.pgvector import PGVector, PGVectorConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class PGVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py b/api/providers/vdb/vdb-pgvector/tests/unit_tests/test_pgvector.py similarity index 92% rename from api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py rename to api/providers/vdb/vdb-pgvector/tests/unit_tests/test_pgvector.py index 7505262eb78..99a6e00c16a 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py +++ b/api/providers/vdb/vdb-pgvector/tests/unit_tests/test_pgvector.py @@ -2,13 +2,10 @@ from contextlib import contextmanager from types import SimpleNamespace from unittest.mock import MagicMock, patch +import dify_vdb_pgvector.pgvector as pgvector_module import pytest +from dify_vdb_pgvector.pgvector import PGVector, PGVectorConfig -import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module -from core.rag.datasource.vdb.pgvector.pgvector import ( - PGVector, - PGVectorConfig, -) from core.rag.models.document import Document @@ -26,7 +23,7 @@ class TestPGVector: ) self.collection_name = "test_collection" - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") def test_init(self, mock_pool_class): """Test PGVector initialization.""" mock_pool = MagicMock() @@ -41,7 +38,7 @@ class TestPGVector: assert pgvector.pg_bigm is False assert pgvector.index_hash is not None - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") def test_init_with_pg_bigm(self, mock_pool_class): """Test PGVector initialization with pg_bigm enabled.""" config = PGVectorConfig( @@ -61,8 +58,8 @@ class TestPGVector: assert pgvector.pg_bigm is True - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_basic(self, mock_redis, mock_pool_class): """Test basic collection creation.""" # Mock Redis operations @@ -104,8 +101,8 @@ class TestPGVector: # Verify Redis cache was set mock_redis.set.assert_called_once() - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class): """Test collection creation with dimension > 2000 (no HNSW index).""" # Mock Redis operations @@ -139,8 +136,8 @@ class TestPGVector: hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)] assert len(hnsw_index_calls) == 0 - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class): """Test collection creation with pg_bigm enabled.""" config = PGVectorConfig( @@ -180,8 +177,8 @@ class TestPGVector: bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)] assert len(bigm_index_calls) == 1 - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class): """Test that vector extension is created if it doesn't exist.""" # Mock Redis operations @@ -213,8 +210,8 @@ class TestPGVector: ] assert len(create_extension_calls) == 1 - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class): """Test that collection creation is skipped when cache exists.""" # Mock Redis operations - cache exists @@ -240,8 +237,8 @@ class TestPGVector: # Check that no SQL was executed (early return due to cache) assert mock_cursor.execute.call_count == 0 - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class): """Test that Redis lock is used during collection creation.""" # Mock Redis operations @@ -273,7 +270,7 @@ class TestPGVector: mock_lock.__enter__.assert_called_once() mock_lock.__exit__.assert_called_once() - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") def test_get_cursor_context_manager(self, mock_pool_class): """Test that _get_cursor properly manages connection lifecycle.""" mock_pool = MagicMock() diff --git a/api/providers/vdb/vdb-qdrant/pyproject.toml b/api/providers/vdb/vdb-qdrant/pyproject.toml new file mode 100644 index 00000000000..6dd0b9560b5 --- /dev/null +++ b/api/providers/vdb/vdb-qdrant/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-qdrant" +version = "0.0.1" + +dependencies = [ + "qdrant-client==1.9.0", +] +description = "Dify vector store backend (dify-vdb-qdrant)." + +[project.entry-points."dify.vector_backends"] +qdrant = "dify_vdb_qdrant.qdrant_vector:QdrantVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/vikingdb/__init__.py b/api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/vikingdb/__init__.py rename to api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/__init__.py diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/qdrant_vector.py similarity index 97% rename from api/core/rag/datasource/vdb/qdrant/qdrant_vector.py rename to api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/qdrant_vector.py index a9f946dd43c..b5ff87fc5d6 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/qdrant_vector.py @@ -3,7 +3,7 @@ import os import uuid from collections.abc import Generator, Iterable, Sequence from itertools import islice -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import qdrant_client from flask import current_app @@ -22,7 +22,7 @@ from sqlalchemy import select from configs import dify_config from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -32,7 +32,6 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset, DatasetCollectionBinding if TYPE_CHECKING: - from qdrant_client import grpc # noqa from qdrant_client.conversions import common_types from qdrant_client.http import models as rest @@ -94,8 +93,12 @@ class QdrantVector(BaseVector): def get_type(self) -> str: return VectorType.QDRANT - def to_index_struct(self): - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + def to_index_struct(self) -> VectorIndexStructDict: + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: @@ -176,7 +179,7 @@ class QdrantVector(BaseVector): for batch_ids, points in self._generate_rest_batches( texts, embeddings, filtered_metadatas, uuids, 64, self._group_id ): - self._client.upsert(collection_name=self._collection_name, points=points) + self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points)) added_ids.extend(batch_ids) return added_ids @@ -468,7 +471,7 @@ class QdrantVector(BaseVector): def _reload_if_needed(self): if isinstance(self._client, QdrantLocal): - self._client._load() + self._client._load() # pyright: ignore[reportPrivateUsage] @classmethod def _document_from_scored_point( diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py similarity index 95% rename from api/tests/integration_tests/vdb/qdrant/test_qdrant.py rename to api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py index 709cc2e14ef..e0badeb5de2 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py @@ -1,12 +1,11 @@ import uuid -from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector -from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_qdrant.qdrant_vector import QdrantConfig, QdrantVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) - -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.models.document import Document class QdrantVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py b/api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py rename to api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py index 04085065638..0ed5491fbe0 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py +++ b/api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py @@ -125,7 +125,7 @@ def qdrant_module(monkeypatch): for name, module in _build_fake_qdrant_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.qdrant.qdrant_vector as module + import dify_vdb_qdrant.qdrant_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-relyt/pyproject.toml b/api/providers/vdb/vdb-relyt/pyproject.toml new file mode 100644 index 00000000000..2a7c7fac875 --- /dev/null +++ b/api/providers/vdb/vdb-relyt/pyproject.toml @@ -0,0 +1,12 @@ +[project] +name = "dify-vdb-relyt" +version = "0.0.1" + +dependencies = [] +description = "Dify vector store backend (dify-vdb-relyt)." + +[project.entry-points."dify.vector_backends"] +relyt = "dify_vdb_relyt.relyt_vector:RelytVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/weaviate/__init__.py b/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/weaviate/__init__.py rename to api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/__init__.py diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py similarity index 97% rename from api/core/rag/datasource/vdb/relyt/relyt_vector.py rename to api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py index e486375ec2a..336c2d3c8af 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, model_validator from sqlalchemy import Column, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType @@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) -Base = declarative_base() # type: Any +Base: Any = declarative_base() class RelytConfig(BaseModel): @@ -38,7 +38,7 @@ class RelytConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config RELYT_HOST is required") if not values["port"]: @@ -79,7 +79,7 @@ class RelytVector(BaseVector): if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" - with Session(self.client) as session: + with sessionmaker(bind=self.client).begin() as session: drop_statement = sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}"; """) session.execute(drop_statement) create_statement = sql_text(f""" @@ -104,7 +104,6 @@ class RelytVector(BaseVector): $$); """) session.execute(index_statement) - session.commit() redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -208,9 +207,8 @@ class RelytVector(BaseVector): self.delete_by_uuids(ids) def delete(self): - with Session(self.client) as session: + with sessionmaker(bind=self.client).begin() as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS "{self._collection_name}";""")) - session.commit() def text_exists(self, id: str) -> bool: with Session(self.client) as session: @@ -241,7 +239,7 @@ class RelytVector(BaseVector): self, embedding: list[float], k: int = 4, - filter: dict | None = None, + filter: dict[str, Any] | None = None, ) -> list[tuple[Document, float]]: # Add the filter if provided diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py b/api/providers/vdb/vdb-relyt/tests/unit_tests/test_relyt_vector.py similarity index 93% rename from api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py rename to api/providers/vdb/vdb-relyt/tests/unit_tests/test_relyt_vector.py index ca8cd5e5140..f97ad1400ae 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py +++ b/api/providers/vdb/vdb-relyt/tests/unit_tests/test_relyt_vector.py @@ -39,12 +39,31 @@ class _FakeSession: return None +class _FakeBeginContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return None + + +def _patch_both(monkeypatch, module, session): + """Patch both Session and sessionmaker on the module.""" + monkeypatch.setattr(module, "Session", lambda _client: session) + monkeypatch.setattr( + module, "sessionmaker", lambda **kwargs: MagicMock(begin=MagicMock(return_value=_FakeBeginContext(session))) + ) + + @pytest.fixture def relyt_module(monkeypatch): for name, module in _build_fake_relyt_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.relyt.relyt_vector as module + import dify_vdb_relyt.relyt_vector as module return importlib.reload(module) @@ -108,13 +127,13 @@ def test_create_collection_cache_and_sql_execution(relyt_module, monkeypatch): monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=1)) session = _FakeSession() - monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + _patch_both(monkeypatch, relyt_module, session) vector.create_collection(3) session.execute.assert_not_called() monkeypatch.setattr(relyt_module.redis_client, "get", MagicMock(return_value=None)) session = _FakeSession() - monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + _patch_both(monkeypatch, relyt_module, session) vector.create_collection(3) executed_sql = [str(call.args[0]) for call in session.execute.call_args_list] assert any("DROP TABLE IF EXISTS" in sql for sql in executed_sql) @@ -265,15 +284,15 @@ def test_search_by_vector_filters_by_score_and_ids(relyt_module): # 8. delete commits session -def test_delete_commits_session(relyt_module, monkeypatch): +def test_delete_drops_table(relyt_module, monkeypatch): vector = relyt_module.RelytVector.__new__(relyt_module.RelytVector) vector._collection_name = "collection_1" vector.client = MagicMock() vector.embedding_dimension = 3 session = _FakeSession() - monkeypatch.setattr(relyt_module, "Session", lambda _client: session) + _patch_both(monkeypatch, relyt_module, session) vector.delete() - session.commit.assert_called_once() + session.execute.assert_called_once() def test_relyt_factory_existing_and_generated_collection(relyt_module, monkeypatch): diff --git a/api/providers/vdb/vdb-tablestore/pyproject.toml b/api/providers/vdb/vdb-tablestore/pyproject.toml new file mode 100644 index 00000000000..fd1a2d54e0c --- /dev/null +++ b/api/providers/vdb/vdb-tablestore/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-tablestore" +version = "0.0.1" + +dependencies = [ + "tablestore==6.4.4", +] +description = "Dify vector store backend (dify-vdb-tablestore)." + +[project.entry-points."dify.vector_backends"] +tablestore = "dify_vdb_tablestore.tablestore_vector:TableStoreVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/__mock/__init__.py b/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/__init__.py rename to api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/__init__.py diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/tablestore/tablestore_vector.py rename to api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py index 4a734232ec1..f9deac11e5e 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py @@ -30,7 +30,7 @@ class TableStoreConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["access_key_id"]: raise ValueError("config ACCESS_KEY_ID is required") if not values["access_key_secret"]: diff --git a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py b/api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py similarity index 93% rename from api/tests/integration_tests/vdb/tablestore/test_tablestore.py rename to api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py index b60e26a881d..97c9626ee19 100644 --- a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py +++ b/api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py @@ -1,20 +1,21 @@ +import logging import os import uuid import tablestore from _pytest.python_api import approx - -from core.rag.datasource.vdb.tablestore.tablestore_vector import ( +from dify_vdb_tablestore.tablestore_vector import ( TableStoreConfig, TableStoreVector, ) -from tests.integration_tests.vdb.test_vector_store import ( + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_document, get_example_text, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +logger = logging.getLogger(__name__) class TableStoreVectorTest(AbstractVectorTest): @@ -90,7 +91,7 @@ class TableStoreVectorTest(AbstractVectorTest): try: self.vector.delete() except Exception: - pass + logger.debug("Failed to delete vector store during test setup, it may not exist yet") return super().run_all_tests() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py b/api/providers/vdb/vdb-tablestore/tests/unit_tests/test_tablestore_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py rename to api/providers/vdb/vdb-tablestore/tests/unit_tests/test_tablestore_vector.py index e3b6676d9bf..62a11e04457 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py +++ b/api/providers/vdb/vdb-tablestore/tests/unit_tests/test_tablestore_vector.py @@ -81,7 +81,7 @@ def tablestore_module(monkeypatch): fake_module = _build_fake_tablestore_module() monkeypatch.setitem(sys.modules, "tablestore", fake_module) - import core.rag.datasource.vdb.tablestore.tablestore_vector as module + import dify_vdb_tablestore.tablestore_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-tencent/pyproject.toml b/api/providers/vdb/vdb-tencent/pyproject.toml new file mode 100644 index 00000000000..7bb761b1696 --- /dev/null +++ b/api/providers/vdb/vdb-tencent/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-tencent" +version = "0.0.1" + +dependencies = [ + "tcvectordb~=2.1.0", +] +description = "Dify vector store backend (dify-vdb-tencent)." + +[project.entry-points."dify.vector_backends"] +tencent = "dify_vdb_tencent.tencent_vector:TencentVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/analyticdb/__init__.py b/api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/analyticdb/__init__.py rename to api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/__init__.py diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/tencent_vector.py similarity index 94% rename from api/core/rag/datasource/vdb/tencent/tencent_vector.py rename to api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/tencent_vector.py index 829db9db20a..2f26d6fff3f 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/tencent_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any +from typing import Any, TypedDict from pydantic import BaseModel from tcvdb_text.encoder import BM25Encoder # type: ignore @@ -12,7 +12,7 @@ from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, Weighted from configs import dify_config from core.rag.datasource.vdb.field import parse_metadata_json -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -23,6 +23,13 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class TencentParamsDict(TypedDict): + url: str + username: str | None + key: str | None + timeout: float + + class TencentConfig(BaseModel): url: str api_key: str | None = None @@ -36,8 +43,14 @@ class TencentConfig(BaseModel): max_upsert_batch_size: int = 128 enable_hybrid_search: bool = False # Flag to enable hybrid search - def to_tencent_params(self): - return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} + def to_tencent_params(self) -> TencentParamsDict: + result: TencentParamsDict = { + "url": self.url, + "username": self.username, + "key": self.api_key, + "timeout": self.timeout, + } + return result bm25 = BM25Encoder.default("zh") @@ -83,8 +96,12 @@ class TencentVector(BaseVector): def get_type(self) -> str: return VectorType.TENCENT - def to_index_struct(self): - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + def to_index_struct(self) -> VectorIndexStructDict: + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def _has_collection(self) -> bool: return bool( diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/providers/vdb/vdb-tencent/tests/integration_tests/conftest.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/tcvectordb.py rename to api/providers/vdb/vdb-tencent/tests/integration_tests/conftest.py diff --git a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py b/api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py similarity index 76% rename from api/tests/integration_tests/vdb/tcvectordb/test_tencent.py rename to api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py index 3d6deff2a06..a53ec87f927 100644 --- a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py +++ b/api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py @@ -1,12 +1,8 @@ from unittest.mock import MagicMock -from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_tencent.tencent_vector import TencentConfig, TencentVector -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.tcvectordb", -) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text mock_client = MagicMock() mock_client.list_databases.return_value = [{"name": "test"}] diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py b/api/providers/vdb/vdb-tencent/tests/unit_tests/test_tencent_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py rename to api/providers/vdb/vdb-tencent/tests/unit_tests/test_tencent_vector.py index d8f35a60197..299e40ee1ef 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py +++ b/api/providers/vdb/vdb-tencent/tests/unit_tests/test_tencent_vector.py @@ -140,7 +140,7 @@ def tencent_module(monkeypatch): for name, module in _build_fake_tencent_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.tencent.tencent_vector as module + import dify_vdb_tencent.tencent_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/pyproject.toml b/api/providers/vdb/vdb-tidb-on-qdrant/pyproject.toml new file mode 100644 index 00000000000..5040fb38ba2 --- /dev/null +++ b/api/providers/vdb/vdb-tidb-on-qdrant/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-tidb-on-qdrant" +version = "0.0.1" + +dependencies = [ + "qdrant-client==1.9.0", +] +description = "Dify vector store backend (dify-vdb-tidb-on-qdrant)." + +[project.entry-points."dify.vector_backends"] +tidb_on_qdrant = "dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector:TidbOnQdrantVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/baidu/__init__.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/baidu/__init__.py rename to api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/__init__.py diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py similarity index 87% rename from api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py rename to api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py index 499a48ac768..abca55f5402 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -1,4 +1,5 @@ import json +import logging import os import uuid from collections.abc import Generator, Iterable, Sequence @@ -7,6 +8,8 @@ from typing import TYPE_CHECKING, Any import httpx import qdrant_client + +logger = logging.getLogger(__name__) from flask import current_app from httpx import DigestAuth from pydantic import BaseModel @@ -24,12 +27,12 @@ from sqlalchemy import select from configs import dify_config from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, TidbAuthBinding @@ -91,8 +94,12 @@ class TidbOnQdrantVector(BaseVector): def get_type(self) -> str: return VectorType.TIDB_ON_QDRANT - def to_index_struct(self): - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + def to_index_struct(self) -> VectorIndexStructDict: + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: @@ -288,26 +295,27 @@ class TidbOnQdrantVector(BaseVector): if not ids: return - try: - filter = models.Filter( - must=[ - models.FieldCondition( - key="metadata.doc_id", - match=models.MatchAny(any=ids), - ), - ], - ) - self._client.delete( - collection_name=self._collection_name, - points_selector=FilterSelector(filter=filter), - ) - except UnexpectedResponse as e: - # Collection does not exist, so return - if e.status_code == 404: - return - # Some other error occurred, so re-raise the exception - else: - raise e + batch_size = 1000 + for i in range(0, len(ids), batch_size): + batch = ids[i : i + batch_size] + + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=batch), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code != 404: + raise e def text_exists(self, id: str) -> bool: all_collection_name = [] @@ -416,13 +424,16 @@ class TidbOnQdrantVector(BaseVector): class TidbOnQdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: + logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id) stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) tidb_auth_binding = db.session.scalars(stmt).one_or_none() if not tidb_auth_binding: + logger.info("No existing TidbAuthBinding for tenant %s, acquiring lock", dataset.tenant_id) with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) tidb_auth_binding = db.session.scalars(stmt).one_or_none() if tidb_auth_binding: + logger.info("Found binding after lock: cluster_id=%s", tidb_auth_binding.cluster_id) TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" else: @@ -432,11 +443,18 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): .limit(1) ) if idle_tidb_auth_binding: + logger.info( + "Assigning idle cluster %s to tenant %s", + idle_tidb_auth_binding.cluster_id, + dataset.tenant_id, + ) idle_tidb_auth_binding.active = True idle_tidb_auth_binding.tenant_id = dataset.tenant_id db.session.commit() + tidb_auth_binding = idle_tidb_auth_binding TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" else: + logger.info("No idle clusters available, creating new cluster for tenant %s", dataset.tenant_id) new_cluster = TidbService.create_tidb_serverless_cluster( dify_config.TIDB_PROJECT_ID or "", dify_config.TIDB_API_URL or "", @@ -445,21 +463,39 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): dify_config.TIDB_PRIVATE_KEY or "", dify_config.TIDB_REGION or "", ) + logger.info( + "New cluster created: cluster_id=%s, qdrant_endpoint=%s", + new_cluster["cluster_id"], + new_cluster.get("qdrant_endpoint"), + ) new_tidb_auth_binding = TidbAuthBinding( cluster_id=new_cluster["cluster_id"], cluster_name=new_cluster["cluster_name"], account=new_cluster["account"], password=new_cluster["password"], + qdrant_endpoint=new_cluster.get("qdrant_endpoint"), tenant_id=dataset.tenant_id, active=True, status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() + tidb_auth_binding = new_tidb_auth_binding TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" else: + logger.info("Existing binding found: cluster_id=%s", tidb_auth_binding.cluster_id) TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + qdrant_url = ( + (tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None) or dify_config.TIDB_ON_QDRANT_URL or "" + ) + logger.info( + "Using qdrant endpoint: %s (from_binding=%s, fallback_global=%s)", + qdrant_url, + tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None, + dify_config.TIDB_ON_QDRANT_URL, + ) + if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix @@ -474,7 +510,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): collection_name=collection_name, group_id=dataset.id, config=TidbOnQdrantConfig( - endpoint=dify_config.TIDB_ON_QDRANT_URL or "", + endpoint=qdrant_url, api_key=TIDB_ON_QDRANT_API_KEY, root_path=str(config.root_path), timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py similarity index 73% rename from api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py rename to api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py index 37114be6e72..ece061db675 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py @@ -1,3 +1,4 @@ +import logging import time import uuid from collections.abc import Sequence @@ -12,6 +13,8 @@ from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus +logger = logging.getLogger(__name__) + # Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn _tidb_http_client: httpx.Client = get_pooled_http_client( "tidb:cloud", @@ -20,6 +23,46 @@ _tidb_http_client: httpx.Client = get_pooled_http_client( class TidbService: + @staticmethod + def extract_qdrant_endpoint(cluster_response: dict) -> str | None: + """Extract the qdrant endpoint URL from a Get Cluster API response. + + Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``), + prepends ``qdrant-`` and wraps it as an ``https://`` URL. + """ + endpoints = cluster_response.get("endpoints") or {} + public = endpoints.get("public") or {} + host = public.get("host") + if host: + return f"https://qdrant-{host}" + return None + + @staticmethod + def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None: + """Call Get Cluster API and extract the qdrant endpoint. + + Use ``extract_qdrant_endpoint`` instead when you already have + the cluster response to avoid a redundant API call. + """ + try: + logger.info("Fetching qdrant endpoint for cluster %s", cluster_id) + cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) + if not cluster_response: + logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id) + return None + qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response) + if qdrant_url: + logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url) + return qdrant_url + logger.warning( + "No endpoints.public.host found for cluster %s, response keys: %s", + cluster_id, + list(cluster_response.keys()), + ) + except Exception: + logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id) + return None + @staticmethod def create_tidb_serverless_cluster( project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str @@ -57,6 +100,7 @@ class TidbService: "rootPassword": password, } + logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region) response = _tidb_http_client.post( f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key) ) @@ -64,21 +108,39 @@ class TidbService: if response.status_code == 200: response_data = response.json() cluster_id = response_data["clusterId"] + logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id) retry_count = 0 max_retries = 30 while retry_count < max_retries: cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) if cluster_response["state"] == "ACTIVE": user_prefix = cluster_response["userPrefix"] + qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response) + logger.info( + "Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s", + cluster_id, + user_prefix, + qdrant_endpoint, + ) return { "cluster_id": cluster_id, "cluster_name": display_name, "account": f"{user_prefix}.root", "password": password, + "qdrant_endpoint": qdrant_endpoint, } - time.sleep(30) # wait 30 seconds before retrying + logger.info( + "Cluster %s state=%s, retry %d/%d", + cluster_id, + cluster_response["state"], + retry_count + 1, + max_retries, + ) + time.sleep(30) retry_count += 1 + logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries) else: + logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text) response.raise_for_status() @staticmethod @@ -243,19 +305,29 @@ class TidbService: if response.status_code == 200: response_data = response.json() cluster_infos = [] + logger.info("Batch created %d clusters", len(response_data.get("clusters", []))) for item in response_data["clusters"]: cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" cached_password = redis_client.get(cache_key) if not cached_password: + logger.warning("No cached password for cluster %s, skipping", item["displayName"]) continue + qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"]) + logger.info( + "Batch cluster %s: qdrant_endpoint=%s", + item["clusterId"], + qdrant_endpoint, + ) cluster_info = { "cluster_id": item["clusterId"], "cluster_name": item["displayName"], "account": "root", "password": cached_password.decode("utf-8"), + "qdrant_endpoint": qdrant_endpoint, } cluster_infos.append(cluster_info) return cluster_infos else: + logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text) response.raise_for_status() return [] diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py similarity index 64% rename from api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py rename to api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py index c25af79ae4e..76802de62ef 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py @@ -2,13 +2,12 @@ from unittest.mock import patch import httpx import pytest -from qdrant_client.http import models as rest -from qdrant_client.http.exceptions import UnexpectedResponse - -from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import ( +from dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector import ( TidbOnQdrantConfig, TidbOnQdrantVector, ) +from qdrant_client.http import models as rest +from qdrant_client.http.exceptions import UnexpectedResponse class TestTidbOnQdrantVectorDeleteByIds: @@ -22,7 +21,7 @@ class TestTidbOnQdrantVectorDeleteByIds: api_key="test_api_key", ) - with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"): + with patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"): vector = TidbOnQdrantVector( collection_name="test_collection", group_id="test_group", @@ -115,14 +114,12 @@ class TestTidbOnQdrantVectorDeleteByIds: assert exc_info.value.status_code == 500 - def test_delete_by_ids_with_large_batch(self, vector_instance): - """Test deletion with a large batch of IDs.""" - # Create 1000 IDs + def test_delete_by_ids_with_exactly_1000(self, vector_instance): + """Test deletion with exactly 1000 IDs triggers a single batch.""" ids = [f"doc_{i}" for i in range(1000)] vector_instance.delete_by_ids(ids) - # Verify single delete call with all IDs vector_instance._client.delete.assert_called_once() call_args = vector_instance._client.delete.call_args @@ -130,11 +127,28 @@ class TestTidbOnQdrantVectorDeleteByIds: filter_obj = filter_selector.filter field_condition = filter_obj.must[0] - # Verify all 1000 IDs are in the batch assert len(field_condition.match.any) == 1000 assert "doc_0" in field_condition.match.any assert "doc_999" in field_condition.match.any + def test_delete_by_ids_splits_into_batches(self, vector_instance): + """Test deletion with >1000 IDs triggers multiple batched calls.""" + ids = [f"doc_{i}" for i in range(2500)] + + vector_instance.delete_by_ids(ids) + + assert vector_instance._client.delete.call_count == 3 + + batches = [] + for call in vector_instance._client.delete.call_args_list: + filter_selector = call[1]["points_selector"] + field_condition = filter_selector.filter.must[0] + batches.append(field_condition.match.any) + + assert len(batches[0]) == 1000 + assert len(batches[1]) == 1000 + assert len(batches[2]) == 500 + def test_delete_by_ids_filter_structure(self, vector_instance): """Test that the filter structure is correctly constructed.""" ids = ["doc1", "doc2"] @@ -158,3 +172,57 @@ class TestTidbOnQdrantVectorDeleteByIds: # Verify MatchAny structure assert isinstance(field_condition.match, rest.MatchAny) assert field_condition.match.any == ids + + +class TestInitVectorEndpointSelection: + """Test that init_vector selects the correct qdrant endpoint. + + We avoid importing the full module (which triggers Flask app context) + by testing the endpoint selection logic directly on TidbOnQdrantConfig. + """ + + def test_uses_binding_endpoint_when_present(self): + binding_endpoint = "https://qdrant-custom.tidb.com" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-custom.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-custom.tidb.com" + + def test_falls_back_to_global_when_binding_endpoint_is_none(self): + binding_endpoint = None + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-global.tidb.com" + + def test_falls_back_to_empty_when_both_none(self): + binding_endpoint = None + global_url = None + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "" + + def test_binding_endpoint_takes_precedence_over_global(self): + binding_endpoint = "https://qdrant-ap-southeast.tidb.com" + global_url = "https://qdrant-us-east.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-ap-southeast.tidb.com" + + def test_empty_string_binding_endpoint_falls_back_to_global(self): + binding_endpoint = "" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py new file mode 100644 index 00000000000..c1ffbacbbce --- /dev/null +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py @@ -0,0 +1,218 @@ +from unittest.mock import MagicMock, patch + +import pytest +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService + + +class TestExtractQdrantEndpoint: + """Unit tests for TidbService.extract_qdrant_endpoint.""" + + def test_returns_endpoint_when_host_present(self): + response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}} + result = TidbService.extract_qdrant_endpoint(response) + assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" + + def test_returns_none_when_host_missing(self): + response = {"endpoints": {"public": {}}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_public_missing(self): + response = {"endpoints": {}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_endpoints_missing(self): + assert TidbService.extract_qdrant_endpoint({}) is None + + +class TestFetchQdrantEndpoint: + """Unit tests for TidbService.fetch_qdrant_endpoint.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_endpoint_when_host_present(self, mock_get_cluster): + mock_get_cluster.return_value = { + "endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}} + } + result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") + assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster): + mock_get_cluster.return_value = None + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_host_missing(self, mock_get_cluster): + mock_get_cluster.return_value = {"endpoints": {"public": {}}} + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_endpoints_missing(self, mock_get_cluster): + mock_get_cluster.return_value = {} + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_on_exception(self, mock_get_cluster): + mock_get_cluster.side_effect = RuntimeError("network error") + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + +class TestCreateTidbServerlessClusterQdrantEndpoint: + """Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = { + "state": "ACTIVE", + "userPrefix": "pfx", + "endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}}, + } + + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"} + + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] is None + + +class TestBatchCreateTidbServerlessClusterQdrantEndpoint: + """Verify that batch_create includes qdrant_endpoint per cluster.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + cluster_name = "abc123" + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = b"password123" + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + assert len(result) == 1 + assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + +class TestCreateTidbServerlessClusterRetry: + """Cover retry/logging paths in create_tidb_serverless_cluster.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.side_effect = [ + {"state": "CREATING", "userPrefix": ""}, + {"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}}, + ] + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com" + assert mock_get_cluster.call_count == 2 + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""} + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is None + + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=400, text="Bad Request") + mock_response.raise_for_status.side_effect = Exception("HTTP 400") + mock_http.post.return_value = mock_response + + with pytest.raises(Exception, match="HTTP 400"): + TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + +class TestBatchCreateEdgeCases: + """Cover logging/edge-case branches in batch_create.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None) + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = None + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + assert len(result) == 0 + mock_fetch_ep.assert_not_called() + + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=500, text="Server Error") + mock_response.raise_for_status.side_effect = Exception("HTTP 500") + mock_http.post.return_value = mock_response + mock_redis.setex = MagicMock() + + with pytest.raises(Exception, match="HTTP 500"): + TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) diff --git a/api/providers/vdb/vdb-tidb-vector/pyproject.toml b/api/providers/vdb/vdb-tidb-vector/pyproject.toml new file mode 100644 index 00000000000..0e2f0ad88f6 --- /dev/null +++ b/api/providers/vdb/vdb-tidb-vector/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-tidb-vector" +version = "0.0.1" + +dependencies = [ + "tidb-vector==0.0.15", +] +description = "Dify vector store backend (dify-vdb-tidb-vector)." + +[project.entry-points."dify.vector_backends"] +tidb_vector = "dify_vdb_tidb_vector.tidb_vector:TiDBVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/chroma/__init__.py b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/chroma/__init__.py rename to api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/__init__.py diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py similarity index 97% rename from api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py rename to api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py index c9489173741..c696a685ddd 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py @@ -6,7 +6,7 @@ import sqlalchemy from pydantic import BaseModel, model_validator from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import text as sql_text -from sqlalchemy.orm import Session, declarative_base +from sqlalchemy.orm import Session, declarative_base, sessionmaker from configs import dify_config from core.rag.datasource.vdb.field import Field, parse_metadata_json @@ -31,7 +31,7 @@ class TiDBVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") if not values["port"]: @@ -97,8 +97,7 @@ class TiDBVector(BaseVector): if redis_client.get(collection_exist_cache_key): return tidb_dist_func = self._get_distance_func() - with Session(self._engine) as session: - session.begin() + with sessionmaker(bind=self._engine).begin() as session: create_statement = sql_text(f""" CREATE TABLE IF NOT EXISTS {self._collection_name} ( id CHAR(36) PRIMARY KEY, @@ -115,7 +114,6 @@ class TiDBVector(BaseVector): ); """) session.execute(create_statement) - session.commit() redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -238,9 +236,8 @@ class TiDBVector(BaseVector): return [] def delete(self): - with Session(self._engine) as session: + with sessionmaker(bind=self._engine).begin() as session: session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) - session.commit() def _get_distance_func(self) -> str: match self._distance_func: diff --git a/api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/check_tiflash_ready.py similarity index 72% rename from api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py rename to api/providers/vdb/vdb-tidb-vector/tests/integration_tests/check_tiflash_ready.py index f76700aa0e8..97f8406e424 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py +++ b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/check_tiflash_ready.py @@ -1,9 +1,13 @@ +import logging import time import pymysql +logger = logging.getLogger(__name__) + def check_tiflash_ready() -> bool: + connection = None try: connection = pymysql.connect( host="localhost", @@ -23,8 +27,8 @@ def check_tiflash_ready() -> bool: cursor.execute(select_tiflash_query) result = cursor.fetchall() return result is not None and len(result) > 0 - except Exception as e: - print(f"TiFlash is not ready. Exception: {e}") + except Exception: + logger.exception("TiFlash is not ready.") return False finally: if connection: @@ -38,20 +42,20 @@ def main(): for attempt in range(max_attempts): try: is_tiflash_ready = check_tiflash_ready() - except Exception as e: - print(f"TiFlash is not ready. Exception: {e}") + except Exception: + logger.exception("TiFlash is not ready.") is_tiflash_ready = False if is_tiflash_ready: break else: - print(f"Attempt {attempt + 1} failed, retry in {retry_interval_seconds} seconds...") + logger.error("Attempt %s failed, retry in %s seconds...", attempt + 1, retry_interval_seconds) time.sleep(retry_interval_seconds) if is_tiflash_ready: - print("TiFlash is ready in TiDB.") + logger.info("TiFlash is ready in TiDB.") else: - print(f"TiFlash is not ready in TiDB after {max_attempts} attempting checks.") + logger.error("TiFlash is not ready in TiDB after %s attempting checks.", max_attempts) exit(1) diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py similarity index 77% rename from api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py rename to api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py index 14c6d1c67c5..ac854acbf90 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py @@ -1,10 +1,8 @@ import pytest +from dify_vdb_tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig -from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig -from models.dataset import Document -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text - -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text +from core.rag.models.document import Document @pytest.fixture diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py similarity index 97% rename from api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py rename to api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py index 951a920f3bc..bdbed2f7403 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py @@ -12,7 +12,7 @@ from core.rag.models.document import Document @pytest.fixture def tidb_module(): - import core.rag.datasource.vdb.tidb_vector.tidb_vector as module + import dify_vdb_tidb_vector.tidb_vector as module return importlib.reload(module) @@ -137,14 +137,15 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke session = MagicMock() - class _SessionCtx: + class _BeginCtx: def __enter__(self): return session def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx())) + monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm) vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) vector._collection_name = "collection_1" @@ -153,11 +154,9 @@ def test_create_collection_executes_create_sql_and_sets_cache(tidb_module, monke vector._create_collection(3) - session.begin.assert_called_once() sql = str(session.execute.call_args.args[0]) assert "VECTOR(3)" in sql assert "VEC_L2_DISTANCE" in sql - session.commit.assert_called_once() tidb_module.redis_client.set.assert_called_once() @@ -396,23 +395,22 @@ def test_search_by_vector_filters_and_scores(tidb_module, monkeypatch): def test_delete_drops_table(tidb_module, monkeypatch): session = MagicMock() session.execute.return_value = None - session.commit = MagicMock() - class _SessionCtx: + class _BeginCtx: def __enter__(self): return session def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(tidb_module, "Session", lambda _engine: _SessionCtx()) + mock_sm = MagicMock(begin=MagicMock(return_value=_BeginCtx())) + monkeypatch.setattr(tidb_module, "sessionmaker", lambda **kwargs: mock_sm) vector = tidb_module.TiDBVector.__new__(tidb_module.TiDBVector) vector._collection_name = "collection_1" vector._engine = MagicMock() vector.delete() drop_sql = str(session.execute.call_args.args[0]) assert "DROP TABLE IF EXISTS collection_1" in drop_sql - session.commit.assert_called_once() def test_tidb_factory_uses_existing_or_generated_collection(tidb_module, monkeypatch): diff --git a/api/providers/vdb/vdb-upstash/pyproject.toml b/api/providers/vdb/vdb-upstash/pyproject.toml new file mode 100644 index 00000000000..f71773cdbb1 --- /dev/null +++ b/api/providers/vdb/vdb-upstash/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-upstash" +version = "0.0.1" + +dependencies = [ + "upstash-vector==0.8.0", +] +description = "Dify vector store backend (dify-vdb-upstash)." + +[project.entry-points."dify.vector_backends"] +upstash = "dify_vdb_upstash.upstash_vector:UpstashVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/couchbase/__init__.py b/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/couchbase/__init__.py rename to api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/__init__.py diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py similarity index 98% rename from api/core/rag/datasource/vdb/upstash/upstash_vector.py rename to api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py index 289d9718533..75d70a19647 100644 --- a/api/core/rag/datasource/vdb/upstash/upstash_vector.py +++ b/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py @@ -20,7 +20,7 @@ class UpstashVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["url"]: raise ValueError("Upstash URL is required") if not values["token"]: diff --git a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py b/api/providers/vdb/vdb-upstash/tests/integration_tests/conftest.py similarity index 94% rename from api/tests/integration_tests/vdb/__mock/upstashvectordb.py rename to api/providers/vdb/vdb-upstash/tests/integration_tests/conftest.py index 70c85d4c989..adba0c150ce 100644 --- a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py +++ b/api/providers/vdb/vdb-upstash/tests/integration_tests/conftest.py @@ -6,7 +6,6 @@ from _pytest.monkeypatch import MonkeyPatch from upstash_vector import Index -# Mocking the Index class from upstash_vector class MockIndex: def __init__(self, url="", token=""): self.url = url @@ -37,7 +36,6 @@ class MockIndex: namespace: str = "", include_data: bool = False, ): - # Simple mock query, in real scenario you would calculate similarity mock_result = [] for vector_data in self.vectors: mock_result.append(vector_data) diff --git a/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py b/api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py similarity index 75% rename from api/tests/integration_tests/vdb/upstash/test_upstash_vector.py rename to api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py index 8cea0a05eb6..f4a65030b65 100644 --- a/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py +++ b/api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py @@ -1,8 +1,7 @@ -from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVector, UpstashVectorConfig -from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_upstash.upstash_vector import UpstashVector, UpstashVectorConfig -pytest_plugins = ("tests.integration_tests.vdb.__mock.upstashvectordb",) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text +from core.rag.models.document import Document class UpstashVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py b/api/providers/vdb/vdb-upstash/tests/unit_tests/test_upstash_vector.py similarity index 97% rename from api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py rename to api/providers/vdb/vdb-upstash/tests/unit_tests/test_upstash_vector.py index ac8a63a44ba..a884275c89c 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py +++ b/api/providers/vdb/vdb-upstash/tests/unit_tests/test_upstash_vector.py @@ -38,11 +38,11 @@ def _build_fake_upstash_module(): @pytest.fixture def upstash_module(monkeypatch): # Remove patched modules if present - for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]: + for modname in ["upstash_vector", "dify_vdb_upstash.upstash_vector"]: if modname in sys.modules: monkeypatch.delitem(sys.modules, modname, raising=False) monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module()) - module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector") + module = importlib.import_module("dify_vdb_upstash.upstash_vector") return module diff --git a/api/providers/vdb/vdb-vastbase/pyproject.toml b/api/providers/vdb/vdb-vastbase/pyproject.toml new file mode 100644 index 00000000000..287eb147dc1 --- /dev/null +++ b/api/providers/vdb/vdb-vastbase/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-vastbase" +version = "0.0.1" + +dependencies = [ + "pyobvector~=0.2.17", +] +description = "Dify vector store backend (dify-vdb-vastbase)." + +[project.entry-points."dify.vector_backends"] +vastbase = "dify_vdb_vastbase.vastbase_vector:VastbaseVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/elasticsearch/__init__.py b/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/elasticsearch/__init__.py rename to api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/__init__.py diff --git a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py b/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py rename to api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py index d080e8da580..ab00f9db28e 100644 --- a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py +++ b/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py @@ -28,7 +28,7 @@ class VastbaseVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config VASTBASE_HOST is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py b/api/providers/vdb/vdb-vastbase/tests/integration_tests/test_vastbase_vector.py similarity index 72% rename from api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py rename to api/providers/vdb/vdb-vastbase/tests/integration_tests/test_vastbase_vector.py index a47f13625cd..0467dec37a2 100644 --- a/api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py +++ b/api/providers/vdb/vdb-vastbase/tests/integration_tests/test_vastbase_vector.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_vastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class VastbaseVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py b/api/providers/vdb/vdb-vastbase/tests/unit_tests/test_vastbase_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py rename to api/providers/vdb/vdb-vastbase/tests/unit_tests/test_vastbase_vector.py index bd8df520ba2..4dfb956c002 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py +++ b/api/providers/vdb/vdb-vastbase/tests/unit_tests/test_vastbase_vector.py @@ -41,7 +41,7 @@ def vastbase_module(monkeypatch): for name, module in _build_fake_psycopg2_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module + import dify_vdb_vastbase.vastbase_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-vikingdb/pyproject.toml b/api/providers/vdb/vdb-vikingdb/pyproject.toml new file mode 100644 index 00000000000..fdf59f76a41 --- /dev/null +++ b/api/providers/vdb/vdb-vikingdb/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-vikingdb" +version = "0.0.1" + +dependencies = [ + "volcengine-compat~=1.0.0", +] +description = "Dify vector store backend (dify-vdb-vikingdb)." + +[project.entry-points."dify.vector_backends"] +vikingdb = "dify_vdb_vikingdb.vikingdb_vector:VikingDBVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/hologres/__init__.py b/api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/hologres/__init__.py rename to api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/__init__.py diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py similarity index 100% rename from api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py rename to api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/providers/vdb/vdb-vikingdb/tests/integration_tests/conftest.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/vikingdb.py rename to api/providers/vdb/vdb-vikingdb/tests/integration_tests/conftest.py diff --git a/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py b/api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py similarity index 78% rename from api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py rename to api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py index 56311acd25e..5a3908d14be 100644 --- a/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py +++ b/api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py @@ -1,10 +1,6 @@ -from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.vikingdb", -) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text class VikingDBVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py b/api/providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py rename to api/providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py index 9da92af2d02..544b8163bef 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py +++ b/api/providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py @@ -83,7 +83,7 @@ def vikingdb_module(monkeypatch): for name, module in _build_fake_vikingdb_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module + import dify_vdb_vikingdb.vikingdb_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-weaviate/pyproject.toml b/api/providers/vdb/vdb-weaviate/pyproject.toml new file mode 100644 index 00000000000..035fbd396d1 --- /dev/null +++ b/api/providers/vdb/vdb-weaviate/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-weaviate" +version = "0.0.1" + +dependencies = [ + "weaviate-client==4.20.5", +] +description = "Dify vector store backend (dify-vdb-weaviate)." + +[project.entry-points."dify.vector_backends"] +weaviate = "dify_vdb_weaviate.weaviate_vector:WeaviateVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/huawei/__init__.py b/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/huawei/__init__.py rename to api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/__init__.py diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py similarity index 88% rename from api/core/rag/datasource/vdb/weaviate/weaviate_vector.py rename to api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py index d29d62c93f6..902e6a03a8f 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py @@ -20,11 +20,11 @@ from pydantic import BaseModel, model_validator from weaviate.classes.data import DataObject from weaviate.classes.init import Auth from weaviate.classes.query import Filter, MetadataQuery -from weaviate.exceptions import UnexpectedStatusCodeError +from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateQueryError from configs import dify_config from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings @@ -82,7 +82,7 @@ class WeaviateConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]: """Validates that required configuration values are present.""" if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") @@ -184,9 +184,13 @@ class WeaviateVector(BaseVector): dataset_id = dataset.id return Dataset.gen_collection_name_by_id(dataset_id) - def to_index_struct(self) -> dict: + def to_index_struct(self) -> VectorIndexStructDict: """Returns the index structure dictionary for persistence.""" - return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + result: VectorIndexStructDict = { + "type": self.get_type(), + "vector_store": {"class_prefix": self._collection_name}, + } + return result def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): """ @@ -226,6 +230,8 @@ class WeaviateVector(BaseVector): wc.Property(name="doc_id", data_type=wc.DataType.TEXT), wc.Property(name="doc_type", data_type=wc.DataType.TEXT), wc.Property(name="chunk_index", data_type=wc.DataType.INT), + wc.Property(name="is_summary", data_type=wc.DataType.BOOL), + wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT), ], vector_config=wc.Configure.Vectors.self_provided(), ) @@ -258,6 +264,10 @@ class WeaviateVector(BaseVector): to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT)) if "chunk_index" not in existing: to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT)) + if "is_summary" not in existing: + to_add.append(wc.Property(name="is_summary", data_type=wc.DataType.BOOL)) + if "original_chunk_id" not in existing: + to_add.append(wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT)) for prop in to_add: try: @@ -396,15 +406,27 @@ class WeaviateVector(BaseVector): top_k = int(kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) - res = col.query.near_vector( - near_vector=query_vector, - limit=top_k, - return_properties=props, - return_metadata=MetadataQuery(distance=True), - include_vector=False, - filters=where, - target_vector="default", - ) + try: + res = col.query.near_vector( + near_vector=query_vector, + limit=top_k, + return_properties=props, + return_metadata=MetadataQuery(distance=True), + include_vector=False, + filters=where, + target_vector="default", + ) + except WeaviateQueryError: + self._ensure_properties() + res = col.query.near_vector( + near_vector=query_vector, + limit=top_k, + return_properties=props, + return_metadata=MetadataQuery(distance=True), + include_vector=False, + filters=where, + target_vector="default", + ) docs: list[Document] = [] for obj in res.objects: @@ -442,14 +464,25 @@ class WeaviateVector(BaseVector): top_k = int(kwargs.get("top_k", 4)) - res = col.query.bm25( - query=query, - query_properties=[Field.TEXT_KEY.value], - limit=top_k, - return_properties=props, - include_vector=True, - filters=where, - ) + try: + res = col.query.bm25( + query=query, + query_properties=[Field.TEXT_KEY.value], + limit=top_k, + return_properties=props, + include_vector=True, + filters=where, + ) + except WeaviateQueryError: + self._ensure_properties() + res = col.query.bm25( + query=query, + query_properties=[Field.TEXT_KEY.value], + limit=top_k, + return_properties=props, + include_vector=True, + filters=where, + ) docs: list[Document] = [] for obj in res.objects: diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/providers/vdb/vdb-weaviate/tests/integration_tests/test_weaviate.py similarity index 72% rename from api/tests/integration_tests/vdb/weaviate/test_weaviate.py rename to api/providers/vdb/vdb-weaviate/tests/integration_tests/test_weaviate.py index a1d9850979f..631d23d653a 100644 --- a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py +++ b/api/providers/vdb/vdb-weaviate/tests/integration_tests/test_weaviate.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_weaviate.weaviate_vector import WeaviateConfig, WeaviateVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class WeaviateVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weavaite.py b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weavaite.py similarity index 92% rename from api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weavaite.py rename to api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weavaite.py index baf8c9e5f81..c773e4d5529 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weavaite.py +++ b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weavaite.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock, patch -from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector +from dify_vdb_weaviate.weaviate_vector import WeaviateConfig, WeaviateVector def test_init_client_with_valid_config(): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py similarity index 84% rename from api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py rename to api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py index 69d18330011..b40f7e52ca4 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py +++ b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py @@ -14,9 +14,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from dify_vdb_weaviate import weaviate_vector as weaviate_vector_module +from dify_vdb_weaviate.weaviate_vector import WeaviateConfig, WeaviateVector -from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module -from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from core.rag.models.document import Document @@ -40,7 +40,7 @@ class TestWeaviateVector(unittest.TestCase): with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"): WeaviateConfig(endpoint="") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def _create_weaviate_vector(self, mock_weaviate_module): """Helper to create a WeaviateVector instance with mocked client.""" mock_client = MagicMock() @@ -66,7 +66,7 @@ class TestWeaviateVector(unittest.TestCase): mock_client.close.assert_called_once() mock_debug.assert_called_once() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate.connect_to_custom") def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect): cached_client = MagicMock() cached_client.is_ready.return_value = True @@ -79,7 +79,7 @@ class TestWeaviateVector(unittest.TestCase): assert client is cached_client mock_connect.assert_not_called() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate.connect_to_custom") def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect): cached_client = MagicMock() cached_client.is_ready.side_effect = [False, True] @@ -92,8 +92,8 @@ class TestWeaviateVector(unittest.TestCase): assert client is cached_client mock_connect.assert_not_called() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + @patch("dify_vdb_weaviate.weaviate_vector.Auth.api_key", return_value="auth-token") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate.connect_to_custom") def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -122,7 +122,7 @@ class TestWeaviateVector(unittest.TestCase): } mock_api_key.assert_called_once_with("test-key") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate.connect_to_custom") def test_init_client_raises_when_database_not_ready(self, mock_connect): mock_client = MagicMock() mock_client.is_ready.return_value = False @@ -133,7 +133,7 @@ class TestWeaviateVector(unittest.TestCase): with pytest.raises(ConnectionError, match="Vector database is not ready"): wv._init_client(self.config) - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_init(self, mock_weaviate_module): """Test WeaviateVector initialization stores attributes including doc_type.""" mock_client = MagicMock() @@ -183,9 +183,9 @@ class TestWeaviateVector(unittest.TestCase): wv._create_collection.assert_called_once() wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]]) - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.redis_client") + @patch("dify_vdb_weaviate.weaviate_vector.dify_config") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_create_collection_includes_doc_type_property(self, mock_weaviate_module, mock_dify_config, mock_redis): """Test that _create_collection defines doc_type in the schema properties.""" # Mock Redis @@ -232,7 +232,7 @@ class TestWeaviateVector(unittest.TestCase): f"doc_type should be in collection schema properties, got: {property_names}" ) - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + @patch("dify_vdb_weaviate.weaviate_vector.redis_client") def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis): mock_lock = MagicMock() mock_lock.__enter__ = MagicMock() @@ -251,7 +251,7 @@ class TestWeaviateVector(unittest.TestCase): wv._ensure_properties.assert_not_called() mock_redis.set.assert_not_called() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + @patch("dify_vdb_weaviate.weaviate_vector.redis_client") def test_create_collection_logs_and_reraises_errors(self, mock_redis): mock_lock = MagicMock() mock_lock.__enter__ = MagicMock() @@ -270,7 +270,7 @@ class TestWeaviateVector(unittest.TestCase): mock_exception.assert_called_once() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties adds doc_type when it's missing from existing schema.""" mock_client = MagicMock() @@ -305,7 +305,7 @@ class TestWeaviateVector(unittest.TestCase): added_names = [call.args[0].name for call in add_calls] assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}" - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -326,9 +326,9 @@ class TestWeaviateVector(unittest.TestCase): add_calls = mock_col.config.add_property.call_args_list added_names = [call.args[0].name for call in add_calls] - assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"] + assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index", "is_summary", "original_chunk_id"] - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties does not add doc_type when it already exists.""" mock_client = MagicMock() @@ -346,6 +346,8 @@ class TestWeaviateVector(unittest.TestCase): SimpleNamespace(name="doc_id"), SimpleNamespace(name="doc_type"), SimpleNamespace(name="chunk_index"), + SimpleNamespace(name="is_summary"), + SimpleNamespace(name="original_chunk_id"), ] mock_cfg = MagicMock() mock_cfg.properties = existing_props @@ -361,7 +363,7 @@ class TestWeaviateVector(unittest.TestCase): # No properties should be added mock_col.config.add_property.assert_not_called() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -383,9 +385,9 @@ class TestWeaviateVector(unittest.TestCase): with patch.object(weaviate_vector_module.logger, "warning") as mock_warning: wv._ensure_properties() - assert mock_warning.call_count == 4 + assert mock_warning.call_count == 6 - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_vector returns doc_type in document metadata. @@ -432,7 +434,7 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -469,7 +471,7 @@ class TestWeaviateVector(unittest.TestCase): assert docs[0].metadata["score"] == 0.0 assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -484,7 +486,57 @@ class TestWeaviateVector(unittest.TestCase): assert wv.search_by_vector(query_vector=[0.1] * 3) == [] - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") + def test_search_by_vector_retries_on_weaviate_query_error(self, mock_weaviate_module): + """Test that search_by_vector catches WeaviateQueryError, calls _ensure_properties, and retries.""" + from weaviate.exceptions import WeaviateQueryError + + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + # First call raises WeaviateQueryError, second call succeeds + mock_obj = MagicMock() + mock_obj.properties = {"text": "retry result", "document_id": "doc-1"} + mock_obj.metadata.distance = 0.2 + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + + mock_col.query.near_vector.side_effect = [ + WeaviateQueryError("missing property", "gRPC"), + mock_result, + ] + + # Mock _ensure_properties dependencies + mock_cfg = MagicMock() + mock_cfg.properties = [ + SimpleNamespace(name="text"), + SimpleNamespace(name="document_id"), + SimpleNamespace(name="doc_id"), + SimpleNamespace(name="doc_type"), + SimpleNamespace(name="chunk_index"), + SimpleNamespace(name="is_summary"), + SimpleNamespace(name="original_chunk_id"), + ] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_vector(query_vector=[0.1] * 3, top_k=1) + + assert mock_col.query.near_vector.call_count == 2 + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_full_text also returns doc_type in document metadata.""" mock_client = MagicMock() @@ -526,7 +578,7 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -554,7 +606,7 @@ class TestWeaviateVector(unittest.TestCase): assert docs[0].vector == [0.3, 0.4] assert mock_col.query.bm25.call_args.kwargs["filters"] is not None - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -569,7 +621,57 @@ class TestWeaviateVector(unittest.TestCase): assert wv.search_by_full_text(query="missing") == [] - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_retries_on_weaviate_query_error(self, mock_weaviate_module): + """Test that search_by_full_text catches WeaviateQueryError, calls _ensure_properties, and retries.""" + from weaviate.exceptions import WeaviateQueryError + + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + # First call raises WeaviateQueryError, second call succeeds + mock_obj = MagicMock() + mock_obj.properties = {"text": "retry bm25 result", "doc_id": "segment-1"} + mock_obj.vector = {"default": [0.5, 0.6]} + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + + mock_col.query.bm25.side_effect = [ + WeaviateQueryError("missing property", "gRPC"), + mock_result, + ] + + # Mock _ensure_properties dependencies + mock_cfg = MagicMock() + mock_cfg.properties = [ + SimpleNamespace(name="text"), + SimpleNamespace(name="document_id"), + SimpleNamespace(name="doc_id"), + SimpleNamespace(name="doc_type"), + SimpleNamespace(name="chunk_index"), + SimpleNamespace(name="is_summary"), + SimpleNamespace(name="original_chunk_id"), + ] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_full_text(query="retry", top_k=1) + + assert mock_col.query.bm25.call_count == 2 + assert len(docs) == 1 + assert docs[0].page_content == "retry bm25 result" + + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module): """Test that add_texts includes doc_type from document metadata in stored properties.""" mock_client = MagicMock() @@ -611,7 +713,7 @@ class TestWeaviateVector(unittest.TestCase): stored_props = call_kwargs.kwargs.get("properties") assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}" - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -635,7 +737,7 @@ class TestWeaviateVector(unittest.TestCase): with ( patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]), - patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"), + patch("dify_vdb_weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"), ): ids = wv.add_texts(documents=[doc], embeddings=[[]]) @@ -775,9 +877,7 @@ class TestWeaviateVectorFactory(unittest.TestCase): patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"), patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"), patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88), - patch( - "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" - ) as mock_vector, + patch("dify_vdb_weaviate.weaviate_vector.WeaviateVector", return_value="vector") as mock_vector, ): factory = weaviate_vector_module.WeaviateVectorFactory() result = factory.init_vector(dataset, attributes, MagicMock()) @@ -806,9 +906,7 @@ class TestWeaviateVectorFactory(unittest.TestCase): "gen_collection_name_by_id", return_value="GeneratedCollection_Node", ), - patch( - "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" - ) as mock_vector, + patch("dify_vdb_weaviate.weaviate_vector.WeaviateVector", return_value="vector") as mock_vector, ): factory = weaviate_vector_module.WeaviateVectorFactory() result = factory.init_vector(dataset, attributes, MagicMock()) diff --git a/api/pyproject.toml b/api/pyproject.toml index 863b61cad13..fbd5d394ad4 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -4,93 +4,49 @@ version = "1.13.3" requires-python = "~=3.12.0" dependencies = [ - "aliyun-log-python-sdk~=0.9.37", - "arize-phoenix-otel~=0.15.0", - "azure-identity==1.25.3", - "beautifulsoup4==4.14.3", - "boto3==1.42.78", - "bs4~=0.0.1", - "cachetools~=5.3.0", - "celery~=5.6.2", - "charset-normalizer>=3.4.4", - "flask~=3.1.2", - "flask-compress>=1.17,<1.24", - "flask-cors~=6.0.0", - "flask-login~=0.6.3", - "flask-migrate~=4.1.0", - "flask-orjson~=2.0.0", - "flask-sqlalchemy~=3.1.1", - "gevent~=25.9.1", - "gmpy2~=2.3.0", - "google-api-core>=2.19.1", - "google-api-python-client==2.193.0", - "google-auth>=2.47.0", - "google-auth-httplib2==0.3.0", - "google-cloud-aiplatform>=1.123.0", - "googleapis-common-protos>=1.65.0", - "graphon>=0.1.2", - "gunicorn~=25.3.0", - "httpx[socks]~=0.28.0", - "jieba==0.42.1", - "json-repair>=0.55.1", - "langfuse>=3.0.0,<5.0.0", - "langsmith~=0.7.16", - "markdown~=3.10.2", - "mlflow-skinny>=3.0.0", - "numpy~=1.26.4", - "openpyxl~=3.1.5", - "opik~=1.10.37", - "litellm==1.82.6", # Pinned to avoid madoka dependency issue - "opentelemetry-api==1.40.0", - "opentelemetry-distro==0.61b0", - "opentelemetry-exporter-otlp==1.40.0", - "opentelemetry-exporter-otlp-proto-common==1.40.0", - "opentelemetry-exporter-otlp-proto-grpc==1.40.0", - "opentelemetry-exporter-otlp-proto-http==1.40.0", - "opentelemetry-instrumentation==0.61b0", - "opentelemetry-instrumentation-celery==0.61b0", - "opentelemetry-instrumentation-flask==0.61b0", - "opentelemetry-instrumentation-httpx==0.61b0", - "opentelemetry-instrumentation-redis==0.61b0", - "opentelemetry-instrumentation-sqlalchemy==0.61b0", - "opentelemetry-propagator-b3==1.40.0", - "opentelemetry-proto==1.40.0", - "opentelemetry-sdk==1.40.0", - "opentelemetry-semantic-conventions==0.61b0", - "opentelemetry-util-http==0.61b0", - "pandas[excel,output-formatting,performance]~=3.0.1", - "psycogreen~=1.0.2", - "psycopg2-binary~=2.9.6", - "pycryptodome==3.23.0", - "pydantic~=2.12.5", - "pydantic-settings~=2.13.1", - "pyjwt~=2.12.0", - "pypdfium2==5.6.0", - "python-docx~=1.2.0", - "python-dotenv==1.2.2", - "pyyaml~=6.0.1", - "readabilipy~=0.3.0", - "redis[hiredis]~=7.4.0", - "resend~=2.26.0", - "sentry-sdk[flask]~=2.55.0", - "sqlalchemy~=2.0.29", - "starlette==1.0.0", - "tiktoken~=0.12.0", - "transformers~=5.3.0", - "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", - "pypandoc~=1.13", - "yarl~=1.23.0", - "sseclient-py~=1.9.0", + # Legacy: mature and widely deployed + "bleach>=6.3.0", + "boto3>=1.42.91", + "celery>=5.6.3", + "croniter>=6.2.2", + "flask-cors>=6.0.2", + "gevent>=26.4.0", + "gevent-websocket>=0.10.1", + "gmpy2>=2.3.0", + "google-api-python-client>=2.194.0", + "gunicorn>=25.3.0", + "psycogreen>=1.0.2", + "psycopg2-binary>=2.9.11", + "python-socketio>=5.13.0", + "redis[hiredis]>=7.4.0", + "sendgrid>=6.12.5", + "sseclient-py>=1.8.0", + + # Stable: production-proven, cap below the next major + "aliyun-log-python-sdk>=0.9.44,<1.0.0", + "azure-identity>=1.25.3,<2.0.0", + "flask-compress>=1.24,<2.0.0", + "flask-login>=0.6.3,<1.0.0", + "flask-migrate>=4.1.0,<5.0.0", + "flask-orjson>=2.0.0,<3.0.0", + "flask-restx>=1.3.2,<2.0.0", + "google-cloud-aiplatform>=1.148.1,<2.0.0", + "httpx[socks]>=0.28.1,<1.0.0", + "opentelemetry-distro>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-celery>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-flask>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-httpx>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-redis>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-sqlalchemy>=0.62b0,<1.0.0", + "opentelemetry-propagator-b3>=1.41.0,<2.0.0", + "readabilipy>=0.3.0,<1.0.0", + "resend>=2.27.0,<3.0.0", + + # Emerging: newer and fast-moving, use compatible pins + "fastopenapi[flask]~=0.7.0", + "graphon~=0.2.2", "httpx-sse~=0.4.0", - "sendgrid~=6.12.3", - "flask-restx~=1.3.2", - "packaging~=23.2", - "croniter>=6.0.0", - "weaviate-client==4.20.4", - "apscheduler>=3.11.0", - "weave>=0.52.16", - "fastopenapi[flask]>=0.7.0", - "bleach~=6.3.0", + "json-repair~=0.59.4", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -98,9 +54,56 @@ dependencies = [ [tool.setuptools] packages = [] +[tool.uv.workspace] +members = ["providers/vdb/*", "providers/trace/*"] +exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"] + +[tool.uv.sources] +dify-vdb-alibabacloud-mysql = { workspace = true } +dify-vdb-analyticdb = { workspace = true } +dify-vdb-baidu = { workspace = true } +dify-vdb-chroma = { workspace = true } +dify-vdb-clickzetta = { workspace = true } +dify-vdb-couchbase = { workspace = true } +dify-vdb-elasticsearch = { workspace = true } +dify-vdb-hologres = { workspace = true } +dify-vdb-huawei-cloud = { workspace = true } +dify-vdb-iris = { workspace = true } +dify-vdb-lindorm = { workspace = true } +dify-vdb-matrixone = { workspace = true } +dify-vdb-milvus = { workspace = true } +dify-vdb-myscale = { workspace = true } +dify-vdb-oceanbase = { workspace = true } +dify-vdb-opengauss = { workspace = true } +dify-vdb-opensearch = { workspace = true } +dify-vdb-oracle = { workspace = true } +dify-vdb-pgvecto-rs = { workspace = true } +dify-vdb-pgvector = { workspace = true } +dify-vdb-qdrant = { workspace = true } +dify-vdb-relyt = { workspace = true } +dify-vdb-tablestore = { workspace = true } +dify-vdb-tencent = { workspace = true } +dify-vdb-tidb-on-qdrant = { workspace = true } +dify-vdb-tidb-vector = { workspace = true } +dify-vdb-upstash = { workspace = true } +dify-vdb-vastbase = { workspace = true } +dify-vdb-vikingdb = { workspace = true } +dify-vdb-weaviate = { workspace = true } +dify-trace-aliyun = { workspace = true } +dify-trace-arize-phoenix = { workspace = true } +dify-trace-langfuse = { workspace = true } +dify-trace-langsmith = { workspace = true } +dify-trace-mlflow = { workspace = true } +dify-trace-opik = { workspace = true } +dify-trace-tencent = { workspace = true } +dify-trace-weave = { workspace = true } + [tool.uv] -default-groups = ["storage", "tools", "vdb"] +default-groups = ["storage", "tools", "vdb-all", "trace-all"] package = false +override-dependencies = [ + "pyarrow>=18.0.0", +] [dependency-groups] @@ -109,69 +112,69 @@ package = false # Required for development and running tests ############################################################ dev = [ - "coverage~=7.13.4", - "dotenv-linter~=0.7.0", - "faker~=40.11.0", - "lxml-stubs~=0.5.1", - "basedpyright~=1.38.2", - "ruff~=0.15.5", - "pytest~=9.0.2", - "pytest-benchmark~=5.2.3", - "pytest-cov~=7.1.0", - "pytest-env~=1.6.0", - "pytest-mock~=3.15.1", - "testcontainers~=4.14.1", - "types-aiofiles~=25.1.0", - "types-beautifulsoup4~=4.12.0", - "types-cachetools~=6.2.0", - "types-colorama~=0.4.15", - "types-defusedxml~=0.7.0", - "types-deprecated~=1.3.1", - "types-docutils~=0.22.3", - "types-flask-cors~=6.0.0", - "types-flask-migrate~=4.1.0", - "types-gevent~=25.9.0", - "types-greenlet~=3.3.0", - "types-html5lib~=1.1.11", - "types-markdown~=3.10.2", - "types-oauthlib~=3.3.0", - "types-objgraph~=3.6.0", - "types-olefile~=0.47.0", - "types-openpyxl~=3.1.5", - "types-pexpect~=4.9.0", - "types-protobuf~=6.32.1", - "types-psutil~=7.2.2", - "types-psycopg2~=2.9.21", - "types-pygments~=2.19.0", - "types-pymysql~=1.1.0", - "types-python-dateutil~=2.9.0", - "types-pywin32~=311.0.0", - "types-pyyaml~=6.0.12", - "types-regex~=2026.3.32", - "types-shapely~=2.1.0", - "types-simplejson>=3.20.0", - "types-six>=1.17.0", - "types-tensorflow>=2.18.0", - "types-tqdm>=4.67.0", + "coverage>=7.13.4", + "dotenv-linter>=0.7.0", + "faker>=40.15.0", + "lxml-stubs>=0.5.1", + "basedpyright>=1.39.3", + "ruff>=0.15.11", + "pytest>=9.0.3", + "pytest-benchmark>=5.2.3", + "pytest-cov>=7.1.0", + "pytest-env>=1.6.0", + "pytest-mock>=3.15.1", + "testcontainers>=4.14.2", + "types-aiofiles>=25.1.0", + "types-beautifulsoup4>=4.12.0", + "types-cachetools>=6.2.0", + "types-colorama>=0.4.15", + "types-defusedxml>=0.7.0", + "types-deprecated>=1.3.1", + "types-docutils>=0.22.3", + "types-flask-cors>=6.0.0", + "types-flask-migrate>=4.1.0", + "types-gevent>=26.4.0", + "types-greenlet>=3.4.0", + "types-html5lib>=1.1.11", + "types-markdown>=3.10.2", + "types-oauthlib>=3.3.0", + "types-objgraph>=3.6.0", + "types-olefile>=0.47.0", + "types-openpyxl>=3.1.5", + "types-pexpect>=4.9.0", + "types-protobuf>=7.34.1", + "types-psutil>=7.2.2", + "types-psycopg2>=2.9.21", + "types-pygments>=2.20.0", + "types-pymysql>=1.1.0", + "types-python-dateutil>=2.9.0", + "types-pywin32>=311.0.0", + "types-pyyaml>=6.0.12", + "types-regex>=2026.4.4", + "types-shapely>=2.1.0", + "types-simplejson>=3.20.0.20260408", + "types-six>=1.17.0.20260408", + "types-tensorflow>=2.18.0.20260408", + "types-tqdm>=4.67.3.20260408", "types-ujson>=5.10.0", - "boto3-stubs>=1.38.20", - "types-jmespath>=1.0.2.20240106", - "hypothesis>=6.131.15", + "boto3-stubs>=1.42.92", + "types-jmespath>=1.1.0.20260408", + "hypothesis>=6.152.1", "types_pyOpenSSL>=24.1.0", - "types_cffi>=1.17.0", - "types_setuptools>=80.9.0", - "pandas-stubs~=3.0.0", - "scipy-stubs>=1.15.3.0", - "types-python-http-client>=3.3.7.20240910", + "types_cffi>=2.0.0.20260408", + "types_setuptools>=82.0.0.20260408", + "pandas-stubs>=3.0.0", + "scipy-stubs>=1.17.1.4", + "types-python-http-client>=3.3.7.20260408", "import-linter>=2.3", "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", - "mypy~=1.19.1", + "mypy>=1.20.1", # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. - "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.57.1", + "pyrefly>=0.61.1", + "xinference-client>=2.5.0", ] ############################################################ @@ -179,54 +182,110 @@ dev = [ # Required for storage clients ############################################################ storage = [ - "azure-storage-blob==12.28.0", - "bce-python-sdk~=0.9.23", - "cos-python-sdk-v5==1.9.41", - "esdk-obs-python==3.26.2", - "google-cloud-storage>=3.0.0", - "opendal~=0.46.0", - "oss2==2.19.1", - "supabase~=2.18.1", - "tos~=2.9.0", + "azure-storage-blob>=12.28.0", + "bce-python-sdk>=0.9.70", + "cos-python-sdk-v5>=1.9.41", + "esdk-obs-python>=3.22.2", + "google-cloud-storage>=3.10.1", + "opendal>=0.46.0", + "oss2>=2.19.1", + "supabase>=2.28.3", + "tos>=2.9.0", ] ############################################################ # [ Tools ] dependency group ############################################################ -tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] +tools = ["cloudscraper>=1.2.71", "nltk>=3.9.1"] ############################################################ -# [ VDB ] dependency group -# Required by vector store clients +# [ VDB ] workspace plugins — hollow packages under providers/vdb/* +# Each declares its own third-party deps and registers dify.vector_backends entry points. +# Use: uv sync --group vdb-all | uv sync --group vdb-qdrant ############################################################ -vdb = [ - "alibabacloud_gpdb20160503~=5.1.0", - "alibabacloud_tea_openapi~=0.4.3", - "chromadb==0.5.20", - "clickhouse-connect~=0.15.0", - "clickzetta-connector-python>=0.8.102", - "couchbase~=4.5.0", - "elasticsearch==8.14.0", - "opensearch-py==3.1.0", - "oracledb==3.4.2", - "pgvecto-rs[sqlalchemy]~=0.2.1", - "pgvector==0.4.2", - "pymilvus~=2.6.10", - "pymochow==2.3.6", - "pyobvector~=0.2.17", - "qdrant-client==1.9.0", - "intersystems-irispython>=5.1.0", - "tablestore==6.4.2", - "tcvectordb~=2.1.0", - "tidb-vector==0.0.15", - "upstash-vector==0.8.0", - "volcengine-compat~=1.0.0", - "weaviate-client==4.20.4", - "xinference-client~=2.4.0", - "mo-vector~=0.1.13", - "mysql-connector-python>=9.3.0", - "holo-search-sdk>=0.4.1", +vdb-all = [ + "dify-vdb-alibabacloud-mysql", + "dify-vdb-analyticdb", + "dify-vdb-baidu", + "dify-vdb-chroma", + "dify-vdb-clickzetta", + "dify-vdb-couchbase", + "dify-vdb-elasticsearch", + "dify-vdb-hologres", + "dify-vdb-huawei-cloud", + "dify-vdb-iris", + "dify-vdb-lindorm", + "dify-vdb-matrixone", + "dify-vdb-milvus", + "dify-vdb-myscale", + "dify-vdb-oceanbase", + "dify-vdb-opengauss", + "dify-vdb-opensearch", + "dify-vdb-oracle", + "dify-vdb-pgvecto-rs", + "dify-vdb-pgvector", + "dify-vdb-qdrant", + "dify-vdb-relyt", + "dify-vdb-tablestore", + "dify-vdb-tencent", + "dify-vdb-tidb-on-qdrant", + "dify-vdb-tidb-vector", + "dify-vdb-upstash", + "dify-vdb-vastbase", + "dify-vdb-vikingdb", + "dify-vdb-weaviate", ] +vdb-alibabacloud-mysql = ["dify-vdb-alibabacloud-mysql"] +vdb-analyticdb = ["dify-vdb-analyticdb"] +vdb-baidu = ["dify-vdb-baidu"] +vdb-chroma = ["dify-vdb-chroma"] +vdb-clickzetta = ["dify-vdb-clickzetta"] +vdb-couchbase = ["dify-vdb-couchbase"] +vdb-elasticsearch = ["dify-vdb-elasticsearch"] +vdb-hologres = ["dify-vdb-hologres"] +vdb-huawei-cloud = ["dify-vdb-huawei-cloud"] +vdb-iris = ["dify-vdb-iris"] +vdb-lindorm = ["dify-vdb-lindorm"] +vdb-matrixone = ["dify-vdb-matrixone"] +vdb-milvus = ["dify-vdb-milvus"] +vdb-myscale = ["dify-vdb-myscale"] +vdb-oceanbase = ["dify-vdb-oceanbase"] +vdb-opengauss = ["dify-vdb-opengauss"] +vdb-opensearch = ["dify-vdb-opensearch"] +vdb-oracle = ["dify-vdb-oracle"] +vdb-pgvecto-rs = ["dify-vdb-pgvecto-rs"] +vdb-pgvector = ["dify-vdb-pgvector"] +vdb-qdrant = ["dify-vdb-qdrant"] +vdb-relyt = ["dify-vdb-relyt"] +vdb-tablestore = ["dify-vdb-tablestore"] +vdb-tencent = ["dify-vdb-tencent"] +vdb-tidb-on-qdrant = ["dify-vdb-tidb-on-qdrant"] +vdb-tidb-vector = ["dify-vdb-tidb-vector"] +vdb-upstash = ["dify-vdb-upstash"] +vdb-vastbase = ["dify-vdb-vastbase"] +vdb-vikingdb = ["dify-vdb-vikingdb"] +vdb-weaviate = ["dify-vdb-weaviate"] +# Optional client used by some tests / integrations (not a vector backend plugin) +vdb-xinference = ["xinference-client>=2.5.0"] + +trace-all = [ + "dify-trace-aliyun", + "dify-trace-arize-phoenix", + "dify-trace-langfuse", + "dify-trace-langsmith", + "dify-trace-mlflow", + "dify-trace-opik", + "dify-trace-tencent", + "dify-trace-weave", +] +trace-aliyun = ["dify-trace-aliyun"] +trace-arize-phoenix = ["dify-trace-arize-phoenix"] +trace-langfuse = ["dify-trace-langfuse"] +trace-langsmith = ["dify-trace-langsmith"] +trace-mlflow = ["dify-trace-mlflow"] +trace-opik = ["dify-trace-opik"] +trace-tencent = ["dify-trace-tencent"] +trace-weave = ["dify-trace-weave"] [tool.pyrefly] project-includes = ["."] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index 43f604c2de9..fbbca245587 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -34,42 +34,18 @@ core/external_data_tool/api/api.py core/llm_generator/llm_generator.py core/llm_generator/output_parser/structured_output.py core/mcp/mcp_client.py -core/ops/aliyun_trace/data_exporter/traceclient.py -core/ops/arize_phoenix_trace/arize_phoenix_trace.py -core/ops/mlflow_trace/mlflow_trace.py +providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py +providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py core/ops/ops_trace_manager.py -core/ops/tencent_trace/client.py -core/ops/tencent_trace/utils.py +providers/trace/trace-tencent/src/dify_trace_tencent/client.py +providers/trace/trace-tencent/src/dify_trace_tencent/utils.py core/plugin/backwards_invocation/base.py core/plugin/backwards_invocation/model.py core/prompt/utils/extract_thread_messages.py core/rag/datasource/keyword/jieba/jieba.py core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py -core/rag/datasource/vdb/analyticdb/analyticdb_vector.py -core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py -core/rag/datasource/vdb/baidu/baidu_vector.py -core/rag/datasource/vdb/chroma/chroma_vector.py -core/rag/datasource/vdb/clickzetta/clickzetta_vector.py -core/rag/datasource/vdb/couchbase/couchbase_vector.py -core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py -core/rag/datasource/vdb/huawei/huawei_cloud_vector.py -core/rag/datasource/vdb/lindorm/lindorm_vector.py -core/rag/datasource/vdb/matrixone/matrixone_vector.py -core/rag/datasource/vdb/milvus/milvus_vector.py -core/rag/datasource/vdb/myscale/myscale_vector.py -core/rag/datasource/vdb/oceanbase/oceanbase_vector.py -core/rag/datasource/vdb/opensearch/opensearch_vector.py -core/rag/datasource/vdb/oracle/oraclevector.py -core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py -core/rag/datasource/vdb/relyt/relyt_vector.py -core/rag/datasource/vdb/tablestore/tablestore_vector.py -core/rag/datasource/vdb/tencent/tencent_vector.py -core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py -core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py -core/rag/datasource/vdb/tidb_vector/tidb_vector.py -core/rag/datasource/vdb/upstash/upstash_vector.py -core/rag/datasource/vdb/vikingdb/vikingdb_vector.py -core/rag/datasource/vdb/weaviate/weaviate_vector.py +providers/vdb/** core/rag/extractor/csv_extractor.py core/rag/extractor/excel_extractor.py core/rag/extractor/firecrawl/firecrawl_app.py diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index a8b884ea81f..ac0e2a3a53d 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -4,7 +4,9 @@ "tests/", ".venv", "migrations/", - "core/rag" + "core/rag", + "providers/vdb/", + "providers/trace/*/tests", ], "typeCheckingMode": "strict", "allowedUntypedLibraries": [ @@ -36,7 +38,9 @@ "gmpy2", "sendgrid", "sendgrid.helpers.mail", - "holo_search_sdk.types" + "holo_search_sdk.types", + "dify_vdb_qdrant", + "dify_vdb_tidb_on_qdrant" ], "reportUnknownMemberType": "hint", "reportUnknownParameterType": "hint", @@ -47,7 +51,6 @@ "reportMissingTypeArgument": "hint", "reportUnnecessaryComparison": "hint", "reportUnnecessaryIsInstance": "hint", - "reportUntypedFunctionDecorator": "hint", "reportUnnecessaryTypeIgnoreComment": "hint", "reportAttributeAccessIssue": "hint", "pythonVersion": "3.12", diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 1a2a539c802..72b38e79068 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -36,13 +36,13 @@ Example: from collections.abc import Callable, Sequence from datetime import datetime -from typing import Protocol +from typing import Protocol, TypedDict -from graphon.entities.pause_reason import PauseReason -from graphon.enums import WorkflowType from sqlalchemy.orm import Session from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun @@ -55,6 +55,16 @@ from repositories.types import ( ) +class RunsWithRelatedCountsDict(TypedDict): + runs: int + node_executions: int + offloads: int + app_logs: int + trigger_logs: int + pauses: int + pause_reasons: int + + class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): """ Protocol for service-layer WorkflowRun repository operations. @@ -333,7 +343,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): runs: Sequence[WorkflowRun], delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, - ) -> dict[str, int]: + ) -> RunsWithRelatedCountsDict: """ Delete workflow runs and their related records (node executions, offloads, app logs, trigger logs, pauses, pause reasons). @@ -400,7 +410,7 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): runs: Sequence[WorkflowRun], count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, - ) -> dict[str, int]: + ) -> RunsWithRelatedCountsDict: """ Count workflow runs and their related records (node executions, offloads, app logs, trigger logs, pauses, pause reasons) without deleting data. diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index d5c6a203b14..44735eb7699 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 413936b542b..474b200fc58 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,24 +28,23 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold -from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom from models.human_input import HumanInputForm from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun -from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.types import ( AverageInteractionStats, @@ -463,7 +462,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): runs: Sequence[WorkflowRun], delete_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, delete_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, - ) -> dict[str, int]: + ) -> RunsWithRelatedCountsDict: if not runs: return { "runs": 0, @@ -638,7 +637,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): runs: Sequence[WorkflowRun], count_node_executions: Callable[[Session, Sequence[WorkflowRun]], tuple[int, int]] | None = None, count_trigger_logs: Callable[[Session, Sequence[str]], int] | None = None, - ) -> dict[str, int]: + ) -> RunsWithRelatedCountsDict: if not runs: return { "runs": 0, @@ -744,12 +743,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Upload the state file # Create the pause record - pause_model = WorkflowPause() - pause_model.id = str(uuidv7()) - pause_model.workflow_id = workflow_run.workflow_id - pause_model.workflow_run_id = workflow_run.id - pause_model.state_object_key = state_obj_key - pause_model.created_at = naive_utc_now() + pause_model = WorkflowPause( + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_obj_key, + ) pause_reason_models = [] for reason in pause_reasons: if isinstance(reason, HumanInputRequired): diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index feba5f7eb65..67f8795d3fd 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,9 +7,6 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any -from graphon.nodes.human_input.entities import FormDefinition -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -21,6 +18,9 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/repositories/workflow_collaboration_repository.py b/api/repositories/workflow_collaboration_repository.py new file mode 100644 index 00000000000..000f80496d5 --- /dev/null +++ b/api/repositories/workflow_collaboration_repository.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import json +from typing import TypedDict + +from extensions.ext_redis import redis_client + +SESSION_STATE_TTL_SECONDS = 3600 +WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:" +WORKFLOW_LEADER_PREFIX = "workflow_leader:" +WS_SID_MAP_PREFIX = "ws_sid_map:" + + +class WorkflowSessionInfo(TypedDict): + user_id: str + username: str + avatar: str | None + sid: str + connected_at: int + + +class SidMapping(TypedDict): + workflow_id: str + user_id: str + + +class WorkflowCollaborationRepository: + def __init__(self) -> None: + self._redis = redis_client + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(redis_client={self._redis})" + + @staticmethod + def workflow_key(workflow_id: str) -> str: + return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}" + + @staticmethod + def leader_key(workflow_id: str) -> str: + return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}" + + @staticmethod + def sid_key(sid: str) -> str: + return f"{WS_SID_MAP_PREFIX}{sid}" + + @staticmethod + def _decode(value: str | bytes | None) -> str | None: + if value is None: + return None + if isinstance(value, bytes): + return value.decode("utf-8") + return value + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + workflow_key = self.workflow_key(workflow_id) + sid_key = self.sid_key(sid) + if self._redis.exists(workflow_key): + self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS) + if self._redis.exists(sid_key): + self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS) + + def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None: + workflow_key = self.workflow_key(workflow_id) + self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info)) + self._redis.set( + self.sid_key(session_info["sid"]), + json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}), + ex=SESSION_STATE_TTL_SECONDS, + ) + self.refresh_session_state(workflow_id, session_info["sid"]) + + def get_sid_mapping(self, sid: str) -> SidMapping | None: + raw = self._redis.get(self.sid_key(sid)) + if not raw: + return None + value = self._decode(raw) + if not value: + return None + try: + return json.loads(value) + except (TypeError, json.JSONDecodeError): + return None + + def delete_session(self, workflow_id: str, sid: str) -> None: + self._redis.hdel(self.workflow_key(workflow_id), sid) + self._redis.delete(self.sid_key(sid)) + + def session_exists(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.hexists(self.workflow_key(workflow_id), sid)) + + def sid_mapping_exists(self, sid: str) -> bool: + return bool(self._redis.exists(self.sid_key(sid))) + + def get_session_sids(self, workflow_id: str) -> list[str]: + raw_sids = self._redis.hkeys(self.workflow_key(workflow_id)) + decoded_sids: list[str] = [] + for sid in raw_sids: + decoded = self._decode(sid) + if decoded: + decoded_sids.append(decoded) + return decoded_sids + + def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + sessions_json = self._redis.hgetall(self.workflow_key(workflow_id)) + users: list[WorkflowSessionInfo] = [] + + for session_info_json in sessions_json.values(): + value = self._decode(session_info_json) + if not value: + continue + try: + session_info = json.loads(value) + except (TypeError, json.JSONDecodeError): + continue + + if not isinstance(session_info, dict): + continue + if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info: + continue + + users.append( + { + "user_id": str(session_info["user_id"]), + "username": str(session_info["username"]), + "avatar": session_info.get("avatar"), + "sid": str(session_info["sid"]), + "connected_at": int(session_info.get("connected_at") or 0), + } + ) + + return users + + def get_current_leader(self, workflow_id: str) -> str | None: + raw = self._redis.get(self.leader_key(workflow_id)) + return self._decode(raw) + + def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS)) + + def set_leader(self, workflow_id: str, sid: str) -> None: + self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS) + + def delete_leader(self, workflow_id: str) -> None: + self._redis.delete(self.leader_key(workflow_id)) + + def expire_leader(self, workflow_id: str) -> None: + self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS) diff --git a/api/schedule/clean_workflow_runlogs_precise.py b/api/schedule/clean_workflow_runlogs_precise.py index ebb8d529244..c5762fcdad3 100644 --- a/api/schedule/clean_workflow_runlogs_precise.py +++ b/api/schedule/clean_workflow_runlogs_precise.py @@ -4,6 +4,7 @@ import time from collections.abc import Sequence import click +from sqlalchemy import delete, select from sqlalchemy.orm import Session, sessionmaker import app @@ -113,11 +114,9 @@ def _delete_batch( try: with session.begin_nested(): workflow_run_ids = [run.id for run in workflow_runs] - message_data = ( - session.query(Message.id, Message.conversation_id) - .where(Message.workflow_run_id.in_(workflow_run_ids)) - .all() - ) + message_data = session.execute( + select(Message.id, Message.conversation_id).where(Message.workflow_run_id.in_(workflow_run_ids)) + ).all() message_id_list = [msg.id for msg in message_data] conversation_id_list = list({msg.conversation_id for msg in message_data if msg.conversation_id}) if message_id_list: @@ -132,23 +131,19 @@ def _delete_batch( SavedMessage, ] for model in message_related_models: - session.query(model).where(model.message_id.in_(message_id_list)).delete(synchronize_session=False) # type: ignore + session.execute(delete(model).where(model.message_id.in_(message_id_list))) # type: ignore # error: "DeclarativeAttributeIntercept" has no attribute "message_id". But this type is only in lib # and these 6 types all have the message_id field. - session.query(Message).where(Message.workflow_run_id.in_(workflow_run_ids)).delete( - synchronize_session=False - ) + session.execute(delete(Message).where(Message.workflow_run_id.in_(workflow_run_ids))) if conversation_id_list: - session.query(ConversationVariable).where( - ConversationVariable.conversation_id.in_(conversation_id_list) - ).delete(synchronize_session=False) - - session.query(Conversation).where(Conversation.id.in_(conversation_id_list)).delete( - synchronize_session=False + session.execute( + delete(ConversationVariable).where(ConversationVariable.conversation_id.in_(conversation_id_list)) ) + session.execute(delete(Conversation).where(Conversation.id.in_(conversation_id_list))) + def _delete_node_executions(active_session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: run_ids = [run.id for run in runs] repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 6ceb3ef856a..e242b0c667a 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -1,11 +1,11 @@ import time import click +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService from sqlalchemy import func, select import app from configs import dify_config -from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus @@ -57,6 +57,7 @@ def create_clusters(batch_size): cluster_name=new_cluster["cluster_name"], account=new_cluster["account"], password=new_cluster["password"], + qdrant_endpoint=new_cluster.get("qdrant_endpoint"), active=False, status=TidbAuthBindingStatus.CREATING, ) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 8479cdfb0c8..2cc0192a4a1 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -7,8 +7,8 @@ from sqlalchemy import select import app from configs import dify_config +from core.db.session_factory import session_factory from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service from models import Account, Tenant, TenantAccountJoin @@ -33,67 +33,68 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = db.session.scalars( - select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) - ).all() - # group by tenant_id - dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) - for dataset_auto_disable_log in dataset_auto_disable_logs: - if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) - url = f"{dify_config.CONSOLE_WEB_URL}/datasets" - for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): - features = FeatureService.get_features(tenant_id) - plan = features.billing.subscription.plan - if plan != CloudPlan.SANDBOX: - knowledge_details = [] - # check tenant - tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) - if not tenant: - continue - # check current owner - current_owner_join = db.session.scalar( - select(TenantAccountJoin) - .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") - .limit(1) - ) - if not current_owner_join: - continue - account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) - if not account: - continue + with session_factory.create_session() as session: + dataset_auto_disable_logs = session.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False)) + ).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != CloudPlan.SANDBOX: + knowledge_details = [] + # check tenant + tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id)) + if not tenant: + continue + # check current owner + current_owner_join = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) + ) + if not current_owner_join: + continue + account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) + if not account: + continue - dataset_auto_dataset_map = {} # type: ignore + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id)) + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, + language_code="en-US", + to=account.email, + template_context={ + "userName": account.email, + "knowledge_details": knowledge_details, + "url": url, + }, + ) + + # update notified to True for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( - dataset_auto_disable_log.document_id - ) - - for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id)) - if dataset: - document_count = len(document_ids) - knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") - if knowledge_details: - email_service = get_email_i18n_service() - email_service.send_email( - email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, - language_code="en-US", - to=account.email, - template_context={ - "userName": account.email, - "knowledge_details": knowledge_details, - "url": url, - }, - ) - - # update notified to True - for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - dataset_auto_disable_log.notified = True - db.session.commit() + dataset_auto_disable_log.notified = True + session.commit() end_at = time.perf_counter() logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 10003b1b975..46d1b85aa0c 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -2,11 +2,11 @@ import time from collections.abc import Sequence import click +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService from sqlalchemy import select import app from configs import dify_config -from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus diff --git a/api/services/account_service.py b/api/services/account_service.py index 29b14447305..b6554a3de7a 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,8 +8,9 @@ from hashlib import sha256 from typing import Any, TypedDict, cast from pydantic import BaseModel, TypeAdapter -from sqlalchemy import func, select -from sqlalchemy.orm import Session +from sqlalchemy import delete, func, select, update + +from core.db.session_factory import session_factory class InvitationData(TypedDict): @@ -83,6 +84,12 @@ from tasks.mail_reset_password_task import ( logger = logging.getLogger(__name__) +class InvitationDetailDict(TypedDict): + account: Account + data: InvitationData + tenant: Tenant + + def _try_join_enterprise_default_workspace(account_id: str) -> None: """Best-effort join to enterprise default workspace.""" if not dify_config.ENTERPRISE_ENABLED: @@ -105,6 +112,14 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: + # Phase-bound token metadata for the change-email flow. Tokens carry the + # current phase so that downstream endpoints can enforce proper progression + CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase" + CHANGE_EMAIL_PHASE_OLD = "old_email" + CHANGE_EMAIL_PHASE_OLD_VERIFIED = "old_email_verified" + CHANGE_EMAIL_PHASE_NEW = "new_email" + CHANGE_EMAIL_PHASE_NEW_VERIFIED = "new_email_verified" + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1) email_code_login_rate_limiter = RateLimiter( @@ -144,22 +159,26 @@ class AccountService: @staticmethod def load_user(user_id: str) -> None | Account: - account = db.session.query(Account).filter_by(id=user_id).first() + account = db.session.get(Account, user_id) if not account: return None if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") - current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first() + current_tenant = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True) + .limit(1) + ) if current_tenant: account.set_tenant_id(current_tenant.tenant_id) else: - available_ta = ( - db.session.query(TenantAccountJoin) - .filter_by(account_id=account.id) + available_ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == account.id) .order_by(TenantAccountJoin.id.asc()) - .first() + .limit(1) ) if not available_ta: return None @@ -195,7 +214,7 @@ class AccountService: def authenticate(email: str, password: str, invite_token: str | None = None) -> Account: """authenticate account with email and password""" - account = db.session.query(Account).filter_by(email=email).first() + account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) if not account: raise AccountPasswordError("Invalid email or password.") @@ -371,8 +390,10 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: AccountIntegrate | None = ( - db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + account_integrate: AccountIntegrate | None = db.session.scalar( + select(AccountIntegrate) + .where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider) + .limit(1) ) if account_integrate: @@ -416,7 +437,9 @@ class AccountService: def update_account_email(account: Account, email: str) -> Account: """Update account email""" account.email = email - account_integrate = db.session.query(AccountIntegrate).filter_by(account_id=account.id).first() + account_integrate = db.session.scalar( + select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1) + ) if account_integrate: db.session.delete(account_integrate) db.session.add(account) @@ -561,13 +584,20 @@ class AccountService: raise ValueError("Email must be provided.") if not phase: raise ValueError("phase must be provided.") + if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW): + raise ValueError("phase must be one of old_email or new_email.") if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60)) - code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) + code, token = cls.generate_change_email_token( + account_email, + account, + old_email=old_email, + additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase}, + ) send_change_mail_task.delay( language=language, @@ -786,19 +816,19 @@ class AccountService: return token @staticmethod - def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + def get_account_by_email_with_case_fallback(email: str) -> Account | None: """ Retrieve an account by email and fall back to the lowercase email if the original lookup fails. This keeps backward compatibility for older records that stored uppercase emails while the rest of the system gradually normalizes new inputs. """ - query_session = session or db.session - account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account + with session_factory.create_session() as session: + account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() + if account or email == email.lower(): + return account - return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: @@ -818,7 +848,7 @@ class AccountService: ) ) - account = db.session.query(Account).where(Account.email == email).first() + account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) if not account: return None @@ -1018,7 +1048,7 @@ class AccountService: @staticmethod def check_email_unique(email: str) -> bool: - return db.session.query(Account).filter_by(email=email).first() is None + return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None class TenantService: @@ -1061,11 +1091,11 @@ class TenantService: @staticmethod def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False): """Check if user have a workspace or not""" - available_ta = ( - db.session.query(TenantAccountJoin) - .filter_by(account_id=account.id) + available_ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.account_id == account.id) .order_by(TenantAccountJoin.id.asc()) - .first() + .limit(1) ) if available_ta: @@ -1096,7 +1126,11 @@ class TenantService: logger.error("Tenant %s has already an owner.", tenant.id) raise Exception("Tenant already has an owner.") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .limit(1) + ) if ta: ta.role = TenantAccountRole(role) else: @@ -1111,11 +1145,12 @@ class TenantService: @staticmethod def get_join_tenants(account: Account) -> list[Tenant]: """Get account join tenants""" - return ( - db.session.query(Tenant) - .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) - .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) - .all() + return list( + db.session.scalars( + select(Tenant) + .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) + .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + ).all() ) @staticmethod @@ -1125,7 +1160,11 @@ class TenantService: if not tenant: raise TenantNotFoundError("Tenant not found.") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .limit(1) + ) if ta: tenant.role = ta.role else: @@ -1140,23 +1179,25 @@ class TenantService: if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = ( - db.session.query(TenantAccountJoin) + tenant_account_join = db.session.scalar( + select(TenantAccountJoin) .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) .where( TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id == tenant_id, Tenant.status == TenantStatus.NORMAL, ) - .first() + .limit(1) ) if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - db.session.query(TenantAccountJoin).where( - TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id - ).update({"current": False}) + db.session.execute( + update(TenantAccountJoin) + .where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id) + .values(current=False) + ) tenant_account_join.current = True # Set the current tenant for the account account.set_tenant_id(tenant_account_join.tenant_id) @@ -1165,8 +1206,8 @@ class TenantService: @staticmethod def get_tenant_members(tenant: Tenant) -> list[Account]: """Get tenant members""" - query = ( - db.session.query(Account, TenantAccountJoin.role) + stmt = ( + select(Account, TenantAccountJoin.role) .select_from(Account) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .where(TenantAccountJoin.tenant_id == tenant.id) @@ -1175,7 +1216,7 @@ class TenantService: # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in query: + for account, role in db.session.execute(stmt): account.role = role updated_accounts.append(account) @@ -1184,8 +1225,8 @@ class TenantService: @staticmethod def get_dataset_operator_members(tenant: Tenant) -> list[Account]: """Get dataset admin members""" - query = ( - db.session.query(Account, TenantAccountJoin.role) + stmt = ( + select(Account, TenantAccountJoin.role) .select_from(Account) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .where(TenantAccountJoin.tenant_id == tenant.id) @@ -1195,7 +1236,7 @@ class TenantService: # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in query: + for account, role in db.session.execute(stmt): account.role = role updated_accounts.append(account) @@ -1208,26 +1249,31 @@ class TenantService: raise ValueError("all roles must be TenantAccountRole") return ( - db.session.query(TenantAccountJoin) - .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])) - .first() + db.session.scalar( + select(TenantAccountJoin) + .where( + TenantAccountJoin.tenant_id == tenant.id, + TenantAccountJoin.role.in_([role.value for role in roles]), + ) + .limit(1) + ) is not None ) @staticmethod def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None: """Get the role of the current account for a given tenant""" - join = ( - db.session.query(TenantAccountJoin) + join = db.session.scalar( + select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) - .first() + .limit(1) ) return TenantAccountRole(join.role) if join else None @staticmethod def get_tenant_count() -> int: """Get tenant count""" - return cast(int, db.session.query(func.count(Tenant.id)).scalar()) + return cast(int, db.session.scalar(select(func.count(Tenant.id)))) @staticmethod def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str): @@ -1244,7 +1290,11 @@ class TenantService: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first() + ta_operator = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id) + .limit(1) + ) if not ta_operator or ta_operator.role not in perms[action]: raise NoPermissionError(f"No permission to {action} member.") @@ -1262,7 +1312,11 @@ class TenantService: TenantService.check_member_permission(tenant, operator, account, "remove") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .limit(1) + ) if not ta: raise MemberNotInTenantError("Member not in tenant.") @@ -1277,7 +1331,12 @@ class TenantService: should_delete_account = False if account.status == AccountStatus.PENDING: # autoflush flushes ta deletion before this query, so 0 means no remaining joins - remaining_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).count() + remaining_joins = ( + db.session.scalar( + select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id) + ) + or 0 + ) if remaining_joins == 0: db.session.delete(account) should_delete_account = True @@ -1312,8 +1371,10 @@ class TenantService: """Update member role""" TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first() + target_member_join = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id) + .limit(1) ) if not target_member_join: @@ -1324,8 +1385,10 @@ class TenantService: if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) ) if current_owner_join: current_owner_join.role = TenantAccountRole.ADMIN @@ -1384,10 +1447,10 @@ class RegisterService: db.session.add(dify_setup) db.session.commit() except Exception as e: - db.session.query(DifySetup).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Account).delete() - db.session.query(Tenant).delete() + db.session.execute(delete(DifySetup)) + db.session.execute(delete(TenantAccountJoin)) + db.session.execute(delete(Account)) + db.session.execute(delete(Tenant)) db.session.commit() logger.exception("Setup account failed, email: %s, name: %s", email, name) @@ -1469,8 +1532,7 @@ class RegisterService: check_workspace_member_invite_permission(tenant.id) - with Session(db.engine) as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: TenantService.check_member_permission(tenant, inviter, None, "add") @@ -1488,7 +1550,11 @@ class RegisterService: TenantService.switch_tenant(account, tenant.id) else: TenantService.check_member_permission(tenant, inviter, account, "add") - ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .limit(1) + ) if not ta: TenantService.create_tenant_member(tenant, account, role) @@ -1540,26 +1606,23 @@ class RegisterService: @classmethod def get_invitation_if_token_valid( cls, workspace_id: str | None, email: str | None, token: str - ) -> dict[str, Any] | None: + ) -> InvitationDetailDict | None: invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None - tenant = ( - db.session.query(Tenant) - .where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") - .first() + tenant = db.session.scalar( + select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1) ) if not tenant: return None - tenant_account = ( - db.session.query(Account, TenantAccountJoin.role) + tenant_account = db.session.execute( + select(Account, TenantAccountJoin.role) .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) - .first() - ) + ).first() if not tenant_account: return None @@ -1605,7 +1668,7 @@ class RegisterService: @classmethod def get_invitation_with_case_fallback( cls, workspace_id: str | None, email: str | None, token: str - ) -> dict[str, Any] | None: + ) -> InvitationDetailDict | None: invitation = cls.get_invitation_if_token_valid(workspace_id, email, token) if invitation or not email or email == email.lower(): return invitation diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index f2ffa3b170c..5d136e73930 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,4 +1,5 @@ import copy +from typing import Any, TypedDict from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, @@ -15,9 +16,18 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( from models.model import AppMode +class AdvancedPromptTemplateArgs(TypedDict): + """Expected shape of the args dict passed to AdvancedPromptTemplateService.get_prompt.""" + + app_mode: str + model_mode: str + model_name: str + has_context: str + + class AdvancedPromptTemplateService: @classmethod - def get_prompt(cls, args: dict): + def get_prompt(cls, args: AdvancedPromptTemplateArgs) -> dict[str, Any]: app_mode = args["app_mode"] model_mode = args["model_mode"] model_name = args["model_name"] @@ -29,30 +39,41 @@ class AdvancedPromptTemplateService: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str): + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]: context_prompt = copy.deepcopy(CONTEXT) - if app_mode == AppMode.CHAT: - if model_mode == "completion": - return cls.get_completion_prompt( - copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt - ) - elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - elif app_mode == AppMode.COMPLETION: - if model_mode == "completion": - return cls.get_completion_prompt( - copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt - ) - elif model_mode == "chat": - return cls.get_chat_prompt( - copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt - ) + match app_mode: + case AppMode.CHAT: + match model_mode: + case "completion": + return cls.get_completion_prompt( + copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) + case "chat": + return cls.get_chat_prompt( + copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + case _: + pass + case AppMode.COMPLETION: + match model_mode: + case "completion": + return cls.get_completion_prompt( + copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) + case "chat": + return cls.get_chat_prompt( + copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + case _: + pass + case _: + pass # default return empty dict return {} @classmethod - def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str): + def get_completion_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]: if has_context == "true": prompt_template["completion_prompt_config"]["prompt"]["text"] = ( context + prompt_template["completion_prompt_config"]["prompt"]["text"] @@ -61,7 +82,7 @@ class AdvancedPromptTemplateService: return prompt_template @classmethod - def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str): + def get_chat_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]: if has_context == "true": prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] @@ -70,28 +91,41 @@ class AdvancedPromptTemplateService: return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str): + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) - if app_mode == AppMode.CHAT: - if model_mode == "completion": - return cls.get_completion_prompt( - copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt - ) - elif model_mode == "chat": - return cls.get_chat_prompt( - copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt - ) - elif app_mode == AppMode.COMPLETION: - if model_mode == "completion": - return cls.get_completion_prompt( - copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), - has_context, - baichuan_context_prompt, - ) - elif model_mode == "chat": - return cls.get_chat_prompt( - copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt - ) + match app_mode: + case AppMode.CHAT: + match model_mode: + case "completion": + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) + case "chat": + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) + case _: + pass + case AppMode.COMPLETION: + match model_mode: + case "completion": + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) + case "chat": + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) + case _: + pass + case _: + pass # default return empty dict return {} diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 8ebc87a6708..ff0882ad5c6 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,10 +1,9 @@ import logging import uuid +from typing import TypedDict import pandas as pd - -logger = logging.getLogger(__name__) -from sqlalchemy import or_, select +from sqlalchemy import delete, or_, select, update from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -22,16 +21,78 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task +logger = logging.getLogger(__name__) + + +class AnnotationJobStatusDict(TypedDict): + job_id: str + job_status: str + + +class EmbeddingModelDict(TypedDict): + embedding_provider_name: str + embedding_model_name: str + + +class AnnotationSettingDict(TypedDict): + id: str + enabled: bool + score_threshold: float + embedding_model: EmbeddingModelDict | dict + + +class AnnotationSettingDisabledDict(TypedDict): + enabled: bool + + +class EnableAnnotationArgs(TypedDict): + """Expected shape of the args dict passed to enable_app_annotation.""" + + score_threshold: float + embedding_provider_name: str + embedding_model_name: str + + +class UpsertAnnotationArgs(TypedDict, total=False): + """Expected shape of the args dict passed to up_insert_app_annotation_from_message.""" + + answer: str + content: str + message_id: str + question: str + + +class InsertAnnotationArgs(TypedDict): + """Expected shape of the args dict passed to insert_app_annotation_directly.""" + + question: str + answer: str + + +class UpdateAnnotationArgs(TypedDict, total=False): + """Expected shape of the args dict passed to update_app_annotation_directly. + + Both fields are optional at the type level; the service validates at runtime + and raises ValueError if either is missing. + """ + + answer: str + question: str + + +class UpdateAnnotationSettingArgs(TypedDict): + """Expected shape of the args dict passed to update_app_annotation_setting.""" + + score_threshold: float + class AppAnnotationService: @classmethod - def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: + def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: @@ -41,9 +102,12 @@ class AppAnnotationService: if answer is None: raise ValueError("Either 'answer' or 'content' must be provided") - if args.get("message_id"): - message_id = str(args["message_id"]) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first() + raw_message_id = args.get("message_id") + if raw_message_id: + message_id = str(raw_message_id) + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") @@ -64,15 +128,18 @@ class AppAnnotationService: account_id=current_user.id, ) else: - question = args.get("question") - if not question: + maybe_question = args.get("question") + if not maybe_question: raise ValueError("'question' is required when 'message_id' is not provided") + question = maybe_question annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id) db.session.add(annotation) db.session.commit() - annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) + ) assert current_tenant_id is not None if annotation_setting: add_annotation_to_index_task.delay( @@ -85,7 +152,7 @@ class AppAnnotationService: return annotation @classmethod - def enable_app_annotation(cls, args: dict, app_id: str): + def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict: enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: @@ -109,7 +176,7 @@ class AppAnnotationService: return {"job_id": job_id, "job_status": "waiting"} @classmethod - def disable_app_annotation(cls, app_id: str): + def disable_app_annotation(cls, app_id: str) -> AnnotationJobStatusDict: _, current_tenant_id = current_account_with_tenant() disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}" cache_result = redis_client.get(disable_app_annotation_key) @@ -128,10 +195,8 @@ class AppAnnotationService: def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: @@ -170,20 +235,17 @@ class AppAnnotationService: """ # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotations = ( - db.session.query(MessageAnnotation) + annotations = db.session.scalars( + select(MessageAnnotation) .where(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc()) - .all() - ) + ).all() # Sanitize CSV-injectable fields to prevent formula injection for annotation in annotations: @@ -197,13 +259,11 @@ class AppAnnotationService: return annotations @classmethod - def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: + def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: @@ -219,7 +279,9 @@ class AppAnnotationService: db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) + ) if annotation_setting: add_annotation_to_index_task.delay( annotation.id, @@ -231,19 +293,17 @@ class AppAnnotationService: return annotation @classmethod - def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): + def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str): # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + annotation = db.session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") @@ -252,13 +312,17 @@ class AppAnnotationService: if question is None: raise ValueError("'question' is required") - annotation.content = args["answer"] + answer = args.get("answer") + if answer is None: + raise ValueError("'answer' is required") + + annotation.content = answer annotation.question = question db.session.commit() # if annotation reply is enabled , add annotation to index - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if app_annotation_setting: @@ -276,16 +340,14 @@ class AppAnnotationService: def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + annotation = db.session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") @@ -301,8 +363,8 @@ class AppAnnotationService: db.session.commit() # if annotation reply is enabled , delete annotation index - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if app_annotation_setting: @@ -314,22 +376,19 @@ class AppAnnotationService: def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]): # get app info _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") # Fetch annotations and their settings in a single query - annotations_to_delete = ( - db.session.query(MessageAnnotation, AppAnnotationSetting) + annotations_to_delete = db.session.execute( + select(MessageAnnotation, AppAnnotationSetting) .outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id) .where(MessageAnnotation.id.in_(annotation_ids)) - .all() - ) + ).all() if not annotations_to_delete: return {"deleted_count": 0} @@ -338,9 +397,9 @@ class AppAnnotationService: annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete] # Step 2: Bulk delete hit histories in a single query - db.session.query(AppAnnotationHitHistory).where( - AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete) - ).delete(synchronize_session=False) + db.session.execute( + delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)) + ) # Step 3: Trigger async tasks for search index deletion for annotation, annotation_setting in annotations_to_delete: @@ -350,11 +409,10 @@ class AppAnnotationService: ) # Step 4: Bulk delete annotations in a single query - deleted_count = ( - db.session.query(MessageAnnotation) - .where(MessageAnnotation.id.in_(annotation_ids_to_delete)) - .delete(synchronize_session=False) + delete_result = db.session.execute( + delete(MessageAnnotation).where(MessageAnnotation.id.in_(annotation_ids_to_delete)) ) + deleted_count = getattr(delete_result, "rowcount", 0) db.session.commit() return {"deleted_count": deleted_count} @@ -375,10 +433,8 @@ class AppAnnotationService: # get app info current_user, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: @@ -499,16 +555,14 @@ class AppAnnotationService: def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): _, current_tenant_id = current_account_with_tenant() # get app info - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + annotation = db.session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") @@ -528,7 +582,7 @@ class AppAnnotationService: @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + annotation = db.session.get(MessageAnnotation, annotation_id) if not annotation: return None @@ -548,8 +602,10 @@ class AppAnnotationService: score: float, ): # add hit count to annotation - db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update( - {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False + db.session.execute( + update(MessageAnnotation) + .where(MessageAnnotation.id == annotation_id) + .values(hit_count=MessageAnnotation.hit_count + 1) ) annotation_hit_history = AppAnnotationHitHistory( @@ -567,19 +623,19 @@ class AppAnnotationService: db.session.commit() @classmethod - def get_app_annotation_setting_by_app_id(cls, app_id: str): + def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict: _, current_tenant_id = current_account_with_tenant() # get app info - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail if collection_binding_detail: @@ -602,25 +658,25 @@ class AppAnnotationService: return {"enabled": False} @classmethod - def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): + def update_app_annotation_setting( + cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs + ) -> AnnotationSettingDict: current_user, current_tenant_id = current_account_with_tenant() # get app info - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation_setting = ( - db.session.query(AppAnnotationSetting) + annotation_setting = db.session.scalar( + select(AppAnnotationSetting) .where( AppAnnotationSetting.app_id == app_id, AppAnnotationSetting.id == annotation_setting_id, ) - .first() + .limit(1) ) if not annotation_setting: raise NotFound("App annotation not found") @@ -653,26 +709,26 @@ class AppAnnotationService: @classmethod def clear_all_annotations(cls, app_id: str): _, current_tenant_id = current_account_with_tenant() - app = ( - db.session.query(App) - .where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal") - .first() + app = db.session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") # if annotation reply is enabled, delete annotation index - app_annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) - annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id) - for annotation in annotations_query.yield_per(100): - annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where( - AppAnnotationHitHistory.annotation_id == annotation.id - ) - for annotation_hit_history in annotation_hit_histories_query.yield_per(100): + annotations_iter = db.session.scalars( + select(MessageAnnotation).where(MessageAnnotation.app_id == app_id) + ).yield_per(100) + for annotation in annotations_iter: + hit_histories_iter = db.session.scalars( + select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation.id) + ).yield_per(100) + for annotation_hit_history in hit_histories_iter: db.session.delete(annotation_hit_history) # if annotation reply is enabled, delete annotation index diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index dd73e103746..97aaea3395c 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -3,27 +3,21 @@ import hashlib import logging import uuid from collections.abc import Mapping -from enum import StrEnum -from typing import cast +from typing import Any, cast from urllib.parse import urlparse from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version -from pydantic import BaseModel, Field +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from core.helper import ssrf_proxy from core.plugin.entities.plugin import PluginDependency from core.trigger.constants import ( @@ -36,10 +30,17 @@ from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerSc from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType from models.workflow import Workflow +from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService from services.workflow_service import WorkflowService @@ -50,19 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.6.0" - - -class ImportMode(StrEnum): - YAML_CONTENT = "yaml-content" - YAML_URL = "yaml-url" - - -class ImportStatus(StrEnum): - COMPLETED = "completed" - COMPLETED_WITH_WARNINGS = "completed-with-warnings" - PENDING = "pending" - FAILED = "failed" +CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION class Import(BaseModel): @@ -75,10 +64,6 @@ class Import(BaseModel): error: str = "" -class CheckDependenciesResult(BaseModel): - leaked_dependencies: list[PluginDependency] = Field(default_factory=list) - - def _check_version_compatibility(imported_version: str) -> ImportStatus: """Determine import status based on version comparison""" try: @@ -416,7 +401,7 @@ class AppDslService: self, *, app: App | None, - data: dict, + data: dict[str, Any], account: Account, name: str | None = None, description: str | None = None, @@ -471,7 +456,7 @@ class AppDslService: app.updated_by = account.id self._session.add(app) - self._session.commit() + self._session.flush() app_was_created.send(app, account=account) # save dependencies @@ -483,61 +468,67 @@ class AppDslService: ) # Initialize app based on mode - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - workflow_data = data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise ValueError("Missing workflow data for workflow/advanced chat app") + match app_mode: + case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW: + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for workflow/advanced chat app") - environment_variables_list = workflow_data.get("environment_variables", []) - environment_variables = [ - variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables", []) - conversation_variables = [ - variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list - ] + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) + for obj in conversation_variables_list + ] - workflow_service = WorkflowService() - current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) - if current_draft_workflow: - unique_hash = current_draft_workflow.unique_hash - else: - unique_hash = None - graph = workflow_data.get("graph", {}) - for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: - dataset_ids = node["data"].get("dataset_ids", []) - node["data"]["dataset_ids"] = [ - decrypted_id - for dataset_id in dataset_ids - if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id)) - ] - workflow_service.sync_draft_workflow( - app_model=app, - graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), - unique_hash=unique_hash, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: - # Initialize model config - model_config = data.get("model_config") - if not model_config or not isinstance(model_config, dict): - raise ValueError("Missing model_config for chat/agent-chat/completion app") - # Initialize or update model config - if not app.app_model_config: - app_model_config = AppModelConfig( - app_id=app.id, created_by=account.id, updated_by=account.id - ).from_model_config_dict(cast(AppModelConfigDict, model_config)) - app_model_config.id = str(uuid4()) - app.app_model_config_id = app_model_config.id + workflow_service = WorkflowService() + current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if ( + decrypted_id := self.decrypt_dataset_id( + encrypted_data=dataset_id, tenant_id=app.tenant_id + ) + ) + ] + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION: + # Initialize model config + model_config = data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise ValueError("Missing model_config for chat/agent-chat/completion app") + # Initialize or update model config + if not app.app_model_config: + app_model_config = AppModelConfig( + app_id=app.id, created_by=account.id, updated_by=account.id + ).from_model_config_dict(cast(AppModelConfigDict, model_config)) + app_model_config.id = str(uuid4()) + app.app_model_config_id = app_model_config.id - self._session.add(app_model_config) - app_model_config_was_updated.send(app, app_model_config=app_model_config) - else: - raise ValueError("Invalid app mode") + self._session.add(app_model_config) + app_model_config_was_updated.send(app, app_model_config=app_model_config) + case _: + raise ValueError("Invalid app mode") return app @classmethod @@ -577,7 +568,7 @@ class AppDslService: @classmethod def _append_workflow_export_data( - cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None + cls, *, export_data: dict[str, Any], app_model: App, include_secret: bool, workflow_id: str | None = None ): """ Append workflow export data @@ -630,7 +621,7 @@ class AppDslService: ] @classmethod - def _append_model_config_export_data(cls, export_data: dict, app_model: App): + def _append_model_config_export_data(cls, export_data: dict[str, Any], app_model: App): """ Append model config export data :param export_data: export data diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 40013f2b669..5e8c7aa3379 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -4,7 +4,7 @@ import logging import threading import uuid from collections.abc import Callable, Generator, Mapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -88,7 +88,7 @@ class AppGenerateService: def generate( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, @@ -116,139 +116,143 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) - if app_model.mode == AppMode.COMPLETION: - return rate_limit.generate( - CompletionAppGenerator.convert_to_event_stream( - CompletionAppGenerator().generate( - app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming - ), - ), - request_id=request_id, - ) - elif app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent: - return rate_limit.generate( - AgentChatAppGenerator.convert_to_event_stream( - AgentChatAppGenerator().generate( - app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming - ), - ), - request_id, - ) - elif app_model.mode == AppMode.CHAT: - return rate_limit.generate( - ChatAppGenerator.convert_to_event_stream( - ChatAppGenerator().generate( - app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming - ), - ), - request_id=request_id, - ) - elif app_model.mode == AppMode.ADVANCED_CHAT: - workflow_id = args.get("workflow_id") - workflow = cls._get_workflow(app_model, invoke_from, workflow_id) - - if streaming: - # Streaming mode: subscribe to SSE and enqueue the execution on first subscriber - with rate_limit_context(rate_limit, request_id): - payload = AppExecutionParams.new( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=True, - call_depth=0, - ) - payload_json = payload.model_dump_json() - - def on_subscribe(): - workflow_based_app_execution_task.delay(payload_json) - - on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) - generator = AdvancedChatAppGenerator() + effective_mode = ( + AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode + ) + match effective_mode: + case AppMode.COMPLETION: return rate_limit.generate( - generator.convert_to_event_stream( - generator.retrieve_events( - AppMode.ADVANCED_CHAT, - payload.workflow_run_id, - on_subscribe=on_subscribe, + CompletionAppGenerator.convert_to_event_stream( + CompletionAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming ), ), request_id=request_id, ) - else: - # Blocking mode: run synchronously and return JSON instead of SSE - # Keep behaviour consistent with WORKFLOW blocking branch. - advanced_generator = AdvancedChatAppGenerator() + case AppMode.AGENT_CHAT: return rate_limit.generate( - advanced_generator.convert_to_event_stream( - advanced_generator.generate( + AgentChatAppGenerator.convert_to_event_stream( + AgentChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming + ), + ), + request_id, + ) + case AppMode.CHAT: + return rate_limit.generate( + ChatAppGenerator.convert_to_event_stream( + ChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=streaming + ), + ), + request_id=request_id, + ) + case AppMode.ADVANCED_CHAT: + workflow_id = args.get("workflow_id") + workflow = cls._get_workflow(app_model, invoke_from, workflow_id) + + if streaming: + # Streaming mode: subscribe to SSE and enqueue the execution on first subscriber + with rate_limit_context(rate_limit, request_id): + payload = AppExecutionParams.new( app_model=app_model, workflow=workflow, user=user, args=args, invoke_from=invoke_from, - workflow_run_id=str(uuid.uuid4()), - streaming=False, + streaming=True, + call_depth=0, ) - ), - request_id=request_id, - ) - elif app_model.mode == AppMode.WORKFLOW: - workflow_id = args.get("workflow_id") - workflow = cls._get_workflow(app_model, invoke_from, workflow_id) - if streaming: - with rate_limit_context(rate_limit, request_id): - payload = AppExecutionParams.new( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=True, - call_depth=0, - root_node_id=root_node_id, - workflow_run_id=str(uuid.uuid4()), + payload_json = payload.model_dump_json() + + def on_subscribe(): + workflow_based_app_execution_task.delay(payload_json) + + on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) + generator = AdvancedChatAppGenerator() + return rate_limit.generate( + generator.convert_to_event_stream( + generator.retrieve_events( + AppMode.ADVANCED_CHAT, + payload.workflow_run_id, + on_subscribe=on_subscribe, + ), + ), + request_id=request_id, ) - payload_json = payload.model_dump_json() + else: + # Blocking mode: run synchronously and return JSON instead of SSE + # Keep behaviour consistent with WORKFLOW blocking branch. + advanced_generator = AdvancedChatAppGenerator() + return rate_limit.generate( + advanced_generator.convert_to_event_stream( + advanced_generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + workflow_run_id=str(uuid.uuid4()), + streaming=False, + ) + ), + request_id=request_id, + ) + case AppMode.WORKFLOW: + workflow_id = args.get("workflow_id") + workflow = cls._get_workflow(app_model, invoke_from, workflow_id) + if streaming: + with rate_limit_context(rate_limit, request_id): + payload = AppExecutionParams.new( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=True, + call_depth=0, + root_node_id=root_node_id, + workflow_run_id=str(uuid.uuid4()), + ) + payload_json = payload.model_dump_json() - def on_subscribe(): - workflow_based_app_execution_task.delay(payload_json) + def on_subscribe(): + workflow_based_app_execution_task.delay(payload_json) - on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) + on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe) + return rate_limit.generate( + WorkflowAppGenerator.convert_to_event_stream( + MessageBasedAppGenerator.retrieve_events( + AppMode.WORKFLOW, + payload.workflow_run_id, + on_subscribe=on_subscribe, + ), + ), + request_id, + ) + + pause_config = PauseStateLayerConfig( + session_factory=session_factory.get_session_maker(), + state_owner_user_id=workflow.created_by, + ) return rate_limit.generate( WorkflowAppGenerator.convert_to_event_stream( - MessageBasedAppGenerator.retrieve_events( - AppMode.WORKFLOW, - payload.workflow_run_id, - on_subscribe=on_subscribe, + WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=False, + root_node_id=root_node_id, + call_depth=0, + pause_state_config=pause_config, ), ), request_id, ) - - pause_config = PauseStateLayerConfig( - session_factory=session_factory.get_session_maker(), - state_owner_user_id=workflow.created_by, - ) - return rate_limit.generate( - WorkflowAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - streaming=False, - root_node_id=root_node_id, - call_depth=0, - pause_state_config=pause_config, - ), - ), - request_id, - ) - else: - raise ValueError(f"Invalid app mode {app_model.mode}") + case _: + raise ValueError(f"Invalid app mode {app_model.mode}") except Exception: quota_charge.refund() rate_limit.exit(request_id) @@ -280,53 +284,83 @@ class AppGenerateService: @classmethod def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): - if app_model.mode == AppMode.ADVANCED_CHAT: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + match app_model.mode: + case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT: + raise ValueError(f"Invalid app mode {app_model.mode}") + case AppMode.ADVANCED_CHAT: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + AdvancedChatAppGenerator().single_iteration_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) ) - ) - elif app_model.mode == AppMode.WORKFLOW: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + case AppMode.WORKFLOW: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_iteration_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) ) - ) - else: - raise ValueError(f"Invalid app mode {app_model.mode}") + case AppMode.CHANNEL | AppMode.RAG_PIPELINE: + raise ValueError(f"Invalid app mode {app_model.mode}") + case _: + raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod def generate_single_loop( cls, app_model: App, user: Account, node_id: str, args: LoopNodeRunPayload, streaming: bool = True ): - if app_model.mode == AppMode.ADVANCED_CHAT: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - AdvancedChatAppGenerator().single_loop_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + match app_model.mode: + case AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT: + raise ValueError(f"Invalid app mode {app_model.mode}") + case AppMode.ADVANCED_CHAT: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + AdvancedChatAppGenerator().single_loop_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) ) - ) - elif app_model.mode == AppMode.WORKFLOW: - workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) - return AdvancedChatAppGenerator.convert_to_event_stream( - WorkflowAppGenerator().single_loop_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming + case AppMode.WORKFLOW: + workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) + return AdvancedChatAppGenerator.convert_to_event_stream( + WorkflowAppGenerator().single_loop_generate( + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, + ) ) - ) - else: - raise ValueError(f"Invalid app mode {app_model.mode}") + case AppMode.CHANNEL | AppMode.RAG_PIPELINE: + raise ValueError(f"Invalid app mode {app_model.mode}") + case _: + raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod def generate_more_like_this( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[Mapping, Generator]: + ) -> Mapping | Generator: """ Generate more like this :param app_model: app model diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 3bc30cb3232..8252de7753a 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,3 +1,5 @@ +from typing import Any + from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager @@ -6,12 +8,13 @@ from models.model import AppMode, AppModelConfigDict class AppModelConfigService: @classmethod - def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict: - if app_mode == AppMode.CHAT: - return ChatAppConfigManager.config_validate(tenant_id, config) - elif app_mode == AppMode.AGENT_CHAT: - return AgentChatAppConfigManager.config_validate(tenant_id, config) - elif app_mode == AppMode.COMPLETION: - return CompletionAppConfigManager.config_validate(tenant_id, config) - else: - raise ValueError(f"Invalid app mode: {app_mode}") + def validate_configuration(cls, tenant_id: str, config: dict[str, Any], app_mode: AppMode) -> AppModelConfigDict: + match app_mode: + case AppMode.CHAT: + return ChatAppConfigManager.config_validate(tenant_id, config) + case AppMode.AGENT_CHAT: + return AgentChatAppConfigManager.config_validate(tenant_id, config) + case AppMode.COMPLETION: + return CompletionAppConfigManager.config_validate(tenant_id, config) + case AppMode.WORKFLOW | AppMode.ADVANCED_CHAT | AppMode.CHANNEL | AppMode.RAG_PIPELINE: + raise ValueError(f"Invalid app mode: {app_mode}") diff --git a/api/services/app_service.py b/api/services/app_service.py index 87d52a3159c..a046b909b38 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -4,8 +4,6 @@ from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from configs import dify_config @@ -17,6 +15,8 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) class AppService: - def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None: + def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict[str, Any]) -> Pagination | None: """ Get app list with pagination :param user_id: user id @@ -78,7 +78,7 @@ class AppService: return app_models - def create_app(self, tenant_id: str, args: dict, account: Account) -> App: + def create_app(self, tenant_id: str, args: dict[str, Any], account: Account) -> App: """ Create app :param tenant_id: tenant id @@ -303,17 +303,22 @@ class AppService: return app - def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: + def update_app_icon( + self, app: App, icon: str, icon_background: str, icon_type: IconType | str | None = None + ) -> App: """ Update app icon :param app: App instance :param icon: new icon :param icon_background: new icon_background + :param icon_type: new icon type :return: App instance """ assert current_user is not None app.icon = icon app.icon_background = icon_background + if icon_type is not None: + app.icon_type = icon_type if isinstance(icon_type, IconType) else IconType(icon_type) app.updated_by = current_user.id app.updated_at = naive_utc_now() db.session.commit() @@ -389,7 +394,7 @@ class AppService: """ app_mode = AppMode.value_of(app_model.mode) - meta: dict = {"tool_icons": {}} + meta: dict[str, Any] = {"tool_icons": {}} if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 0842e9d3e7f..6e9d6b1c737 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,11 +5,10 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ -from graphon.graph_engine.manager import GraphEngineManager - from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 0133634e5a5..a731d5c048e 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -7,11 +7,11 @@ with support for different subscription tiers, rate limiting, and execution trac import json from datetime import UTC, datetime -from typing import Any, Union +from typing import Any from celery.result import AsyncResult from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from enums.quota_type import QuotaType from extensions.ext_database import db @@ -50,7 +50,7 @@ class AsyncWorkflowService: @classmethod def trigger_workflow_async( - cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData + cls, session: Session, user: Account | EndUser, trigger_data: TriggerData ) -> AsyncTriggerResponse: """ Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK @@ -177,7 +177,7 @@ class AsyncWorkflowService: @classmethod def reinvoke_trigger( - cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str + cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str ) -> AsyncTriggerResponse: """ Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK @@ -237,7 +237,7 @@ class AsyncWorkflowService: Returns: Trigger log as dictionary or None if not found """ - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id) @@ -263,7 +263,7 @@ class AsyncWorkflowService: Returns: List of trigger logs as dictionaries """ - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) logs = trigger_log_repo.get_recent_logs( tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset @@ -286,7 +286,7 @@ class AsyncWorkflowService: Returns: List of failed trigger logs as dictionaries """ - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session) logs = trigger_log_repo.get_failed_for_retry( tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit diff --git a/api/services/attachment_service.py b/api/services/attachment_service.py index 2bd5627d5eb..54e664e9443 100644 --- a/api/services/attachment_service.py +++ b/api/services/attachment_service.py @@ -1,6 +1,6 @@ import base64 -from sqlalchemy import Engine +from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound @@ -22,8 +22,8 @@ class AttachmentService: raise AssertionError("must be a sessionmaker or an Engine.") def get_file_base64(self, file_id: str) -> str: - upload_file = ( - self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = self._session_maker(expire_on_commit=False).scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) ) if not upload_file: raise NotFound("File not found") diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1c7027efb49..60948e652b8 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,12 +5,12 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context -from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 3282dcfb113..36b15170567 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -1,4 +1,5 @@ import json +from typing import Any from sqlalchemy import select @@ -19,7 +20,7 @@ class ApiKeyAuthService: return data_source_api_key_bindings @staticmethod - def create_provider_auth(tenant_id: str, args: dict): + def create_provider_auth(tenant_id: str, args: dict[str, Any]): auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 3c3e4aa6d25..a1362ccad6b 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -2,7 +2,7 @@ import json import logging import os from collections.abc import Sequence -from typing import Literal, TypedDict +from typing import Any, Literal, NotRequired, TypedDict import httpx from pydantic import TypeAdapter @@ -32,6 +32,103 @@ class SubscriptionPlan(TypedDict): expiration_date: int +class _BillingQuota(TypedDict): + size: int + limit: int + + +class _VectorSpaceQuota(TypedDict): + size: float + limit: int + + +class _KnowledgeRateLimit(TypedDict): + # NOTE (hj24): + # 1. Return for sandbox users but is null for other plans, it's defined but never used. + # 2. Keep it for compatibility for now, can be deprecated in future versions. + size: NotRequired[int] + # NOTE END + limit: int + + +class _BillingSubscription(TypedDict): + plan: str + interval: str + education: bool + + +class BillingInfo(TypedDict): + """Response of /subscription/info. + + NOTE (hj24): + - Fields not listed here (e.g. trigger_event, api_rate_limit) are stripped by TypeAdapter.validate_python() + - To ensure the precision, billing may convert fields like int as str, be careful when use TypeAdapter: + 1. validate_python in non-strict mode will coerce it to the expected type + 2. In strict mode, it will raise ValidationError + 3. To preserve compatibility, always keep non-strict mode here and avoid strict mode + """ + + enabled: bool + subscription: _BillingSubscription + members: _BillingQuota + apps: _BillingQuota + vector_space: _VectorSpaceQuota + knowledge_rate_limit: _KnowledgeRateLimit + documents_upload_quota: _BillingQuota + annotation_quota_limit: _BillingQuota + docs_processing: str + can_replace_logo: bool + model_load_balancing_enabled: bool + knowledge_pipeline_publish_enabled: bool + next_credit_reset_date: NotRequired[int] + + +_billing_info_adapter = TypeAdapter(BillingInfo) + + +class KnowledgeRateLimitDict(TypedDict): + limit: int + subscription_plan: str + + +class TenantFeaturePlanUsageDict(TypedDict): + result: str + history_id: str + + +class LangContentDict(TypedDict): + lang: str + title: str + subtitle: str + body: str + title_pic_url: str + + +class NotificationDict(TypedDict): + notification_id: str + contents: dict[str, LangContentDict] + frequency: Literal["once", "every_page_load"] + + +class AccountNotificationDict(TypedDict, total=False): + should_show: bool + notification: NotificationDict + shouldShow: bool + notifications: list[dict] + + +class UpsertNotificationDict(TypedDict): + notification_id: str + + +class BatchAddNotificationAccountsDict(TypedDict): + count: int + + +class DismissNotificationDict(TypedDict): + success: bool + + class BillingService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @@ -44,11 +141,11 @@ class BillingService: _PLAN_CACHE_TTL = 600 @classmethod - def get_info(cls, tenant_id: str): + def get_info(cls, tenant_id: str) -> BillingInfo: params = {"tenant_id": tenant_id} billing_info = cls._send_request("GET", "/subscription/info", params=params) - return billing_info + return _billing_info_adapter.validate_python(billing_info) @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): @@ -58,7 +155,7 @@ class BillingService: return usage_info @classmethod - def get_knowledge_rate_limit(cls, tenant_id: str): + def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict: params = {"tenant_id": tenant_id} knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params) @@ -89,7 +186,9 @@ class BillingService: return cls._send_request("GET", "/invoices", params=params) @classmethod - def update_tenant_feature_plan_usage(cls, tenant_id: str, feature_key: str, delta: int) -> dict: + def update_tenant_feature_plan_usage( + cls, tenant_id: str, feature_key: str, delta: int + ) -> TenantFeaturePlanUsageDict: """ Update tenant feature plan usage. @@ -109,7 +208,7 @@ class BillingService: ) @classmethod - def refund_tenant_feature_plan_usage(cls, history_id: str) -> dict: + def refund_tenant_feature_plan_usage(cls, history_id: str) -> TenantFeaturePlanUsageDict: """ Refund a previous usage charge. @@ -405,7 +504,7 @@ class BillingService: return tenant_whitelist @classmethod - def get_account_notification(cls, account_id: str) -> dict: + def get_account_notification(cls, account_id: str) -> AccountNotificationDict: """Return the active in-product notification for account_id, if any. Calling this endpoint also marks the notification as seen; subsequent @@ -429,20 +528,20 @@ class BillingService: @classmethod def upsert_notification( cls, - contents: list[dict], + contents: list[LangContentDict], frequency: str = "once", status: str = "active", notification_id: str | None = None, start_time: str | None = None, end_time: str | None = None, - ) -> dict: + ) -> UpsertNotificationDict: """Create or update a notification. contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str} start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional. Returns {"notification_id": str}. """ - payload: dict = { + payload: dict[str, Any] = { "contents": contents, "frequency": frequency, "status": status, @@ -456,7 +555,9 @@ class BillingService: return cls._send_request("POST", "/notifications", json=payload) @classmethod - def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict: + def batch_add_notification_accounts( + cls, notification_id: str, account_ids: list[str] + ) -> BatchAddNotificationAccountsDict: """Register target account IDs for a notification (max 1000 per call). Returns {"count": int}. @@ -468,7 +569,7 @@ class BillingService: ) @classmethod - def dismiss_notification(cls, notification_id: str, account_id: str) -> dict: + def dismiss_notification(cls, notification_id: str, account_id: str) -> DismissNotificationDict: """Mark a notification as dismissed for an account. Returns {"success": bool}. diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 1c128524ad4..dcc93b4b0f1 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,14 +6,14 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app -from graphon.model_runtime.utils.encoders import jsonable_encoder -from sqlalchemy import select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, @@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs: for model, table_name in related_tables: # Query records related to expired messages - records = ( - session.query(model) - .where( + records = session.scalars( + select(model).where( model.message_id.in_(batch_message_ids), # type: ignore ) - .all() - ) + ).all() if len(records) == 0: continue @@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs: except Exception: logger.exception("Failed to save %s records", table_name) - session.query(model).where( - model.id.in_(record_ids), # type: ignore - ).delete(synchronize_session=False) + session.execute( + delete(model) + .where( + model.id.in_(record_ids), # type: ignore + ) + .execution_options(synchronize_session=False) + ) click.echo( click.style( @@ -120,16 +122,15 @@ class ClearFreePlanTenantExpiredLogs: apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all() app_ids = [app.id for app in apps] while True: - with Session(db.engine).no_autoflush as session: - messages = ( - session.query(Message) + with sessionmaker(bind=db.engine, autoflush=False).begin() as session: + messages = session.scalars( + select(Message) .where( Message.app_id.in_(app_ids), Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(messages) == 0: break @@ -147,12 +148,11 @@ class ClearFreePlanTenantExpiredLogs: message_ids = [message.id for message in messages] # delete messages - session.query(Message).where( - Message.id.in_(message_ids), - ).delete(synchronize_session=False) + session.execute( + delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False) + ) cls._clear_message_related_tables(session, tenant_id, message_ids) - session.commit() click.echo( click.style( @@ -161,16 +161,15 @@ class ClearFreePlanTenantExpiredLogs: ) while True: - with Session(db.engine).no_autoflush as session: - conversations = ( - session.query(Conversation) + with sessionmaker(bind=db.engine, autoflush=False).begin() as session: + conversations = session.scalars( + select(Conversation) .where( Conversation.app_id.in_(app_ids), Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(conversations) == 0: break @@ -187,10 +186,11 @@ class ClearFreePlanTenantExpiredLogs: ) conversation_ids = [conversation.id for conversation in conversations] - session.query(Conversation).where( - Conversation.id.in_(conversation_ids), - ).delete(synchronize_session=False) - session.commit() + session.execute( + delete(Conversation) + .where(Conversation.id.in_(conversation_ids)) + .execution_options(synchronize_session=False) + ) click.echo( click.style( @@ -294,16 +294,15 @@ class ClearFreePlanTenantExpiredLogs: break while True: - with Session(db.engine).no_autoflush as session: - workflow_app_logs = ( - session.query(WorkflowAppLog) + with sessionmaker(bind=db.engine, autoflush=False).begin() as session: + workflow_app_logs = session.scalars( + select(WorkflowAppLog) .where( WorkflowAppLog.tenant_id == tenant_id, WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(workflow_app_logs) == 0: break @@ -323,10 +322,11 @@ class ClearFreePlanTenantExpiredLogs: workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] # delete workflow app logs - session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete( - synchronize_session=False + session.execute( + delete(WorkflowAppLog) + .where(WorkflowAppLog.id.in_(workflow_app_log_ids)) + .execution_options(synchronize_session=False) ) - session.commit() click.echo( click.style( @@ -346,8 +346,8 @@ class ClearFreePlanTenantExpiredLogs: started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) current_time = started_at - with Session(db.engine) as session: - total_tenant_count = session.query(Tenant.id).count() + with sessionmaker(db.engine).begin() as session: + total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0 click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) @@ -398,7 +398,7 @@ class ClearFreePlanTenantExpiredLogs: # Initial interval of 1 day, will be dynamically adjusted based on tenant count interval = datetime.timedelta(days=1) # Process tenants in this batch - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: # Calculate tenant count in next batch with current interval # Try different intervals until we find one with a reasonable tenant count test_intervals = [ @@ -412,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs: tenant_count = 0 for test_interval in test_intervals: tenant_count = ( - session.query(Tenant.id) - .where(Tenant.created_at.between(current_time, current_time + test_interval)) - .count() + session.scalar( + select(func.count(Tenant.id)).where( + Tenant.created_at.between(current_time, current_time + test_interval) + ) + ) + or 0 ) if tenant_count <= 100: interval = test_interval @@ -436,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs: batch_end = min(current_time + interval, ended_at) - rs = ( - session.query(Tenant.id) + rs = session.execute( + select(Tenant.id) .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index f5085af59b6..ee8a1c4edd5 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,7 +3,6 @@ import logging from collections.abc import Callable, Sequence from typing import Any -from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -13,6 +12,7 @@ from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from factories import variable_factory +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 95a8951951c..287d513f480 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ -from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 7826695366b..2d210db1212 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -1,7 +1,7 @@ import logging from sqlalchemy import select, update -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from core.errors.error import QuotaExceededError @@ -29,14 +29,15 @@ class CreditPoolService: @classmethod def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None: """get tenant credit pool""" - return db.session.scalar( - select(TenantCreditPool) - .where( - TenantCreditPool.tenant_id == tenant_id, - TenantCreditPool.pool_type == pool_type, + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + return session.scalar( + select(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, + ) + .limit(1) ) - .limit(1) - ) @classmethod def check_credits_available( @@ -71,7 +72,7 @@ class CreditPoolService: actual_credits = min(credits_required, pool.remaining_credits) try: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: stmt = ( update(TenantCreditPool) .where( @@ -81,7 +82,6 @@ class CreditPoolService: .values(quota_used=TenantCreditPool.quota_used + actual_credits) ) session.execute(stmt) - session.commit() except Exception: logger.exception("Failed to deduct credits for tenant %s", tenant_id) raise QuotaExceededError("Failed to deduct credits") diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 53bc51d4574..894cb056879 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,15 +7,12 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal, cast +from typing import Any, Literal, TypedDict, cast import sqlalchemy as sa -from graphon.file import helpers as file_helpers -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError -from sqlalchemy import exists, func, select -from sqlalchemy.orm import Session +from sqlalchemy import delete, exists, func, select, update +from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config @@ -31,6 +28,9 @@ from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user @@ -107,6 +107,16 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde logger = logging.getLogger(__name__) +class ProcessRulesDict(TypedDict): + mode: str + rules: dict[str, Any] + + +class AutoDisableLogsDict(TypedDict): + document_ids: list[str] + count: int + + class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): @@ -114,9 +124,11 @@ class DatasetService: if user: # get permitted dataset ids - dataset_permission = ( - db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all() - ) + dataset_permission = db.session.scalars( + select(DatasetPermission).where( + DatasetPermission.account_id == user.id, DatasetPermission.tenant_id == tenant_id + ) + ).all() permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None if user.current_role == TenantAccountRole.DATASET_OPERATOR: @@ -180,21 +192,20 @@ class DatasetService: return datasets.items, datasets.total @staticmethod - def get_process_rules(dataset_id): + def get_process_rules(dataset_id) -> ProcessRulesDict: # get the latest process rule - dataset_process_rule = ( - db.session.query(DatasetProcessRule) + dataset_process_rule = db.session.execute( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) - .one_or_none() - ) + ).scalar_one_or_none() if dataset_process_rule: mode = dataset_process_rule.mode - rules = dataset_process_rule.rules_dict + rules = dataset_process_rule.rules_dict or {} else: - mode = DocumentService.DEFAULT_RULES["mode"] - rules = DocumentService.DEFAULT_RULES["rules"] + mode = str(DocumentService.DEFAULT_RULES["mode"]) + rules = dict(DocumentService.DEFAULT_RULES.get("rules") or {}) return {"mode": mode, "rules": rules} @staticmethod @@ -222,10 +233,10 @@ class DatasetService: embedding_model_provider: str | None = None, embedding_model_name: str | None = None, retrieval_model: RetrievalModel | None = None, - summary_index_setting: dict | None = None, + summary_index_setting: dict[str, Any] | None = None, ): # check if dataset name already exists - if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): + if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None if indexing_technique == IndexTechniqueType.HIGH_QUALITY: @@ -300,17 +311,17 @@ class DatasetService: ): if rag_pipeline_dataset_create_entity.name: # check if dataset name already exists - if ( - db.session.query(Dataset) - .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) - .first() + if db.session.scalar( + select(Dataset) + .where(Dataset.name == rag_pipeline_dataset_create_entity.name, Dataset.tenant_id == tenant_id) + .limit(1) ): raise DatasetNameDuplicateError( f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists." ) else: # generate a random name as Untitled 1 2 3 ... - datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all() + datasets = db.session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all() names = [dataset.name for dataset in datasets] rag_pipeline_dataset_create_entity.name = generate_incremental_name( names, @@ -344,7 +355,7 @@ class DatasetService: @staticmethod def get_dataset(dataset_id) -> Dataset | None: - dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first() + dataset: Dataset | None = db.session.get(Dataset, dataset_id) return dataset @staticmethod @@ -466,14 +477,14 @@ class DatasetService: @staticmethod def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str): - dataset = ( - db.session.query(Dataset) + dataset = db.session.scalar( + select(Dataset) .where( Dataset.id != dataset_id, Dataset.name == name, Dataset.tenant_id == tenant_id, ) - .first() + .limit(1) ) return dataset is not None @@ -517,6 +528,8 @@ class DatasetService: raise ValueError("External knowledge id is required.") if not external_knowledge_api_id: raise ValueError("External knowledge api id is required.") + # Ensure the referenced external API template exists and belongs to the dataset tenant. + ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id, dataset.tenant_id) # Update metadata fields dataset.updated_by = user.id if user else None dataset.updated_at = naive_utc_now() @@ -540,22 +553,22 @@ class DatasetService: external_knowledge_id: External knowledge identifier external_knowledge_api_id: External knowledge API identifier """ - with Session(db.engine) as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + with sessionmaker(db.engine).begin() as session: + external_knowledge_binding = session.scalar( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1) ) if not external_knowledge_binding: raise ValueError("External knowledge binding not found.") - # Update binding if values have changed - if ( - external_knowledge_binding.external_knowledge_id != external_knowledge_id - or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id - ): - external_knowledge_binding.external_knowledge_id = external_knowledge_id - external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id - db.session.add(external_knowledge_binding) + # Update binding if values have changed + if ( + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id + ): + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + session.add(external_knowledge_binding) @staticmethod def _update_internal_dataset(dataset, data, user): @@ -596,7 +609,7 @@ class DatasetService: filtered_data["icon_info"] = data.get("icon_info") # Update dataset in database - db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) + db.session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data)) db.session.commit() # Reload dataset to get updated values @@ -631,7 +644,7 @@ class DatasetService: if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE: return - pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first() + pipeline = db.session.get(Pipeline, dataset.pipeline_id) if not pipeline: return @@ -1138,8 +1151,10 @@ class DatasetService: if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: # For partial team permission, user needs explicit permission or be the creator if dataset.created_by != user.id: - user_permission = ( - db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first() + user_permission = db.session.scalar( + select(DatasetPermission) + .where(DatasetPermission.dataset_id == dataset.id, DatasetPermission.account_id == user.id) + .limit(1) ) if not user_permission: logger.debug("User %s does not have permission to access dataset %s", user.id, dataset.id) @@ -1161,7 +1176,9 @@ class DatasetService: elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( dp.dataset_id == dataset.id - for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all() + for dp in db.session.scalars( + select(DatasetPermission).where(DatasetPermission.account_id == user.id) + ).all() ): raise NoPermissionError("You do not have permission to access this dataset.") @@ -1175,12 +1192,11 @@ class DatasetService: @staticmethod def get_related_apps(dataset_id: str): - return ( - db.session.query(AppDatasetJoin) + return db.session.scalars( + select(AppDatasetJoin) .where(AppDatasetJoin.dataset_id == dataset_id) - .order_by(db.desc(AppDatasetJoin.created_at)) - .all() - ) + .order_by(AppDatasetJoin.created_at.desc()) + ).all() @staticmethod def update_dataset_api_status(dataset_id: str, status: bool): @@ -1195,7 +1211,7 @@ class DatasetService: db.session.commit() @staticmethod - def get_dataset_auto_disable_logs(dataset_id: str): + def get_dataset_auto_disable_logs(dataset_id: str) -> AutoDisableLogsDict: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None features = FeatureService.get_features(current_user.current_tenant_id) @@ -1396,8 +1412,8 @@ class DocumentService: @staticmethod def get_document(dataset_id: str, document_id: str | None = None) -> Document | None: if document_id: - document = ( - db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = db.session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) return document else: @@ -1440,15 +1456,17 @@ class DocumentService: document_id_list: list[str] = [str(document_id) for document_id in document_ids] with session_factory.create_session() as session: - updated_count = ( - session.query(Document) - .filter( + result = session.execute( + update(Document) + .where( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) - .update({Document.need_summary: need_summary}, synchronize_session=False) + .values(need_summary=need_summary) + .execution_options(synchronize_session=False) ) + updated_count = result.rowcount # type: ignore[union-attr,attr-defined] session.commit() logger.info( "Updated need_summary to %s for %d documents in dataset %s", @@ -1626,7 +1644,7 @@ class DocumentService: @staticmethod def get_document_by_id(document_id: str) -> Document | None: - document = db.session.query(Document).where(Document.id == document_id).first() + document = db.session.get(Document, document_id) return document @@ -1691,7 +1709,7 @@ class DocumentService: @staticmethod def get_document_file_detail(file_id: str): - file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none() + file_detail = db.session.get(UploadFile, file_id) return file_detail @staticmethod @@ -1765,9 +1783,11 @@ class DocumentService: document.name = name db.session.add(document) if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict: - db.session.query(UploadFile).where( - UploadFile.id == document.data_source_info_dict["upload_file_id"] - ).update({UploadFile.name: name}) + db.session.execute( + update(UploadFile) + .where(UploadFile.id == document.data_source_info_dict["upload_file_id"]) + .values(name=name) + ) db.session.commit() @@ -1854,8 +1874,8 @@ class DocumentService: @staticmethod def get_documents_position(dataset_id): - document = ( - db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() + document = db.session.scalar( + select(Document).where(Document.dataset_id == dataset_id).order_by(Document.position.desc()).limit(1) ) if document: return document.position + 1 @@ -2012,28 +2032,28 @@ class DocumentService: if not knowledge_config.data_source.info_list.file_info_list: raise ValueError("File source info is required") upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids - files = ( - db.session.query(UploadFile) - .where( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id.in_(upload_file_list), - ) - .all() + files = list( + db.session.scalars( + select(UploadFile).where( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id.in_(upload_file_list), + ) + ).all() ) if len(files) != len(set(upload_file_list)): raise FileNotExistsError("One or more files not found.") file_names = [file.name for file in files] - db_documents = ( - db.session.query(Document) - .where( - Document.dataset_id == dataset.id, - Document.tenant_id == current_user.current_tenant_id, - Document.data_source_type == DataSourceType.UPLOAD_FILE, - Document.enabled == True, - Document.name.in_(file_names), - ) - .all() + db_documents = list( + db.session.scalars( + select(Document).where( + Document.dataset_id == dataset.id, + Document.tenant_id == current_user.current_tenant_id, + Document.data_source_type == DataSourceType.UPLOAD_FILE, + Document.enabled == True, + Document.name.in_(file_names), + ) + ).all() ) documents_map = {document.name: document for document in db_documents} for file in files: @@ -2079,15 +2099,15 @@ class DocumentService: raise ValueError("No notion info list found.") exist_page_ids = [] exist_document = {} - documents = ( - db.session.query(Document) - .filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type=DataSourceType.NOTION_IMPORT, - enabled=True, - ) - .all() + documents = list( + db.session.scalars( + select(Document).where( + Document.dataset_id == dataset.id, + Document.tenant_id == current_user.current_tenant_id, + Document.data_source_type == DataSourceType.NOTION_IMPORT, + Document.enabled == True, + ) + ).all() ) if documents: for document in documents: @@ -2473,7 +2493,7 @@ class DocumentService: data_source_type: str, document_form: str, document_language: str, - data_source_info: dict, + data_source_info: dict[str, Any], created_from: str, position: int, account: Account, @@ -2518,14 +2538,15 @@ class DocumentService: assert isinstance(current_user, Account) documents_count = ( - db.session.query(Document) - .where( - Document.completed_at.isnot(None), - Document.enabled == True, - Document.archived == False, - Document.tenant_id == current_user.current_tenant_id, + db.session.scalar( + select(func.count(Document.id)).where( + Document.completed_at.isnot(None), + Document.enabled == True, + Document.archived == False, + Document.tenant_id == current_user.current_tenant_id, + ) ) - .count() + or 0 ) return documents_count @@ -2575,10 +2596,10 @@ class DocumentService: raise ValueError("No file info list found.") upload_file_list = document_data.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: - file = ( - db.session.query(UploadFile) + file = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) # raise error if file not found @@ -2595,8 +2616,8 @@ class DocumentService: notion_info_list = document_data.data_source.info_list.notion_info_list for notion_info in notion_info_list: workspace_id = notion_info.workspace_id - data_source_binding = ( - db.session.query(DataSourceOauthBinding) + data_source_binding = db.session.scalar( + select(DataSourceOauthBinding) .where( sa.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -2605,7 +2626,7 @@ class DocumentService: DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ) - .first() + .limit(1) ) if not data_source_binding: raise ValueError("Data source binding not found.") @@ -2650,8 +2671,10 @@ class DocumentService: db.session.commit() # update document segment - db.session.query(DocumentSegment).filter_by(document_id=document.id).update( - {DocumentSegment.status: SegmentStatus.RE_SEGMENT} + db.session.execute( + update(DocumentSegment) + .where(DocumentSegment.document_id == document.id) + .values(status=SegmentStatus.RE_SEGMENT) ) db.session.commit() # trigger async task @@ -2803,6 +2826,10 @@ class DocumentService: knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) + if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL: + if not knowledge_config.process_rule.rules.parent_mode: + knowledge_config.process_rule.rules.parent_mode = "paragraph" + if not knowledge_config.process_rule.rules.segmentation: raise ValueError("Process rule segmentation is required") @@ -2823,7 +2850,7 @@ class DocumentService: raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod - def estimate_args_validate(cls, args: dict): + def estimate_args_validate(cls, args: dict[str, Any]): if "info_list" not in args or not args["info_list"]: raise ValueError("Data source info is required") @@ -3105,7 +3132,7 @@ class DocumentService: class SegmentService: @classmethod - def segment_create_args_validate(cls, args: dict, document: Document): + def segment_create_args_validate(cls, args: dict[str, Any], document: Document): if document.doc_form == IndexStructureType.QA_INDEX: if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") @@ -3122,7 +3149,7 @@ class SegmentService: raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}") @classmethod - def create_segment(cls, args: dict, document: Document, dataset: Dataset): + def create_segment(cls, args: dict[str, Any], document: Document, dataset: Dataset): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -3143,10 +3170,8 @@ class SegmentService: lock_name = f"add_segment_lock_document_id_{document.id}" try: with redis_client.lock(lock_name, timeout=600): - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document.id) - .scalar() + max_position = db.session.scalar( + select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id) ) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, @@ -3198,7 +3223,7 @@ class SegmentService: segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) db.session.commit() - segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() + segment = db.session.get(DocumentSegment, segment_document.id) return segment except LockNotOwnedError: pass @@ -3221,10 +3246,8 @@ class SegmentService: model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) - max_position = ( - db.session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document.id) - .scalar() + max_position = db.session.scalar( + select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id) ) pre_segment_data_list = [] segment_data_list = [] @@ -3369,11 +3392,7 @@ class SegmentService: else: raise ValueError("The knowledge base index technique is not high quality!") # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == document.dataset_process_rule_id) - .first() - ) + processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id) if processing_rule: VectorService.generate_child_chunks( segment, document, dataset, embedding_model_instance, processing_rule, True @@ -3391,13 +3410,13 @@ class SegmentService: # Query existing summary from database from models.dataset import DocumentSegmentSummary - existing_summary = ( - db.session.query(DocumentSegmentSummary) + existing_summary = db.session.scalar( + select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment.id, DocumentSegmentSummary.dataset_id == dataset.id, ) - .first() + .limit(1) ) # Check if summary has changed @@ -3473,11 +3492,7 @@ class SegmentService: else: raise ValueError("The knowledge base index technique is not high quality!") # get the process rule - processing_rule = ( - db.session.query(DatasetProcessRule) - .where(DatasetProcessRule.id == document.dataset_process_rule_id) - .first() - ) + processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id) if processing_rule: VectorService.generate_child_chunks( segment, document, dataset, embedding_model_instance, processing_rule, True @@ -3489,13 +3504,13 @@ class SegmentService: if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary - existing_summary = ( - db.session.query(DocumentSegmentSummary) + existing_summary = db.session.scalar( + select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment.id, DocumentSegmentSummary.dataset_id == dataset.id, ) - .first() + .limit(1) ) if args.summary is None: @@ -3561,7 +3576,7 @@ class SegmentService: segment.status = SegmentStatus.ERROR segment.error = str(e) db.session.commit() - new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() + new_segment = db.session.get(DocumentSegment, segment.id) if not new_segment: raise ValueError("new_segment is not found") return new_segment @@ -3581,15 +3596,14 @@ class SegmentService: # Get child chunk IDs before parent segment is deleted child_node_ids = [] if segment.index_node_id: - child_chunks = ( - db.session.query(ChildChunk.index_node_id) - .where( - ChildChunk.segment_id == segment.id, - ChildChunk.dataset_id == dataset.id, - ) - .all() + child_node_ids = list( + db.session.scalars( + select(ChildChunk.index_node_id).where( + ChildChunk.segment_id == segment.id, + ChildChunk.dataset_id == dataset.id, + ) + ).all() ) - child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] delete_segment_from_index_task.delay( [segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids @@ -3608,17 +3622,14 @@ class SegmentService: # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return - segments_info = ( - db.session.query(DocumentSegment) - .with_entities(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count) - .where( + segments_info = db.session.execute( + select(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, DocumentSegment.tenant_id == current_user.current_tenant_id, ) - .all() - ) + ).all() if not segments_info: return @@ -3630,15 +3641,16 @@ class SegmentService: # Get child chunk IDs before parent segments are deleted child_node_ids = [] if index_node_ids: - child_chunks = ( - db.session.query(ChildChunk.index_node_id) - .where( - ChildChunk.segment_id.in_(segment_db_ids), - ChildChunk.dataset_id == dataset.id, - ) - .all() - ) - child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] + child_node_ids = [ + nid + for nid in db.session.scalars( + select(ChildChunk.index_node_id).where( + ChildChunk.segment_id.in_(segment_db_ids), + ChildChunk.dataset_id == dataset.id, + ) + ).all() + if nid + ] # Start async cleanup with both parent and child node IDs if index_node_ids or child_node_ids: @@ -3654,7 +3666,7 @@ class SegmentService: db.session.add(document) # Delete database records - db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete() + db.session.execute(delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))) db.session.commit() @classmethod @@ -3728,15 +3740,13 @@ class SegmentService: with redis_client.lock(lock_name, timeout=20): index_node_id = str(uuid.uuid4()) index_node_hash = helper.generate_text_hash(content) - max_position = ( - db.session.query(func.max(ChildChunk.position)) - .where( + max_position = db.session.scalar( + select(func.max(ChildChunk.position)).where( ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, ChildChunk.segment_id == segment.id, ) - .scalar() ) child_chunk = ChildChunk( tenant_id=current_user.current_tenant_id, @@ -3896,10 +3906,8 @@ class SegmentService: @classmethod def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None: """Get a child chunk by its ID.""" - result = ( - db.session.query(ChildChunk) - .where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) - .first() + result = db.session.scalar( + select(ChildChunk).where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).limit(1) ) return result if isinstance(result, ChildChunk) else None @@ -3934,10 +3942,10 @@ class SegmentService: @classmethod def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None: """Get a segment by its ID.""" - result = ( - db.session.query(DocumentSegment) + result = db.session.scalar( + select(DocumentSegment) .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) - .first() + .limit(1) ) return result if isinstance(result, DocumentSegment) else None @@ -3980,15 +3988,15 @@ class DatasetCollectionBindingService: def get_dataset_collection_binding( cls, provider_name: str, model_name: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) + dataset_collection_binding = db.session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == provider_name, DatasetCollectionBinding.model_name == model_name, DatasetCollectionBinding.type == collection_type, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -4006,13 +4014,13 @@ class DatasetCollectionBindingService: def get_dataset_collection_binding_by_id_and_type( cls, collection_binding_id: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) + dataset_collection_binding = db.session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: raise ValueError("Dataset collection binding not found") @@ -4034,7 +4042,7 @@ class DatasetPermissionService: @classmethod def update_partial_member_list(cls, tenant_id, dataset_id, user_list): try: - db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() + db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id)) permissions = [] for user in user_list: permission = DatasetPermission( @@ -4070,7 +4078,7 @@ class DatasetPermissionService: @classmethod def clear_partial_member_list(cls, dataset_id): try: - db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete() + db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id)) db.session.commit() except Exception as e: db.session.rollback() diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 06f83a18f7a..416bc8cef9d 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,8 +3,8 @@ import time from collections.abc import Mapping from typing import Any -from graphon.model_runtime.entities.provider_entities import FormType -from sqlalchemy.orm import Session +from sqlalchemy import delete, func, select, update +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE @@ -17,6 +17,7 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService @@ -52,13 +53,14 @@ class DatasourceProviderService: """ remove oauth custom client params """ - with Session(db.engine) as session: - session.query(DatasourceOauthTenantParamConfig).filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - ).delete() - session.commit() + with sessionmaker(bind=db.engine).begin() as session: + session.execute( + delete(DatasourceOauthTenantParamConfig).where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, + ) + ) def decrypt_datasource_provider_credentials( self, @@ -108,17 +110,23 @@ class DatasourceProviderService: """ get credential by id """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: if credential_id: - datasource_provider = ( - session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() + datasource_provider = session.scalar( + select(DatasourceProvider) + .where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id) + .limit(1) ) else: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + datasource_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .first() + .limit(1) ) if not datasource_provider: return {} @@ -155,7 +163,6 @@ class DatasourceProviderService: datasource_provider=datasource_provider, ) datasource_provider.expires_at = refreshed_credentials.expires_at - session.commit() return self.decrypt_datasource_provider_credentials( tenant_id=tenant_id, @@ -173,13 +180,16 @@ class DatasourceProviderService: """ get all datasource credentials by provider """ - with Session(db.engine) as session: - datasource_providers = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + with sessionmaker(bind=db.engine).begin() as session: + datasource_providers = session.scalars( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .all() - ) + ).all() if not datasource_providers: return [] current_user = get_current_user() @@ -223,7 +233,6 @@ class DatasourceProviderService: provider=provider, ) real_credentials_list.append(real_credentials) - session.commit() return real_credentials_list @@ -233,16 +242,16 @@ class DatasourceProviderService: """ update datasource provider name """ - with Session(db.engine) as session: - target_provider = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - id=credential_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + with sessionmaker(bind=db.engine).begin() as session: + target_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == credential_id, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if target_provider is None: raise ValueError("provider not found") @@ -252,20 +261,19 @@ class DatasourceProviderService: # check name is exist if ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=name, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == name, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, + ) ) - .count() - > 0 - ): + or 0 + ) > 0: raise ValueError("Authorization name is already exists") target_provider.name = name - session.commit() return def set_default_datasource_provider( @@ -274,39 +282,43 @@ class DatasourceProviderService: """ set default datasource provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # get provider - target_provider = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - id=credential_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + target_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == credential_id, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if target_provider is None: raise ValueError("provider not found") # clear default provider - session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=target_provider.provider, - plugin_id=target_provider.plugin_id, - is_default=True, - ).update({"is_default": False}) + session.execute( + update(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == target_provider.provider, + DatasourceProvider.plugin_id == target_provider.plugin_id, + DatasourceProvider.is_default.is_(True), + ) + .values(is_default=False) + .execution_options(synchronize_session=False) + ) # set new default provider target_provider.is_default = True - session.commit() return {"result": "success"} def setup_oauth_custom_client_params( self, tenant_id: str, datasource_provider_id: DatasourceProviderID, - client_params: dict | None, + client_params: dict[str, Any] | None, enabled: bool | None, ): """ @@ -314,15 +326,15 @@ class DatasourceProviderService: """ if client_params is None and enabled is None: return - with Session(db.engine) as session: - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + with sessionmaker(bind=db.engine).begin() as session: + tenant_oauth_client_params = session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if not tenant_oauth_client_params: @@ -340,7 +352,7 @@ class DatasourceProviderService: original_params = ( encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {} ) - new_params: dict = { + new_params: dict[str, Any] = { key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) for key, value in client_params.items() } @@ -348,7 +360,6 @@ class DatasourceProviderService: if enabled is not None: tenant_oauth_client_params.enabled = enabled - session.commit() def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool: """ @@ -356,9 +367,14 @@ class DatasourceProviderService: """ with Session(db.engine).no_autoflush as session: return ( - session.query(DatasourceOauthParamConfig) - .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id) - .first() + session.scalar( + select(DatasourceOauthParamConfig) + .where( + DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id, + ) + .limit(1) + ) is not None ) @@ -367,16 +383,16 @@ class DatasourceProviderService: check if tenant oauth params is enabled """ return ( - db.session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - enabled=True, + db.session.scalar( + select(func.count(DatasourceOauthTenantParamConfig.id)).where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, + DatasourceOauthTenantParamConfig.enabled == True, + ) ) - .count() - > 0 - ) + or 0 + ) > 0 def get_tenant_oauth_client( self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False @@ -384,14 +400,14 @@ class DatasourceProviderService: """ get tenant oauth client """ - tenant_oauth_client_params = ( - db.session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + tenant_oauth_client_params = db.session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if tenant_oauth_client_params: encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) @@ -428,15 +444,15 @@ class DatasourceProviderService: plugin_id = datasource_provider_id.plugin_id with Session(db.engine).no_autoflush as session: # get tenant oauth client params - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - enabled=True, + tenant_oauth_client_params = session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == provider, + DatasourceOauthTenantParamConfig.plugin_id == plugin_id, + DatasourceOauthTenantParamConfig.enabled.is_(True), ) - .first() + .limit(1) ) if tenant_oauth_client_params: encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) @@ -448,8 +464,13 @@ class DatasourceProviderService: is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier) if is_verified: # fallback to system oauth client params - oauth_client_params = ( - session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + oauth_client_params = session.scalar( + select(DatasourceOauthParamConfig) + .where( + DatasourceOauthParamConfig.provider == provider, + DatasourceOauthParamConfig.plugin_id == plugin_id, + ) + .limit(1) ) if oauth_client_params: return oauth_client_params.system_credentials @@ -460,15 +481,13 @@ class DatasourceProviderService: def generate_next_datasource_provider_name( session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType ) -> str: - db_providers = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, + db_providers = session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, ) - .all() - ) + ).all() return generate_incremental_name( [provider.name for provider in db_providers], f"{credential_type.get_name()}", @@ -481,17 +500,19 @@ class DatasourceProviderService: provider_id: DatasourceProviderID, avatar_url: str | None, expire_at: int, - credentials: dict, + credentials: dict[str, Any], credential_id: str, ) -> None: """ update datasource oauth provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}" with redis_client.lock(lock, timeout=20): - target_provider = ( - session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() + target_provider = session.scalar( + select(DatasourceProvider) + .where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id) + .limit(1) ) if target_provider is None: raise ValueError("provider not found") @@ -501,25 +522,28 @@ class DatasourceProviderService: db_provider_name = target_provider.name else: name_conflict = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=db_provider_name, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - auth_type=CredentialType.OAUTH2.value, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == db_provider_name, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + DatasourceProvider.auth_type == CredentialType.OAUTH2.value, + ) ) - .count() + or 0 ) if name_conflict > 0: db_provider_name = generate_incremental_name( [ provider.name - for provider in session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ) + for provider in session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + ) + ).all() ], db_provider_name, ) @@ -534,7 +558,6 @@ class DatasourceProviderService: target_provider.expires_at = expire_at target_provider.encrypted_credentials = credentials target_provider.avatar_url = avatar_url or target_provider.avatar_url - session.commit() def add_datasource_oauth_provider( self, @@ -543,13 +566,13 @@ class DatasourceProviderService: provider_id: DatasourceProviderID, avatar_url: str | None, expire_at: int, - credentials: dict, + credentials: dict[str, Any], ) -> None: """ add datasource oauth provider """ credential_type = CredentialType.OAUTH2 - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}" with redis_client.lock(lock, timeout=60): db_provider_name = name @@ -562,25 +585,27 @@ class DatasourceProviderService: ) else: if ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=db_provider_name, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - auth_type=credential_type.value, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == db_provider_name, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + DatasourceProvider.auth_type == credential_type.value, + ) ) - .count() - > 0 - ): + or 0 + ) > 0: db_provider_name = generate_incremental_name( [ provider.name - for provider in session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ) + for provider in session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + ) + ).all() ], db_provider_name, ) @@ -603,14 +628,13 @@ class DatasourceProviderService: expires_at=expire_at, ) session.add(datasource_provider) - session.commit() def add_datasource_api_key_provider( self, name: str | None, tenant_id: str, provider_id: DatasourceProviderID, - credentials: dict, + credentials: dict[str, Any], ) -> None: """ validate datasource provider credentials. @@ -622,7 +646,7 @@ class DatasourceProviderService: provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" with redis_client.lock(lock, timeout=20): db_provider_name = name or self.generate_next_datasource_provider_name( @@ -634,11 +658,16 @@ class DatasourceProviderService: # check name is exist if ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name) - .count() - > 0 - ): + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.plugin_id == plugin_id, + DatasourceProvider.provider == provider_name, + DatasourceProvider.name == db_provider_name, + ) + ) + or 0 + ) > 0: raise ValueError("Authorization name is already exists") try: @@ -669,7 +698,6 @@ class DatasourceProviderService: encrypted_credentials=credentials, ) session.add(datasource_provider) - session.commit() def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]: """ @@ -707,24 +735,27 @@ class DatasourceProviderService: :return: """ # Get all provider configurations of the current workspace - datasource_providers: list[DatasourceProvider] = ( - db.session.query(DatasourceProvider) + datasource_providers: list[DatasourceProvider] = list( + db.session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + ).all() + ) + if not datasource_providers: + return [] + copy_credentials_list = [] + default_provider = db.session.execute( + select(DatasourceProvider.id) .where( DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.provider == provider, DatasourceProvider.plugin_id == plugin_id, ) - .all() - ) - if not datasource_providers: - return [] - copy_credentials_list = [] - default_provider = ( - db.session.query(DatasourceProvider.id) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .first() - ) + ).first() default_provider_id = default_provider.id if default_provider else None for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials @@ -880,14 +911,14 @@ class DatasourceProviderService: :return: """ # Get all provider configurations of the current workspace - datasource_providers: list[DatasourceProvider] = ( - db.session.query(DatasourceProvider) - .where( - DatasourceProvider.tenant_id == tenant_id, - DatasourceProvider.provider == provider, - DatasourceProvider.plugin_id == plugin_id, - ) - .all() + datasource_providers: list[DatasourceProvider] = list( + db.session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + ).all() ) if not datasource_providers: return [] @@ -916,28 +947,44 @@ class DatasourceProviderService: return copy_credentials_list def update_datasource_credentials( - self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None + self, + tenant_id: str, + auth_id: str, + provider: str, + plugin_id: str, + credentials: dict[str, Any] | None, + name: str | None, ) -> None: """ update datasource credentials. """ - with Session(db.engine) as session: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) - .first() + with sessionmaker(bind=db.engine).begin() as session: + datasource_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == auth_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .limit(1) ) if not datasource_provider: raise ValueError("Datasource provider not found") # update name if name and name != datasource_provider.name: if ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id) - .count() - > 0 - ): + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == name, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + ) + or 0 + ) > 0: raise ValueError("Authorization name is already exists") datasource_provider.name = name @@ -976,7 +1023,6 @@ class DatasourceProviderService: encrypted_credentials[key] = value datasource_provider.encrypted_credentials = encrypted_credentials - session.commit() def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: """ @@ -987,10 +1033,15 @@ class DatasourceProviderService: :param plugin_id: plugin id :return: """ - datasource_provider = ( - db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) - .first() + datasource_provider = db.session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == auth_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .limit(1) ) if datasource_provider: db.session.delete(datasource_provider) diff --git a/api/services/end_user_service.py b/api/services/end_user_service.py index 326f46780d9..749d8dbc30e 100644 --- a/api/services/end_user_service.py +++ b/api/services/end_user_service.py @@ -1,8 +1,8 @@ import logging from collections.abc import Mapping -from sqlalchemy import case -from sqlalchemy.orm import Session +from sqlalchemy import case, select +from sqlalchemy.orm import sessionmaker from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db @@ -24,15 +24,15 @@ class EndUserService: when an end-user ID is known. """ - with Session(db.engine, expire_on_commit=False) as session: - return ( - session.query(EndUser) + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + return session.scalar( + select(EndUser) .where( EndUser.id == end_user_id, EndUser.tenant_id == tenant_id, EndUser.app_id == app_id, ) - .first() + .limit(1) ) @classmethod @@ -54,11 +54,11 @@ class EndUserService: if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility # This single query approach is more efficient than separate queries - end_user = ( - session.query(EndUser) + end_user = session.scalar( + select(EndUser) .where( EndUser.tenant_id == tenant_id, EndUser.app_id == app_id, @@ -68,7 +68,7 @@ class EndUserService: # Prioritize records with matching type (0 = match, 1 = no match) case((EndUser.type == type, 0), else_=1) ) - .first() + .limit(1) ) if end_user: @@ -82,7 +82,6 @@ class EndUserService: user_id, ) end_user.type = type - session.commit() else: # Create new end user if none exists end_user = EndUser( @@ -94,7 +93,6 @@ class EndUserService: external_user_id=user_id, ) session.add(end_user) - session.commit() return end_user @@ -135,17 +133,17 @@ class EndUserService: if not unique_app_ids: return result - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Fetch existing end users for all target apps in a single query - existing_end_users: list[EndUser] = ( - session.query(EndUser) - .where( - EndUser.tenant_id == tenant_id, - EndUser.app_id.in_(unique_app_ids), - EndUser.session_id == user_id, - EndUser.type == type, - ) - .all() + existing_end_users: list[EndUser] = list( + session.scalars( + select(EndUser).where( + EndUser.tenant_id == tenant_id, + EndUser.app_id.in_(unique_app_ids), + EndUser.session_id == user_id, + EndUser.type == type, + ) + ).all() ) found_app_ids: set[str] = set() @@ -174,7 +172,6 @@ class EndUserService: ) session.add_all(new_end_users) - session.commit() for eu in new_end_users: result[eu.app_id] = eu diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index d4be36305e5..23571f2d7d8 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -1,23 +1,15 @@ -import enum import logging from pydantic import BaseModel from configs import dify_config +from core.entities import PluginCredentialType from services.enterprise.base import EnterprisePluginManagerRequest from services.errors.base import BaseServiceError logger = logging.getLogger(__name__) -class PluginCredentialType(enum.Enum): - MODEL = 0 # must be 0 for API contract compatibility - TOOL = 1 # must be 1 for API contract compatibility - - def to_number(self): - return self.value - - class CheckCredentialPolicyComplianceRequest(BaseModel): dify_credential_id: str provider: str diff --git a/api/services/entities/auth_entities.py b/api/services/entities/auth_entities.py new file mode 100644 index 00000000000..e3fb2496925 --- /dev/null +++ b/api/services/entities/auth_entities.py @@ -0,0 +1,47 @@ +from enum import StrEnum, auto + +from pydantic import BaseModel, Field, field_validator + +from libs.helper import EmailStr +from libs.password import valid_password + + +class LoginFailureReason(StrEnum): + """Bounded reason codes for failed login audit logs.""" + + ACCOUNT_BANNED = auto() + ACCOUNT_IN_FREEZE = auto() + ACCOUNT_NOT_FOUND = auto() + EMAIL_CODE_EMAIL_MISMATCH = auto() + INVALID_CREDENTIALS = auto() + INVALID_EMAIL_CODE = auto() + INVALID_EMAIL_CODE_TOKEN = auto() + INVALID_INVITATION_EMAIL = auto() + LOGIN_RATE_LIMITED = auto() + + +class LoginPayloadBase(BaseModel): + email: EmailStr + password: str + + +class ForgotPasswordSendPayload(BaseModel): + email: EmailStr + language: str | None = None + + +class ForgotPasswordCheckPayload(BaseModel): + email: EmailStr + code: str + token: str = Field(min_length=1) + + +class ForgotPasswordResetPayload(BaseModel): + token: str = Field(min_length=1) + new_password: str + password_confirm: str + + @field_validator("new_password", "password_confirm") + @classmethod + def validate_password(cls, value: str) -> str: + return valid_password(value) diff --git a/api/services/entities/dsl_entities.py b/api/services/entities/dsl_entities.py new file mode 100644 index 00000000000..05baa51fbd3 --- /dev/null +++ b/api/services/entities/dsl_entities.py @@ -0,0 +1,21 @@ +from enum import StrEnum + +from pydantic import BaseModel, Field + +from core.plugin.entities.plugin import PluginDependency + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class CheckDependenciesResult(BaseModel): + leaked_dependencies: list[PluginDependency] = Field(default_factory=list) diff --git a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py index c9fb1c9e215..110dbe5a5e7 100644 --- a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py +++ b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Any, Literal, Union from pydantic import BaseModel @@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel): class ExternalKnowledgeApiSetting(BaseModel): url: str request_method: str - headers: dict | None = None - params: dict | None = None + headers: dict[str, Any] | None = None + params: dict[str, Any] | None = None diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 66309f0e597..b1fe3528616 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,15 +1,15 @@ -from enum import StrEnum -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, field_validator +from core.rag.entities import Rule from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod -class ParentMode(StrEnum): - FULL_DOC = "full-doc" - PARAGRAPH = "paragraph" +class RerankingModel(BaseModel): + reranking_provider_name: str | None = None + reranking_model_name: str | None = None class NotionIcon(BaseModel): @@ -53,34 +53,11 @@ class DataSource(BaseModel): info_list: InfoList -class PreProcessingRule(BaseModel): - id: str - enabled: bool - - -class Segmentation(BaseModel): - separator: str = "\n" - max_tokens: int - chunk_overlap: int = 0 - - -class Rule(BaseModel): - pre_processing_rules: list[PreProcessingRule] | None = None - segmentation: Segmentation | None = None - parent_mode: Literal["full-doc", "paragraph"] | None = None - subchunk_segmentation: Segmentation | None = None - - class ProcessRule(BaseModel): mode: Literal["automatic", "custom", "hierarchical"] rules: Rule | None = None -class RerankingModel(BaseModel): - reranking_provider_name: str | None = None - reranking_model_name: str | None = None - - class WeightVectorSetting(BaseModel): vector_weight: float embedding_provider_name: str @@ -110,7 +87,7 @@ class RetrievalModel(BaseModel): class MetaDataConfig(BaseModel): doc_type: str - doc_metadata: dict + doc_metadata: dict[str, Any] class KnowledgeConfig(BaseModel): @@ -120,7 +97,7 @@ class KnowledgeConfig(BaseModel): data_source: DataSource | None = None process_rule: ProcessRule | None = None retrieval_model: RetrievalModel | None = None - summary_index_setting: dict | None = None + summary_index_setting: dict[str, Any] | None = None doc_form: str = "text_model" doc_language: str = "English" embedding_model: str | None = None diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index 041ae4edba2..7fb7ed12bf6 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -1,10 +1,29 @@ -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, field_validator +from core.rag.entities import KeywordSetting, VectorSetting from core.rag.retrieval.retrieval_methods import RetrievalMethod +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + + reranking_provider_name: str | None = "" + reranking_model_name: str | None = "" + + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + + vector_setting: VectorSetting | None + keyword_setting: KeywordSetting | None + + class IconInfo(BaseModel): icon: str icon_background: str | None = None @@ -27,59 +46,6 @@ class RagPipelineDatasetCreateEntity(BaseModel): yaml_content: str | None = None -class RerankingModelConfig(BaseModel): - """ - Reranking Model Config. - """ - - reranking_provider_name: str | None = "" - reranking_model_name: str | None = "" - - -class VectorSetting(BaseModel): - """ - Vector Setting. - """ - - vector_weight: float - embedding_provider_name: str - embedding_model_name: str - - -class KeywordSetting(BaseModel): - """ - Keyword Setting. - """ - - keyword_weight: float - - -class WeightedScoreConfig(BaseModel): - """ - Weighted score Config. - """ - - vector_setting: VectorSetting | None - keyword_setting: KeywordSetting | None - - -class EmbeddingSetting(BaseModel): - """ - Embedding Setting. - """ - - embedding_provider_name: str - embedding_model_name: str - - -class EconomySetting(BaseModel): - """ - Economy Setting. - """ - - keyword_number: int - - class RetrievalSetting(BaseModel): """ Retrieval Setting. @@ -95,16 +61,6 @@ class RetrievalSetting(BaseModel): weights: WeightedScoreConfig | None = None -class IndexMethod(BaseModel): - """ - Knowledge Index Setting. - """ - - indexing_technique: Literal["high_quality", "economy"] - embedding_setting: EmbeddingSetting - economy_setting: EconomySetting - - class KnowledgeConfiguration(BaseModel): """ Knowledge Base Configuration. @@ -117,7 +73,7 @@ class KnowledgeConfiguration(BaseModel): keyword_number: int | None = 10 retrieval_model: RetrievalSetting # add summary index setting - summary_index_setting: dict | None = None + summary_index_setting: dict[str, Any] | None = None @field_validator("embedding_model_provider", mode="before") @classmethod diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index a944ef6acdd..6679c08ebd7 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,15 +1,6 @@ from collections.abc import Sequence from enum import StrEnum -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -24,6 +15,15 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from models.provider import ProviderType diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 9a522ece528..60b457ecd03 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,16 +1,16 @@ import json from copy import deepcopy -from typing import Any, Union, cast +from typing import Any, cast from urllib.parse import urlparse import httpx -from graphon.nodes.http_request.exc import InvalidHttpMethodError -from sqlalchemy import select +from sqlalchemy import func, select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy -from core.rag.entities.metadata_entities import MetadataCondition +from core.rag.entities import MetadataFilteringCondition from extensions.ext_database import db +from graphon.nodes.http_request.exc import InvalidHttpMethodError from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, @@ -47,7 +47,7 @@ class ExternalDatasetService: return external_knowledge_apis.items, external_knowledge_apis.total @classmethod - def validate_api_list(cls, api_settings: dict): + def validate_api_list(cls, api_settings: dict[str, Any]): if not api_settings: raise ValueError("api list is empty") if not api_settings.get("endpoint"): @@ -56,7 +56,7 @@ class ExternalDatasetService: raise ValueError("api_key is required") @staticmethod - def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis: + def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict[str, Any]) -> ExternalKnowledgeApis: settings = args.get("settings") if settings is None: raise ValueError("settings is required") @@ -75,7 +75,7 @@ class ExternalDatasetService: return external_knowledge_api @staticmethod - def check_endpoint_and_api_key(settings: dict): + def check_endpoint_and_api_key(settings: dict[str, Any]): if "endpoint" not in settings or not settings["endpoint"]: raise ValueError("endpoint is required") if "api_key" not in settings or not settings["api_key"]: @@ -103,8 +103,10 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str, tenant_id: str) -> ExternalKnowledgeApis: - external_knowledge_api: ExternalKnowledgeApis | None = ( - db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar( + select(ExternalKnowledgeApis) + .where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id) + .limit(1) ) if external_knowledge_api is None: raise ValueError("api template not found") @@ -112,8 +114,10 @@ class ExternalDatasetService: @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api: ExternalKnowledgeApis | None = ( - db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + external_knowledge_api: ExternalKnowledgeApis | None = db.session.scalar( + select(ExternalKnowledgeApis) + .where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id) + .limit(1) ) if external_knowledge_api is None: raise ValueError("api template not found") @@ -132,8 +136,10 @@ class ExternalDatasetService: @staticmethod def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str): - external_knowledge_api = ( - db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + external_knowledge_api = db.session.scalar( + select(ExternalKnowledgeApis) + .where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id) + .limit(1) ) if external_knowledge_api is None: raise ValueError("api template not found") @@ -142,29 +148,43 @@ class ExternalDatasetService: db.session.commit() @staticmethod - def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: + def external_knowledge_api_use_check(external_knowledge_api_id: str, tenant_id: str) -> tuple[bool, int]: + """ + Return usage for an external knowledge API within a single tenant. + + The caller already scopes access by tenant, so this query must do the + same; otherwise the endpoint becomes a cross-tenant UUID oracle. + """ count = ( - db.session.query(ExternalKnowledgeBindings) - .filter_by(external_knowledge_api_id=external_knowledge_api_id) - .count() + db.session.scalar( + select(func.count(ExternalKnowledgeBindings.id)).where( + ExternalKnowledgeBindings.external_knowledge_api_id == external_knowledge_api_id, + ExternalKnowledgeBindings.tenant_id == tenant_id, + ) + ) + or 0 ) - if count > 0: - return True, count - return False, 0 + return count > 0, count @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding: ExternalKnowledgeBindings | None = ( - db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() + external_knowledge_binding: ExternalKnowledgeBindings | None = db.session.scalar( + select(ExternalKnowledgeBindings) + .where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id) + .limit(1) ) if not external_knowledge_binding: raise ValueError("external knowledge binding not found") return external_knowledge_binding @staticmethod - def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): - external_knowledge_api = ( - db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + def document_create_args_validate( + tenant_id: str, external_knowledge_api_id: str, process_parameter: dict[str, Any] + ): + external_knowledge_api = db.session.scalar( + select(ExternalKnowledgeApis) + .where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id) + .limit(1) ) if external_knowledge_api is None or external_knowledge_api.settings is None: raise ValueError("api template not found") @@ -177,9 +197,7 @@ class ExternalDatasetService: raise ValueError(f"{parameter.get('name')} is required") @staticmethod - def process_external_api( - settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]] - ) -> httpx.Response: + def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response: """ do http request depending on api bundle """ @@ -206,7 +224,7 @@ class ExternalDatasetService: return response @staticmethod - def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]: + def assembling_headers(authorization: Authorization, headers: dict[str, Any] | None = None) -> dict[str, Any]: authorization = deepcopy(authorization) if headers: headers = deepcopy(headers) @@ -232,18 +250,23 @@ class ExternalDatasetService: return headers @staticmethod - def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting: + def get_external_knowledge_api_settings(settings: dict[str, Any]) -> ExternalKnowledgeApiSetting: return ExternalKnowledgeApiSetting.model_validate(settings) @staticmethod - def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: + def create_external_dataset(tenant_id: str, user_id: str, args: dict[str, Any]) -> Dataset: # check if dataset name already exists - if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first(): + if db.session.scalar( + select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1) + ): raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") - external_knowledge_api = ( - db.session.query(ExternalKnowledgeApis) - .filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id) - .first() + external_knowledge_api = db.session.scalar( + select(ExternalKnowledgeApis) + .where( + ExternalKnowledgeApis.id == args.get("external_knowledge_api_id"), + ExternalKnowledgeApis.tenant_id == tenant_id, + ) + .limit(1) ) if external_knowledge_api is None: @@ -283,19 +306,24 @@ class ExternalDatasetService: tenant_id: str, dataset_id: str, query: str, - external_retrieval_parameters: dict, - metadata_condition: MetadataCondition | None = None, + external_retrieval_parameters: dict[str, Any], + metadata_condition: MetadataFilteringCondition | None = None, ): - external_knowledge_binding = ( - db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() + external_knowledge_binding = db.session.scalar( + select(ExternalKnowledgeBindings) + .where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id) + .limit(1) ) if not external_knowledge_binding: raise ValueError("external knowledge binding not found") - external_knowledge_api = ( - db.session.query(ExternalKnowledgeApis) - .filter_by(id=external_knowledge_binding.external_knowledge_api_id) - .first() + external_knowledge_api = db.session.scalar( + select(ExternalKnowledgeApis) + .where( + ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id, + ExternalKnowledgeApis.tenant_id == tenant_id, + ) + .limit(1) ) if external_knowledge_api is None or external_knowledge_api.settings is None: raise ValueError("external api template not found") diff --git a/api/services/feature_service.py b/api/services/feature_service.py index f38e1762d11..e18eb096c96 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -3,6 +3,7 @@ from enum import StrEnum from pydantic import BaseModel, ConfigDict, Field from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from enums.cloud_plan import CloudPlan from enums.hosted_provider import HostedTrialProvider from services.billing_service import BillingService @@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel): class SystemFeatureModel(BaseModel): + app_dsl_version: str = "" sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" enable_marketplace: bool = False @@ -164,6 +166,7 @@ class SystemFeatureModel(BaseModel): enable_email_code_login: bool = False enable_email_password_login: bool = True enable_social_oauth_login: bool = False + enable_collaboration_mode: bool = False is_allow_register: bool = False is_allow_create_workspace: bool = False is_email_setup: bool = False @@ -224,6 +227,7 @@ class FeatureService: @classmethod def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() + system_features.app_dsl_version = CURRENT_APP_DSL_VERSION cls._fulfill_system_params_from_env(system_features) @@ -244,6 +248,7 @@ class FeatureService: system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + system_features.enable_collaboration_mode = dify_config.ENABLE_COLLABORATION_MODE system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" @@ -312,7 +317,10 @@ class FeatureService: features.apps.limit = billing_info["apps"]["limit"] if "vector_space" in billing_info: - features.vector_space.size = billing_info["vector_space"]["size"] + # NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0) + # but LimitationModel.size is int; truncate here for compatibility + features.vector_space.size = int(billing_info["vector_space"]["size"]) + # NOTE END features.vector_space.limit = billing_info["vector_space"]["limit"] if "documents_upload_quota" in billing_info: @@ -333,7 +341,11 @@ class FeatureService: features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] if "knowledge_rate_limit" in billing_info: + # NOTE (hj24): + # 1. knowledge_rate_limit size is nullable, currently it's defined but never used, only limit is used. + # 2. So be careful if later we decide to use [size], we cannot assume it is always present. features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"] + # NOTE END if "knowledge_pipeline_publish_enabled" in billing_info: features.knowledge_pipeline.publish_enabled = billing_info["knowledge_pipeline_publish_enabled"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 50a326d8138..f60afe2f19d 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -2,13 +2,12 @@ import base64 import hashlib import os import uuid -from collections.abc import Iterator, Sequence +from collections.abc import Generator, Sequence # Changed Iterator to Generator from contextlib import contextmanager, suppress from tempfile import NamedTemporaryFile -from typing import Literal, Union +from typing import Literal from zipfile import ZIP_DEFLATED, ZipFile -from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -24,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account @@ -52,7 +52,7 @@ class FileService: filename: str, content: bytes, mimetype: str, - user: Union[Account, EndUser], + user: Account | EndUser, source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: @@ -132,8 +132,8 @@ class FileService: return file_size <= file_size_limit def get_file_base64(self, file_id: str) -> str: - upload_file = ( - self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = self._session_maker(expire_on_commit=False).scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) ) if not upload_file: raise NotFound("File not found") @@ -178,7 +178,7 @@ class FileService: Return a short text preview extracted from a document file. """ with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") @@ -200,7 +200,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -220,7 +220,7 @@ class FileService: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -231,7 +231,7 @@ class FileService: def get_public_image_preview(self, file_id: str): with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -247,7 +247,7 @@ class FileService: def get_file_content(self, file_id: str) -> str: with self._session_maker(expire_on_commit=False) as session: - upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") @@ -324,7 +324,7 @@ class FileService: def build_upload_files_zip_tempfile( *, upload_files: Sequence[UploadFile], - ) -> Iterator[str]: + ) -> Generator[str, None, None]: # Changed from Iterator[str] """ Build a ZIP from `UploadFile`s and yield a tempfile path. diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 82e0b0f8b1f..2e5987dd28c 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,23 +1,32 @@ import json import logging import time -from typing import Any - -from graphon.model_runtime.entities import LLMMode +from typing import Any, TypedDict, cast from core.app.app_config.entities import ModelConfig -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db +from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery from models.enums import CreatorUserRole, DatasetQuerySource logger = logging.getLogger(__name__) + +class QueryDict(TypedDict): + content: str + + +class RetrieveResponseDict(TypedDict): + query: QueryDict + records: list[dict[str, Any]] + + default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, @@ -27,6 +36,10 @@ default_retrieval_model = { } +class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False): + metadata_filtering_conditions: dict[str, Any] + + class HitTestingService: @classmethod def retrieve( @@ -34,24 +47,26 @@ class HitTestingService: dataset: Dataset, query: str, account: Account, - retrieval_model: Any, # FIXME drop this any - external_retrieval_model: dict, + retrieval_model: dict[str, Any] | None, + external_retrieval_model: dict[str, Any], attachment_ids: list | None = None, limit: int = 10, ): start = time.perf_counter() # get retrieval model , if the model is not setting , using default - if not retrieval_model: - retrieval_model = dataset.retrieval_model or default_retrieval_model + resolved_retrieval_model = cast( + HitTestingRetrievalModelDict, + retrieval_model or dataset.retrieval_model or default_retrieval_model, + ) document_ids_filter = None - metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions and query: + metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {}) + if metadata_filtering_conditions_raw and query: dataset_retrieval = DatasetRetrieval() - from core.app.app_config.entities import MetadataFilteringCondition + from core.rag.entities import MetadataFilteringCondition - metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions) + metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw) metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( dataset_ids=[dataset.id], @@ -68,19 +83,21 @@ class HitTestingService: if metadata_condition and not document_ids_filter: return cls.compact_retrieve_response(query, []) all_documents = RetrievalService.retrieve( - retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), + retrieval_method=RetrievalMethod( + resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH) + ), dataset_id=dataset.id, query=query, attachment_ids=attachment_ids, - top_k=retrieval_model.get("top_k", 4), - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] + top_k=resolved_retrieval_model.get("top_k", 4), + score_threshold=resolved_retrieval_model.get("score_threshold", 0.0) + if resolved_retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] + reranking_model=resolved_retrieval_model.get("reranking_model", None) + if resolved_retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model", + weights=resolved_retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, ) @@ -114,8 +131,8 @@ class HitTestingService: dataset: Dataset, query: str, account: Account, - external_retrieval_model: dict | None = None, - metadata_filtering_conditions: dict | None = None, + external_retrieval_model: dict[str, Any] | None = None, + metadata_filtering_conditions: dict[str, Any] | None = None, ): if dataset.provider != "external": return { @@ -150,7 +167,7 @@ class HitTestingService: return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod - def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]: + def compact_retrieve_response(cls, query: str, documents: list[Document]) -> RetrieveResponseDict: records = RetrievalService.format_retrieval_documents(documents) return { @@ -161,7 +178,7 @@ class HitTestingService: } @classmethod - def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]: + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> RetrieveResponseDict: records = [] if dataset.provider == "external": for document in documents: diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 77576fa4c0d..8b4983e5f70 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,12 +4,11 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol -from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -18,6 +17,7 @@ from core.workflow.human_input_compat import ( ) from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 02a6620fc74..76598d31ace 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,12 +3,6 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any -from graphon.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -17,6 +11,12 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/message_service.py b/api/services/message_service.py index 5b133b0c04a..98f24dd6a68 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,6 +1,5 @@ from collections.abc import Sequence -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -14,6 +13,7 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 12729278cc4..672f309bac0 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -1,6 +1,8 @@ import copy import logging +from sqlalchemy import delete, func, select + from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -25,10 +27,14 @@ class MetadataService: raise ValueError("Metadata name cannot exceed 255 characters.") current_user, current_tenant_id = current_account_with_tenant() # check if metadata name already exists - if ( - db.session.query(DatasetMetadata) - .filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=metadata_args.name) - .first() + if db.session.scalar( + select(DatasetMetadata) + .where( + DatasetMetadata.tenant_id == current_tenant_id, + DatasetMetadata.dataset_id == dataset_id, + DatasetMetadata.name == metadata_args.name, + ) + .limit(1) ): raise ValueError("Metadata name already exists.") for field in BuiltInField: @@ -54,10 +60,14 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists current_user, current_tenant_id = current_account_with_tenant() - if ( - db.session.query(DatasetMetadata) - .filter_by(tenant_id=current_tenant_id, dataset_id=dataset_id, name=name) - .first() + if db.session.scalar( + select(DatasetMetadata) + .where( + DatasetMetadata.tenant_id == current_tenant_id, + DatasetMetadata.dataset_id == dataset_id, + DatasetMetadata.name == name, + ) + .limit(1) ): raise ValueError("Metadata name already exists.") for field in BuiltInField: @@ -65,7 +75,11 @@ class MetadataService: raise ValueError("Metadata name already exists in Built-in fields.") try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first() + metadata = db.session.scalar( + select(DatasetMetadata) + .where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id) + .limit(1) + ) if metadata is None: raise ValueError("Metadata not found.") old_name = metadata.name @@ -74,9 +88,9 @@ class MetadataService: metadata.updated_at = naive_utc_now() # update related documents - dataset_metadata_bindings = ( - db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() - ) + dataset_metadata_bindings = db.session.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id) + ).all() if dataset_metadata_bindings: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) @@ -101,15 +115,19 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset_id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id, dataset_id=dataset_id).first() + metadata = db.session.scalar( + select(DatasetMetadata) + .where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id) + .limit(1) + ) if metadata is None: raise ValueError("Metadata not found.") db.session.delete(metadata) # deal related documents - dataset_metadata_bindings = ( - db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() - ) + dataset_metadata_bindings = db.session.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id) + ).all() if dataset_metadata_bindings: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) @@ -224,16 +242,23 @@ class MetadataService: # deal metadata binding (in the same transaction as the doc_metadata update) if not operation.partial_update: - db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete() + db.session.execute( + delete(DatasetMetadataBinding).where( + DatasetMetadataBinding.document_id == operation.document_id + ) + ) current_user, current_tenant_id = current_account_with_tenant() for metadata_value in operation.metadata_list: # check if binding already exists if operation.partial_update: - existing_binding = ( - db.session.query(DatasetMetadataBinding) - .filter_by(document_id=operation.document_id, metadata_id=metadata_value.id) - .first() + existing_binding = db.session.scalar( + select(DatasetMetadataBinding) + .where( + DatasetMetadataBinding.document_id == operation.document_id, + DatasetMetadataBinding.metadata_id == metadata_value.id, + ) + .limit(1) ) if existing_binding: continue @@ -275,9 +300,13 @@ class MetadataService: "id": item.get("id"), "name": item.get("name"), "type": item.get("type"), - "count": db.session.query(DatasetMetadataBinding) - .filter_by(metadata_id=item.get("id"), dataset_id=dataset.id) - .count(), + "count": db.session.scalar( + select(func.count(DatasetMetadataBinding.id)).where( + DatasetMetadataBinding.metadata_id == item.get("id"), + DatasetMetadataBinding.dataset_id == dataset.id, + ) + ) + or 0, } for item in dataset.doc_metadata or [] if item.get("id") != "built-in" diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index bc0bfd215cf..c269346f5f9 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,13 +1,7 @@ import json import logging -from typing import Any, Union +from typing import Any, TypedDict -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -18,6 +12,12 @@ from core.model_manager import LBModelManager from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential @@ -25,6 +25,23 @@ from models.provider import LoadBalancingModelConfig, ProviderCredential, Provid logger = logging.getLogger(__name__) +class LoadBalancingConfigDetailDict(TypedDict): + id: str + name: str + credentials: dict[str, Any] + enabled: bool + + +class LoadBalancingConfigSummaryDict(TypedDict): + id: str + name: str + credentials: dict[str, Any] + credential_id: str | None + enabled: bool + in_cooldown: bool + ttl: int + + class ModelLoadBalancingService: @staticmethod def _get_provider_manager(tenant_id: str) -> ProviderManager: @@ -74,7 +91,7 @@ class ModelLoadBalancingService: def get_load_balancing_configs( self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = "" - ) -> tuple[bool, list[dict]]: + ) -> tuple[bool, list[LoadBalancingConfigSummaryDict]]: """ Get load balancing configurations. :param tenant_id: workspace id @@ -156,7 +173,7 @@ class ModelLoadBalancingService: decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) # fetch status and ttl for each config - datas = [] + datas: list[LoadBalancingConfigSummaryDict] = [] for load_balancing_config in load_balancing_configs: in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl( tenant_id=tenant_id, @@ -214,7 +231,7 @@ class ModelLoadBalancingService: def get_load_balancing_config( self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str - ) -> dict | None: + ) -> LoadBalancingConfigDetailDict | None: """ Get load balancing configuration. :param tenant_id: workspace id @@ -267,12 +284,13 @@ class ModelLoadBalancingService: credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) - return { + result: LoadBalancingConfigDetailDict = { "id": load_balancing_model_config.id, "name": load_balancing_model_config.name, "credentials": credentials, "enabled": load_balancing_model_config.enabled, } + return result def _init_inherit_config( self, tenant_id: str, provider: str, model: str, model_type: ModelType @@ -484,7 +502,7 @@ class ModelLoadBalancingService: provider: str, model: str, model_type: str, - credentials: dict, + credentials: dict[str, Any], config_id: str | None = None, ): """ @@ -543,7 +561,7 @@ class ModelLoadBalancingService: provider_configuration: ProviderConfiguration, model_type: ModelType, model: str, - credentials: dict, + credentials: dict[str, Any], load_balancing_model_config: LoadBalancingModelConfig | None = None, model_provider_factory: ModelProviderFactory | None = None, validate: bool = True, @@ -608,7 +626,7 @@ class ModelLoadBalancingService: def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + ) -> ModelCredentialSchema | ProviderCredentialSchema: """Get form schemas.""" if provider_configuration.provider.model_credential_schema: return provider_configuration.provider.model_credential_schema diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 3f37c9b176d..51cda79661a 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,10 +1,10 @@ import logging - -from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule +from typing import Any from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -168,7 +168,9 @@ class ModelProviderService: model_name=model, ) - def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: + def get_provider_credential( + self, tenant_id: str, provider: str, credential_id: str | None = None + ) -> dict[str, Any] | None: """ get provider credentials. @@ -180,7 +182,7 @@ class ModelProviderService: provider_configuration = self._get_provider_configuration(tenant_id, provider) return provider_configuration.get_provider_credential(credential_id=credential_id) - def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict): + def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict[str, Any]): """ validate provider credentials before saving. @@ -192,7 +194,7 @@ class ModelProviderService: provider_configuration.validate_provider_credentials(credentials) def create_provider_credential( - self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None + self, tenant_id: str, provider: str, credentials: dict[str, Any], credential_name: str | None ) -> None: """ Create and save new provider credentials. @@ -210,7 +212,7 @@ class ModelProviderService: self, tenant_id: str, provider: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ) -> None: @@ -254,7 +256,7 @@ class ModelProviderService: def get_model_credential( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None - ) -> dict | None: + ) -> dict[str, Any] | None: """ Retrieve model-specific credentials. @@ -270,7 +272,9 @@ class ModelProviderService: model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict): + def validate_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict[str, Any] + ): """ validate model credentials. @@ -287,7 +291,13 @@ class ModelProviderService: ) def create_model_credential( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + credentials: dict[str, Any], + credential_name: str | None, ) -> None: """ create and save model credentials. @@ -314,7 +324,7 @@ class ModelProviderService: provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ) -> None: diff --git a/api/services/oauth_server.py b/api/services/oauth_server.py index b05b43d76ec..22648070f01 100644 --- a/api/services/oauth_server.py +++ b/api/services/oauth_server.py @@ -2,7 +2,7 @@ import enum import uuid from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest from extensions.ext_database import db @@ -29,7 +29,7 @@ class OAuthServerService: def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None: query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id) - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: return session.execute(query).scalar_one_or_none() @staticmethod diff --git a/api/services/operation_service.py b/api/services/operation_service.py index c05e9d555c3..903efd26ae7 100644 --- a/api/services/operation_service.py +++ b/api/services/operation_service.py @@ -1,8 +1,22 @@ import os +from typing import TypedDict import httpx +class UtmInfo(TypedDict, total=False): + """Expected shape of the utm_info dict passed to record_utm. + + All fields are optional; missing keys default to an empty string. + """ + + utm_source: str + utm_medium: str + utm_campaign: str + utm_content: str + utm_term: str + + class OperationService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @@ -17,7 +31,7 @@ class OperationService: return response.json() @classmethod - def record_utm(cls, tenant_id: str, utm_info: dict): + def record_utm(cls, tenant_id: str, utm_info: UtmInfo): params = { "tenant_id": tenant_id, "utm_source": utm_info.get("utm_source", ""), diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 0db3d3efec4..3ad42faf249 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,3 +1,5 @@ +from typing import Any + from sqlalchemy import select from core.ops.entities.config_entity import BaseTracingConfig @@ -135,7 +137,7 @@ class OpsService: return trace_config_data.to_dict() @classmethod - def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict[str, Any]): """ Create tracing app config :param app_id: app id @@ -210,7 +212,7 @@ class OpsService: return {"result": "success"} @classmethod - def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict[str, Any]): """ Update tracing app config :param app_id: app id diff --git a/api/services/plugin/endpoint_service.py b/api/services/plugin/endpoint_service.py index 11b8e0a3d91..1727cd7abd6 100644 --- a/api/services/plugin/endpoint_service.py +++ b/api/services/plugin/endpoint_service.py @@ -1,9 +1,13 @@ +from typing import Any + from core.plugin.impl.endpoint import PluginEndpointClient class EndpointService: @classmethod - def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict): + def create_endpoint( + cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict[str, Any] + ): return PluginEndpointClient().create_endpoint( tenant_id=tenant_id, user_id=user_id, @@ -32,7 +36,7 @@ class EndpointService: ) @classmethod - def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict): + def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]): return PluginEndpointClient().update_endpoint( tenant_id=tenant_id, user_id=user_id, diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 88dec062a03..789b5fa5b77 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,5 +1,6 @@ import json import uuid +from typing import Any from core.plugin.impl.base import BasePluginClient from extensions.ext_redis import redis_client @@ -16,7 +17,7 @@ class OAuthProxyService(BasePluginClient): tenant_id: str, plugin_id: str, provider: str, - extra_data: dict = {}, + extra_data: dict[str, Any] = {}, credential_id: str | None = None, ): """ diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index 174bed488da..b96b8140acd 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -1,17 +1,17 @@ -from sqlalchemy.orm import Session +from sqlalchemy import select -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import TenantPluginAutoUpgradeStrategy class PluginAutoUpgradeService: @staticmethod def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None: - with Session(db.engine) as session: - return ( - session.query(TenantPluginAutoUpgradeStrategy) + with session_factory.create_session() as session: + return session.scalar( + select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) - .first() + .limit(1) ) @staticmethod @@ -23,11 +23,11 @@ class PluginAutoUpgradeService: exclude_plugins: list[str], include_plugins: list[str], ) -> bool: - with Session(db.engine) as session: - exist_strategy = ( - session.query(TenantPluginAutoUpgradeStrategy) + with session_factory.create_session() as session, session.begin(): + exist_strategy = session.scalar( + select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) - .first() + .limit(1) ) if not exist_strategy: strategy = TenantPluginAutoUpgradeStrategy( @@ -46,16 +46,15 @@ class PluginAutoUpgradeService: exist_strategy.exclude_plugins = exclude_plugins exist_strategy.include_plugins = include_plugins - session.commit() return True @staticmethod def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: - with Session(db.engine) as session: - exist_strategy = ( - session.query(TenantPluginAutoUpgradeStrategy) + with session_factory.create_session() as session, session.begin(): + exist_strategy = session.scalar( + select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) - .first() + .limit(1) ) if not exist_strategy: # create for this tenant @@ -83,5 +82,4 @@ class PluginAutoUpgradeService: exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE exist_strategy.exclude_plugins = [plugin_id] - session.commit() return True diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 442ccef1da6..43a726b1000 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -5,7 +5,7 @@ import time from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Any, TypedDict +from typing import TypedDict from uuid import uuid4 import click @@ -13,6 +13,7 @@ import sqlalchemy as sa import tqdm from flask import Flask, current_app from pydantic import TypeAdapter +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity @@ -42,6 +43,16 @@ class _TenantPluginRecord(TypedDict): _tenant_plugin_adapter: TypeAdapter[_TenantPluginRecord] = TypeAdapter(_TenantPluginRecord) +class ExtractedPluginsDict(TypedDict): + plugins: dict[str, str] + plugin_not_exist: list[str] + + +class PluginInstallResultDict(TypedDict): + success: list[str] + failed: list[str] + + class PluginMigration: @classmethod def extract_plugins(cls, filepath: str, workers: int): @@ -56,7 +67,7 @@ class PluginMigration: current_time = started_at with Session(db.engine) as session: - total_tenant_count = session.query(Tenant.id).count() + total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0 click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) @@ -113,9 +124,12 @@ class PluginMigration: tenant_count = 0 for test_interval in test_intervals: tenant_count = ( - session.query(Tenant.id) - .where(Tenant.created_at.between(current_time, current_time + test_interval)) - .count() + session.scalar( + select(func.count(Tenant.id)).where( + Tenant.created_at.between(current_time, current_time + test_interval) + ) + ) + or 0 ) if tenant_count <= 100: interval = test_interval @@ -137,8 +151,8 @@ class PluginMigration: batch_end = min(current_time + interval, ended_at) - rs = ( - session.query(Tenant.id) + rs = session.execute( + select(Tenant.id) .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) @@ -225,7 +239,7 @@ class PluginMigration: Extract tool tables. """ with Session(db.engine) as session: - rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all() + rs = session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id)).all() result = [] for row in rs: result.append(ToolProviderID(row.provider).plugin_id) @@ -239,7 +253,7 @@ class PluginMigration: """ with Session(db.engine) as session: - rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all() + rs = session.scalars(select(Workflow).where(Workflow.tenant_id == tenant_id)).all() result = [] for row in rs: graph = row.graph_dict @@ -262,7 +276,7 @@ class PluginMigration: Extract app tables. """ with Session(db.engine) as session: - apps = session.query(App).where(App.tenant_id == tenant_id).all() + apps = session.scalars(select(App).where(App.tenant_id == tenant_id)).all() if not apps: return [] @@ -270,7 +284,7 @@ class PluginMigration: app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT ] - rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() + rs = session.scalars(select(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids))).all() result = [] for row in rs: agent_config = row.agent_mode_dict @@ -310,7 +324,7 @@ class PluginMigration: Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins))) @classmethod - def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]: + def extract_unique_plugins(cls, extracted_plugins: str) -> ExtractedPluginsDict: plugins: dict[str, str] = {} plugin_ids = [] plugin_not_exist = [] @@ -524,7 +538,7 @@ class PluginMigration: @classmethod def handle_plugin_instance_install( cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str] - ) -> Mapping[str, Any]: + ) -> PluginInstallResultDict: """ Install plugins for a tenant. """ diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 40565c56ed9..786c09b44e1 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -1,6 +1,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Literal +from sqlalchemy import select from sqlalchemy.orm import Session from core.plugin.entities.parameters import PluginParameterOption @@ -56,24 +57,24 @@ class PluginParameterService: # fetch credentials from db with Session(db.engine) as session: if credential_id: - db_record = ( - session.query(BuiltinToolProvider) + db_record = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider, BuiltinToolProvider.id == credential_id, ) - .first() + .limit(1) ) else: - db_record = ( - session.query(BuiltinToolProvider) + db_record = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider, ) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() + .limit(1) ) if db_record is None: diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 60fa2696401..3cca4268d00 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -1,14 +1,16 @@ -from sqlalchemy.orm import Session +from sqlalchemy import select -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import TenantPluginPermission class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: - with Session(db.engine) as session: - return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() + with session_factory.create_session() as session: + return session.scalar( + select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) + ) @staticmethod def change_permission( @@ -16,9 +18,9 @@ class PluginPermissionService: install_permission: TenantPluginPermission.InstallPermission, debug_permission: TenantPluginPermission.DebugPermission, ): - with Session(db.engine) as session: - permission = ( - session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first() + with session_factory.create_session() as session, session.begin(): + permission = session.scalar( + select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) if not permission: permission = TenantPluginPermission( @@ -30,5 +32,4 @@ class PluginPermissionService: permission.install_permission = install_permission permission.debug_permission = debug_permission - session.commit() return True diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 10e89b1dbab..56bc7859589 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Union +from typing import Any from configs import dify_config from core.app.apps.pipeline.pipeline_generator import PipelineGenerator @@ -17,7 +17,7 @@ class PipelineGenerateService: def generate( cls, pipeline: Pipeline, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index 24baeb73b5b..8c9a81af877 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -1,6 +1,7 @@ import json from os import path from pathlib import Path +from typing import Any from flask import current_app @@ -13,21 +14,21 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json """ - builtin_data: dict | None = None + builtin_data: dict[str, Any] | None = None def get_type(self) -> str: return PipelineTemplateType.BUILTIN - def get_pipeline_templates(self, language: str) -> dict: + def get_pipeline_templates(self, language: str) -> dict[str, Any]: result = self.fetch_pipeline_templates_from_builtin(language) return result - def get_pipeline_template_detail(self, template_id: str): + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: result = self.fetch_pipeline_template_detail_from_builtin(template_id) return result @classmethod - def _get_builtin_data(cls) -> dict: + def _get_builtin_data(cls) -> dict[str, Any]: """ Get builtin data. :return: @@ -43,21 +44,21 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return cls.builtin_data or {} @classmethod - def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict: + def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict[str, Any]: """ Fetch pipeline templates from builtin. :param language: language :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod - def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict[str, Any] | None: """ Fetch pipeline template detail from builtin. :param template_id: Template ID :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(template_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 2ee871a2663..9d446f6d4b0 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,3 +1,5 @@ +from typing import Any, TypedDict + import yaml from sqlalchemy import select @@ -8,25 +10,47 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class CustomizedTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + position: int + chunk_structure: str + + +class CustomizedTemplatesResultDict(TypedDict): + pipeline_templates: list[CustomizedTemplateItemDict] + + +class CustomizedTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + created_by: str + + class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval recommended app from database """ - def get_pipeline_templates(self, language: str) -> dict: + def get_pipeline_templates(self, language: str) -> dict[str, Any]: _, current_tenant_id = current_account_with_tenant() - result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) - return result + return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) - def get_pipeline_template_detail(self, template_id: str): - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.CUSTOMIZED @classmethod - def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict: + def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict[str, Any]: """ Fetch pipeline templates from db. :param tenant_id: tenant id @@ -38,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) ).all() - recommended_pipelines_results = [] + recommended_pipelines_results: list[CustomizedTemplateItemDict] = [] for pipeline_customized_template in pipeline_customized_templates: - recommended_pipeline_result = { + recommended_pipeline_result: CustomizedTemplateItemDict = { "id": pipeline_customized_template.id, "name": pipeline_customized_template.name, "description": pipeline_customized_template.description, @@ -53,7 +77,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} @classmethod - def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict[str, Any] | None: """ Fetch pipeline template detail from db. :param template_id: Template ID diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 43b21a7b320..2964537c350 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,3 +1,5 @@ +from typing import Any, TypedDict + import yaml from sqlalchemy import select @@ -7,24 +9,47 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class PipelineTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + copyright: str + privacy_policy: str + position: int + chunk_structure: str + + +class PipelineTemplatesResultDict(TypedDict): + pipeline_templates: list[PipelineTemplateItemDict] + + +class PipelineTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + + class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval pipeline template from database """ - def get_pipeline_templates(self, language: str) -> dict: - result = self.fetch_pipeline_templates_from_db(language) - return result + def get_pipeline_templates(self, language: str) -> dict[str, Any]: + return self.fetch_pipeline_templates_from_db(language) - def get_pipeline_template_detail(self, template_id: str): - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.DATABASE @classmethod - def fetch_pipeline_templates_from_db(cls, language: str) -> dict: + def fetch_pipeline_templates_from_db(cls, language: str) -> dict[str, Any]: """ Fetch pipeline templates from db. :param language: language @@ -37,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ).all() ) - recommended_pipelines_results = [] + recommended_pipelines_results: list[PipelineTemplateItemDict] = [] for pipeline_built_in_template in pipeline_built_in_templates: - recommended_pipeline_result = { + recommended_pipeline_result: PipelineTemplateItemDict = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, "description": pipeline_built_in_template.description, @@ -54,7 +79,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} @classmethod - def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict[str, Any] | None: """ Fetch pipeline template detail from db. :param pipeline_id: Pipeline ID diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py index 21c30a49865..0ed2a4b8f28 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -1,15 +1,16 @@ from abc import ABC, abstractmethod +from typing import Any class PipelineTemplateRetrievalBase(ABC): """Interface for pipeline template retrieval.""" @abstractmethod - def get_pipeline_templates(self, language: str) -> dict: + def get_pipeline_templates(self, language: str) -> dict[str, Any]: raise NotImplementedError @abstractmethod - def get_pipeline_template_detail(self, template_id: str) -> dict | None: + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: raise NotImplementedError @abstractmethod diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index f996db11dc2..9565ac46cc1 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -1,4 +1,5 @@ import logging +from typing import Any import httpx @@ -15,28 +16,25 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval recommended app from dify official """ - def get_pipeline_template_detail(self, template_id: str) -> dict | None: - result: dict | None + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: try: - result = self.fetch_pipeline_template_detail_from_dify_official(template_id) + return self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) - def get_pipeline_templates(self, language: str) -> dict: + def get_pipeline_templates(self, language: str) -> dict[str, Any]: try: - result = self.fetch_pipeline_templates_from_dify_official(language) + return self.fetch_pipeline_templates_from_dify_official(language) except Exception as e: logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) def get_type(self) -> str: return PipelineTemplateType.REMOTE @classmethod - def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict: + def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict[str, Any]: """ Fetch pipeline template detail from dify official. @@ -53,11 +51,11 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): + f" status_code: {response.status_code}," + f" response: {response.text[:1000]}" ) - data: dict = response.json() + data: dict[str, Any] = response.json() return data @classmethod - def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict: + def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict[str, Any]: """ Fetch pipeline templates from dify official. :param language: language @@ -69,6 +67,6 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): if response.status_code != 200: raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}") - result: dict = response.json() + result: dict[str, Any] = response.json() return result diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index bcf5973d7b2..9db6682e10f 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,19 +5,10 @@ import threading import time from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Union, cast +from typing import Any, cast from uuid import uuid4 from flask_login import current_user -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus -from graphon.errors import WorkflowNodeRunFailedError -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.runtime import VariablePool -from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -38,11 +29,7 @@ from core.datasource.online_document.online_document_plugin import OnlineDocumen from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.helper import marketplace -from core.rag.entities.event import ( - DatasourceCompletedEvent, - DatasourceErrorEvent, - DatasourceProcessingEvent, -) +from core.rag.entities import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent from core.repositories.factory import DifyCoreRepositoryFactory, OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping @@ -57,6 +44,15 @@ from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore @@ -108,7 +104,7 @@ class RagPipelineService: self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @classmethod - def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: + def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() @@ -124,7 +120,7 @@ class RagPipelineService: return result @classmethod - def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: + def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict[str, Any] | None: """ Get pipeline template detail. @@ -135,7 +131,7 @@ class RagPipelineService: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + built_in_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id) if built_in_result is None: logger.warning( "pipeline template retrieval returned empty result, template_id: %s, mode: %s", @@ -146,7 +142,7 @@ class RagPipelineService: else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + customized_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id) return customized_result @classmethod @@ -156,27 +152,27 @@ class RagPipelineService: :param template_id: template id :param template_info: template info """ - customized_template: PipelineCustomizedTemplate | None = ( - db.session.query(PipelineCustomizedTemplate) + customized_template: PipelineCustomizedTemplate | None = db.session.scalar( + select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, ) - .first() + .limit(1) ) if not customized_template: raise ValueError("Customized pipeline template not found.") # check template name is exist template_name = template_info.name if template_name: - template = ( - db.session.query(PipelineCustomizedTemplate) + template = db.session.scalar( + select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, PipelineCustomizedTemplate.id != template_id, ) - .first() + .limit(1) ) if template: raise ValueError("Template name is already exists") @@ -192,13 +188,13 @@ class RagPipelineService: """ Delete customized pipeline template. """ - customized_template: PipelineCustomizedTemplate | None = ( - db.session.query(PipelineCustomizedTemplate) + customized_template: PipelineCustomizedTemplate | None = db.session.scalar( + select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, ) - .first() + .limit(1) ) if not customized_template: raise ValueError("Customized pipeline template not found.") @@ -210,14 +206,14 @@ class RagPipelineService: Get draft workflow """ # fetch draft workflow by rag pipeline - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() + .limit(1) ) # return draft workflow @@ -232,28 +228,28 @@ class RagPipelineService: return None # fetch published workflow by workflow_id - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == pipeline.workflow_id, ) - .first() + .limit(1) ) return workflow def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: """Fetch a published workflow snapshot by ID for restore operations.""" - workflow = ( - db.session.query(Workflow) + workflow = db.session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id, ) - .first() + .limit(1) ) if workflow and workflow.version == Workflow.VERSION_DRAFT: raise IsDraftWorkflowError("source workflow must be published") @@ -301,7 +297,7 @@ class RagPipelineService: self, *, pipeline: Pipeline, - graph: dict, + graph: dict[str, Any], unique_hash: str | None, account: Account, environment_variables: Sequence[VariableBase], @@ -471,14 +467,16 @@ class RagPipelineService: return default_block_configs - def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None: + def get_default_block_config( + self, node_type: str, filters: dict[str, Any] | None = None + ) -> Mapping[str, object] | None: """ Get default config of node. :param node_type: node type :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config @@ -504,7 +502,7 @@ class RagPipelineService: return default_config def run_draft_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + self, pipeline: Pipeline, node_id: str, user_inputs: dict[str, Any], account: Account ) -> WorkflowNodeExecutionModel | None: """ Run draft workflow node @@ -559,7 +557,7 @@ class RagPipelineService: workflow_node_execution.id ) - with Session(bind=db.engine) as session, session.begin(): + with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( session=session, app_id=pipeline.id, @@ -573,8 +571,7 @@ class RagPipelineService: process_data=workflow_node_execution.process_data, outputs=workflow_node_execution.outputs, ) - session.commit() - if workflow_node_execution_db_model is not None: + if isinstance(workflow_node_execution_db_model, WorkflowNodeExecutionModel): enqueue_draft_node_execution_trace( execution=workflow_node_execution_db_model, outputs=workflow_node_execution.outputs, @@ -587,7 +584,7 @@ class RagPipelineService: self, pipeline: Pipeline, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], account: Account, datasource_type: str, is_published: bool, @@ -754,7 +751,7 @@ class RagPipelineService: self, pipeline: Pipeline, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], account: Account, datasource_type: str, is_published: bool, @@ -974,7 +971,7 @@ class RagPipelineService: if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) if document_id: - document = db.session.query(Document).where(Document.id == document_id.value).first() + document = db.session.get(Document, document_id.value) if document: document.indexing_status = IndexingStatus.ERROR document.error = error @@ -984,7 +981,7 @@ class RagPipelineService: return workflow_node_execution def update_workflow( - self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any] ) -> Workflow | None: """ Update workflow attributes @@ -1104,7 +1101,9 @@ class RagPipelineService: ] return datasource_provider_variables - def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: + def get_rag_pipeline_paginate_workflow_runs( + self, pipeline: Pipeline, args: dict[str, Any] + ) -> InfiniteScrollPagination: """ Get debug workflow run list Only return triggered_from == debugging @@ -1174,19 +1173,19 @@ class RagPipelineService: return list(node_executions) @classmethod - def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): + def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]): """ Publish customized pipeline template """ - pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first() + pipeline = db.session.get(Pipeline, pipeline_id) if not pipeline: raise ValueError("Pipeline not found") if not pipeline.workflow_id: raise ValueError("Pipeline workflow not found") - workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() + workflow = db.session.get(Workflow, pipeline.workflow_id) if not workflow: raise ValueError("Workflow not found") - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: dataset = pipeline.retrieve_dataset(session=session) if not dataset: raise ValueError("Dataset not found") @@ -1194,26 +1193,26 @@ class RagPipelineService: # check template name is exist template_name = args.get("name") if template_name: - template = ( - db.session.query(PipelineCustomizedTemplate) + template = db.session.scalar( + select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, ) - .first() + .limit(1) ) if template: raise ValueError("Template name is already exists") - max_position = ( - db.session.query(func.max(PipelineCustomizedTemplate.position)) - .where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) - .scalar() + max_position = db.session.scalar( + select(func.max(PipelineCustomizedTemplate.position)).where( + PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id + ) ) from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: rag_pipeline_dsl_service = RagPipelineDslService(session) dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True) if args.get("icon_info") is None: @@ -1239,13 +1238,14 @@ class RagPipelineService: def is_workflow_exist(self, pipeline: Pipeline) -> bool: return ( - db.session.query(Workflow) - .where( - Workflow.tenant_id == pipeline.tenant_id, - Workflow.app_id == pipeline.id, - Workflow.version == Workflow.VERSION_DRAFT, + db.session.scalar( + select(func.count(Workflow.id)).where( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.version == Workflow.VERSION_DRAFT, + ) ) - .count() + or 0 ) > 0 def get_node_last_run( @@ -1263,7 +1263,7 @@ class RagPipelineService: ) return node_exec - def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account): + def set_datasource_variables(self, pipeline: Pipeline, args: dict[str, Any], current_user: Account): """ Set datasource variables """ @@ -1328,7 +1328,7 @@ class RagPipelineService: # Convert node_execution to WorkflowNodeExecution after save workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore - with Session(bind=db.engine) as session, session.begin(): + with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( session=session, app_id=pipeline.id, @@ -1342,7 +1342,6 @@ class RagPipelineService: process_data=workflow_node_execution.process_data, outputs=workflow_node_execution.outputs, ) - session.commit() enqueue_draft_node_execution_trace( execution=workflow_node_execution_db_model, outputs=workflow_node_execution.outputs, @@ -1351,13 +1350,13 @@ class RagPipelineService: ) return workflow_node_execution_db_model - def get_recommended_plugins(self, type: str) -> dict: + def get_recommended_plugins(self, type: str) -> dict[str, Any]: # Query active recommended plugins - query = db.session.query(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) + stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) if type and type != "all": - query = query.where(PipelineRecommendedPlugin.type == type) + stmt = stmt.where(PipelineRecommendedPlugin.type == type) - pipeline_recommended_plugins = query.order_by(PipelineRecommendedPlugin.position.asc()).all() + pipeline_recommended_plugins = db.session.scalars(stmt.order_by(PipelineRecommendedPlugin.position.asc())).all() if not pipeline_recommended_plugins: return { @@ -1392,18 +1391,16 @@ class RagPipelineService: "uninstalled_recommended_plugins": uninstalled_plugin_list, } - def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]): + def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser): """ Retry error document """ - document_pipeline_execution_log = ( - db.session.query(DocumentPipelineExecutionLog) - .where(DocumentPipelineExecutionLog.document_id == document.id) - .first() + document_pipeline_execution_log = db.session.scalar( + select(DocumentPipelineExecutionLog).where(DocumentPipelineExecutionLog.document_id == document.id).limit(1) ) if not document_pipeline_execution_log: raise ValueError("Document pipeline execution log not found") - pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_execution_log.pipeline_id).first() + pipeline = db.session.get(Pipeline, document_pipeline_execution_log.pipeline_id) if not pipeline: raise ValueError("Pipeline not found") # convert to app config @@ -1432,23 +1429,23 @@ class RagPipelineService: """ Get datasource plugins """ - dataset: Dataset | None = ( - db.session.query(Dataset) + dataset: Dataset | None = db.session.scalar( + select(Dataset) .where( Dataset.id == dataset_id, Dataset.tenant_id == tenant_id, ) - .first() + .limit(1) ) if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = ( - db.session.query(Pipeline) + pipeline: Pipeline | None = db.session.scalar( + select(Pipeline) .where( Pipeline.id == dataset.pipeline_id, Pipeline.tenant_id == tenant_id, ) - .first() + .limit(1) ) if not pipeline: raise ValueError("Pipeline not found") @@ -1530,23 +1527,23 @@ class RagPipelineService: """ Get pipeline """ - dataset: Dataset | None = ( - db.session.query(Dataset) + dataset: Dataset | None = db.session.scalar( + select(Dataset) .where( Dataset.id == dataset_id, Dataset.tenant_id == tenant_id, ) - .first() + .limit(1) ) if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = ( - db.session.query(Pipeline) + pipeline: Pipeline | None = db.session.scalar( + select(Pipeline) .where( Pipeline.id == dataset.pipeline_id, Pipeline.tenant_id == tenant_id, ) - .first() + .limit(1) ) if not pipeline: raise ValueError("Pipeline not found") diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index 04156713f4f..f315d053cb9 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -5,8 +5,7 @@ import logging import uuid from collections.abc import Mapping from datetime import UTC, datetime -from enum import StrEnum -from typing import cast +from typing import Any, cast from urllib.parse import urlparse from uuid import uuid4 @@ -14,14 +13,8 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from packaging import version -from pydantic import BaseModel, Field +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -34,10 +27,17 @@ from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode from models.workflow import Workflow, WorkflowType +from services.entities.dsl_entities import CheckDependenciesResult, ImportMode, ImportStatus from services.entities.knowledge_entities.rag_pipeline_entities import ( IconInfo, KnowledgeConfiguration, @@ -54,18 +54,6 @@ DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB CURRENT_DSL_VERSION = "0.1.0" -class ImportMode(StrEnum): - YAML_CONTENT = "yaml-content" - YAML_URL = "yaml-url" - - -class ImportStatus(StrEnum): - COMPLETED = "completed" - COMPLETED_WITH_WARNINGS = "completed-with-warnings" - PENDING = "pending" - FAILED = "failed" - - class RagPipelineImportInfo(BaseModel): id: str status: ImportStatus @@ -76,10 +64,6 @@ class RagPipelineImportInfo(BaseModel): dataset_id: str | None = None -class CheckDependenciesResult(BaseModel): - leaked_dependencies: list[PluginDependency] = Field(default_factory=list) - - def _check_version_compatibility(imported_version: str) -> ImportStatus: """Determine import status based on version comparison""" try: @@ -299,7 +283,9 @@ class RagPipelineDslService: ): raise ValueError("Chunk structure is not compatible with the published pipeline") if not dataset: - datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all() + datasets = self._session.scalars( + select(Dataset).where(Dataset.tenant_id == account.current_tenant_id) + ).all() names = [dataset.name for dataset in datasets] generate_name = generate_incremental_name(names, name) dataset = Dataset( @@ -319,8 +305,8 @@ class RagPipelineDslService: chunk_structure=knowledge_configuration.chunk_structure, ) if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - dataset_collection_binding = ( - self._session.query(DatasetCollectionBinding) + dataset_collection_binding = self._session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -328,7 +314,7 @@ class RagPipelineDslService: DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -456,8 +442,8 @@ class RagPipelineDslService: dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - dataset_collection_binding = ( - self._session.query(DatasetCollectionBinding) + dataset_collection_binding = self._session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -465,7 +451,7 @@ class RagPipelineDslService: DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -540,7 +526,7 @@ class RagPipelineDslService: self, *, pipeline: Pipeline | None, - data: dict, + data: dict[str, Any], account: Account, dependencies: list[PluginDependency] | None = None, ) -> Pipeline: @@ -607,14 +593,14 @@ class RagPipelineDslService: IMPORT_INFO_REDIS_EXPIRY, CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), ) - workflow = ( - self._session.query(Workflow) + workflow = self._session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() + .limit(1) ) # create draft workflow if not found @@ -674,21 +660,21 @@ class RagPipelineDslService: return yaml.dump(export_data, allow_unicode=True) # type: ignore - def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None: + def _append_workflow_export_data( + self, *, export_data: dict[str, Any], pipeline: Pipeline, include_secret: bool + ) -> None: """ Append workflow export data :param export_data: export data :param pipeline: Pipeline instance """ - workflow = ( - self._session.query(Workflow) - .where( + workflow = self._session.scalar( + select(Workflow).where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() ) if not workflow: raise ValueError("Missing draft workflow configuration, please check.") @@ -920,15 +906,16 @@ class RagPipelineDslService: ): if rag_pipeline_dataset_create_entity.name: # check if dataset name already exists - if ( - self._session.query(Dataset) - .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) - .first() + if self._session.scalar( + select(Dataset).where( + Dataset.name == rag_pipeline_dataset_create_entity.name, + Dataset.tenant_id == tenant_id, + ) ): raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") else: # generate a random name as Untitled 1 2 3 ... - datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all() + datasets = self._session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all() names = [dataset.name for dataset in datasets] rag_pipeline_dataset_create_entity.name = generate_incremental_name( names, diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index c3b00fe1094..f08ec7474ba 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import UTC, datetime from pathlib import Path +from typing import Any from uuid import uuid4 import yaml @@ -154,7 +155,7 @@ class RagPipelineTransformService: raise ValueError("Unsupported doc form") return pipeline_yaml - def _deal_file_extensions(self, node: dict): + def _deal_file_extensions(self, node: dict[str, Any]): file_extensions = node.get("data", {}).get("fileExtensions", []) if not file_extensions: return node @@ -167,7 +168,7 @@ class RagPipelineTransformService: dataset: Dataset, indexing_technique: str | None, retrieval_model: RetrievalSetting | None, - node: dict, + node: dict[str, Any], ): knowledge_configuration_dict = node.get("data", {}) @@ -191,7 +192,7 @@ class RagPipelineTransformService: def _create_pipeline( self, - data: dict, + data: dict[str, Any], ) -> Pipeline: """Create a new app or update an existing one.""" pipeline_data = data.get("rag_pipeline", {}) @@ -258,7 +259,7 @@ class RagPipelineTransformService: db.session.add(pipeline) return pipeline - def _deal_dependencies(self, pipeline_yaml: dict, tenant_id: str): + def _deal_dependencies(self, pipeline_yaml: dict[str, Any], tenant_id: str): installer_manager = PluginInstaller() installed_plugins = installer_manager.list_plugins(tenant_id) diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 64751d186ce..16dc66cd767 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -1,6 +1,7 @@ import json from os import path from pathlib import Path +from typing import Any from flask import current_app @@ -13,7 +14,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): Retrieval recommended app from buildin, the location is constants/recommended_apps.json """ - builtin_data: dict | None = None + builtin_data: dict[str, Any] | None = None def get_type(self) -> str: return RecommendAppType.BUILDIN @@ -53,7 +54,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod - def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None: + def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict[str, Any] | None: """ Fetch recommended app detail from builtin. :param app_id: App ID diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 6fb90d356d9..1df5fd13b6a 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,3 +1,5 @@ +from typing import Any, TypedDict + from sqlalchemy import select from constants.languages import languages @@ -8,16 +10,43 @@ from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase from services.recommend_app.recommend_app_type import RecommendAppType +class RecommendedAppItemDict(TypedDict): + id: str + app: App | None + app_id: str + description: Any + copyright: Any + privacy_policy: Any + custom_disclaimer: str + category: str + position: int + is_listed: bool + + +class RecommendedAppsResultDict(TypedDict): + recommended_apps: list[RecommendedAppItemDict] + categories: list[str] + + +class RecommendedAppDetailDict(TypedDict): + id: str + name: str + icon: Any + icon_background: str | None + mode: str + export_data: str + + class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): """ Retrieval recommended app from database """ - def get_recommended_apps_and_categories(self, language: str): + def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict: result = self.fetch_recommended_apps_from_db(language) return result - def get_recommend_app_detail(self, app_id: str): + def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None: result = self.fetch_recommended_app_detail_from_db(app_id) return result @@ -25,7 +54,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.DATABASE @classmethod - def fetch_recommended_apps_from_db(cls, language: str): + def fetch_recommended_apps_from_db(cls, language: str) -> RecommendedAppsResultDict: """ Fetch recommended apps from db. :param language: language @@ -41,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): ).all() categories = set() - recommended_apps_result = [] + recommended_apps_result: list[RecommendedAppItemDict] = [] for recommended_app in recommended_apps: app = recommended_app.app if not app or not app.is_public: @@ -51,7 +80,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): if not site: continue - recommended_app_result = { + recommended_app_result: RecommendedAppItemDict = { "id": recommended_app.id, "app": recommended_app.app, "app_id": recommended_app.app_id, @@ -67,10 +96,10 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): categories.add(recommended_app.category) - return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories)) @classmethod - def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None: + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None: """ Fetch recommended app detail from db. :param app_id: App ID @@ -89,11 +118,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): if not app_model or not app_model.is_public: return None - return { - "id": app_model.id, - "name": app_model.name, - "icon": app_model.icon, - "icon_background": app_model.icon_background, - "mode": app_model.mode, - "export_data": AppDslService.export_dsl(app_model=app_model), - } + return RecommendedAppDetailDict( + id=app_model.id, + name=app_model.name, + icon=app_model.icon, + icon_background=app_model.icon_background, + mode=app_model.mode, + export_data=AppDslService.export_dsl(app_model=app_model), + ) diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index b217c9026ae..5818be04805 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -1,4 +1,5 @@ import logging +from typing import Any import httpx @@ -35,7 +36,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.REMOTE @classmethod - def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None: + def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict[str, Any] | None: """ Fetch recommended app detail from dify official. :param app_id: App ID @@ -46,7 +47,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: return None - data: dict = response.json() + data: dict[str, Any] = response.json() return data @classmethod @@ -62,7 +63,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") - result: dict = response.json() + result: dict[str, Any] = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 9819822103b..134dd37a3ed 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,3 +1,5 @@ +from typing import Any + from sqlalchemy import select from configs import dify_config @@ -37,7 +39,7 @@ class RecommendedAppService: return result @classmethod - def get_recommend_app_detail(cls, app_id: str) -> dict | None: + def get_recommend_app_detail(cls, app_id: str) -> dict[str, Any] | None: """ Get recommend app detail. :param app_id: app id @@ -45,7 +47,7 @@ class RecommendedAppService: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() - result: dict = retrieval_instance.get_recommend_app_detail(app_id) + result: dict[str, Any] = retrieval_instance.get_recommend_app_detail(app_id) if FeatureService.get_system_features().enable_trial_app: app_id = result["id"] trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py index 48c3e72af09..1e9f0bf1493 100644 --- a/api/services/retention/conversation/messages_clean_service.py +++ b/api/services/retention/conversation/messages_clean_service.py @@ -3,12 +3,12 @@ import logging import random import time from collections.abc import Sequence -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, TypedDict, cast import sqlalchemy as sa from sqlalchemy import delete, select, tuple_ from sqlalchemy.engine import CursorResult -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from extensions.ext_database import db @@ -158,6 +158,13 @@ class MessagesCleanupMetrics: self._record(self._job_duration_seconds, job_duration_seconds, attributes) +class MessagesCleanStatsDict(TypedDict): + batches: int + total_messages: int + filtered_messages: int + total_deleted: int + + class MessagesCleanService: """ Service for cleaning expired messages based on retention policies. @@ -299,7 +306,7 @@ class MessagesCleanService: task_label=task_label, ) - def run(self) -> dict[str, int]: + def run(self) -> MessagesCleanStatsDict: """ Execute the message cleanup operation. @@ -319,7 +326,7 @@ class MessagesCleanService: job_duration_seconds=time.monotonic() - run_start, ) - def _clean_messages_by_time_range(self) -> dict[str, int]: + def _clean_messages_by_time_range(self) -> MessagesCleanStatsDict: """ Clean messages within a time range using cursor-based pagination. @@ -334,7 +341,7 @@ class MessagesCleanService: Returns: Dict with statistics: batches, filtered_messages, total_deleted """ - stats = { + stats: MessagesCleanStatsDict = { "batches": 0, "total_messages": 0, "filtered_messages": 0, @@ -362,7 +369,7 @@ class MessagesCleanService: batch_deleted_messages = 0 # Step 1: Fetch a batch of messages using cursor - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: fetch_messages_start = time.monotonic() msg_stmt = ( select(Message.id, Message.app_id, Message.created_at) @@ -470,7 +477,7 @@ class MessagesCleanService: # Step 4: Batch delete messages and their relations if not self._dry_run: - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: delete_relations_start = time.monotonic() # Delete related records first self._batch_delete_message_relations(session, message_ids_to_delete) @@ -482,9 +489,7 @@ class MessagesCleanService: delete_result = cast(CursorResult, session.execute(delete_stmt)) messages_deleted = delete_result.rowcount delete_messages_ms = int((time.monotonic() - delete_messages_start) * 1000) - commit_start = time.monotonic() - session.commit() - commit_ms = int((time.monotonic() - commit_start) * 1000) + commit_ms = 0 stats["total_deleted"] += messages_deleted batch_deleted_messages = messages_deleted diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index 2c1f99a3bc9..21be411bea7 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -24,16 +24,16 @@ import zipfile from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any +from typing import Any, TypedDict import click -from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, @@ -49,6 +49,23 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHI logger = logging.getLogger(__name__) +class TableStatsManifestEntry(TypedDict): + row_count: int + checksum: str + size_bytes: int + + +class ArchiveManifestDict(TypedDict): + schema_version: str + workflow_run_id: str + tenant_id: str + app_id: str + workflow_id: str + created_at: str + archived_at: str + tables: dict[str, TableStatsManifestEntry] + + @dataclass class TableStats: """Statistics for a single archived table.""" @@ -472,25 +489,26 @@ class WorkflowRunArchiver: self, run: WorkflowRun, table_stats: list[TableStats], - ) -> dict[str, Any]: + ) -> ArchiveManifestDict: """Generate a manifest for the archived workflow run.""" - return { - "schema_version": ARCHIVE_SCHEMA_VERSION, - "workflow_run_id": run.id, - "tenant_id": run.tenant_id, - "app_id": run.app_id, - "workflow_id": run.workflow_id, - "created_at": run.created_at.isoformat(), - "archived_at": datetime.datetime.now(datetime.UTC).isoformat(), - "tables": { - stat.table_name: { - "row_count": stat.row_count, - "checksum": stat.checksum, - "size_bytes": stat.size_bytes, - } - for stat in table_stats - }, + tables: dict[str, TableStatsManifestEntry] = { + stat.table_name: { + "row_count": stat.row_count, + "checksum": stat.checksum, + "size_bytes": stat.size_bytes, + } + for stat in table_stats } + return ArchiveManifestDict( + schema_version=ARCHIVE_SCHEMA_VERSION, + workflow_run_id=run.id, + tenant_id=run.tenant_id, + app_id=run.app_id, + workflow_id=run.workflow_id, + created_at=run.created_at.isoformat(), + archived_at=datetime.datetime.now(datetime.UTC).isoformat(), + tables=tables, + ) def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes: buffer = io.BytesIO() diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py index 62bc9f5f101..58e8ac57a8f 100644 --- a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py +++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py @@ -3,7 +3,7 @@ import logging import random import time from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict import click from sqlalchemy.orm import Session, sessionmaker @@ -12,7 +12,7 @@ from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from models.workflow import WorkflowRun -from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict from repositories.factory import DifyAPIRepositoryFactory from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.billing_service import BillingService, SubscriptionPlan @@ -24,6 +24,15 @@ if TYPE_CHECKING: from opentelemetry.metrics import Counter, Histogram +class RelatedCountsDict(TypedDict): + node_executions: int + offloads: int + app_logs: int + trigger_logs: int + pauses: int + pause_reasons: int + + class WorkflowRunCleanupMetrics: """ Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs. @@ -173,6 +182,9 @@ class WorkflowRunCleanupMetrics: self._record(self._job_duration_seconds, job_duration_seconds, attributes) +_RELATED_RECORD_KEYS = ("node_executions", "offloads", "app_logs", "trigger_logs", "pauses", "pause_reasons") + + class WorkflowRunCleanup: def __init__( self, @@ -230,7 +242,7 @@ class WorkflowRunCleanup: total_runs_deleted = 0 total_runs_targeted = 0 - related_totals = self._empty_related_counts() if self.dry_run else None + related_totals: RelatedCountsDict | None = self._empty_related_counts() if self.dry_run else None batch_index = 0 last_seen: tuple[datetime.datetime, str] | None = None status = "success" @@ -312,8 +324,7 @@ class WorkflowRunCleanup: int((time.monotonic() - count_start) * 1000), ) if related_totals is not None: - for key in related_totals: - related_totals[key] += batch_counts.get(key, 0) + self._accumulate_related_counts(related_totals, batch_counts) sample_ids = ", ".join(run.id for run in free_runs[:5]) click.echo( click.style( @@ -332,7 +343,10 @@ class WorkflowRunCleanup: targeted_runs=len(free_runs), skipped_runs=paid_or_skipped, deleted_runs=0, - related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()}, + related_counts={ + k: batch_counts[k] # type: ignore[literal-required] + for k in _RELATED_RECORD_KEYS + }, related_action="would_delete", batch_duration_seconds=time.monotonic() - batch_start, ) @@ -372,7 +386,10 @@ class WorkflowRunCleanup: targeted_runs=len(free_runs), skipped_runs=paid_or_skipped, deleted_runs=counts["runs"], - related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()}, + related_counts={ + k: counts[k] # type: ignore[literal-required] + for k in _RELATED_RECORD_KEYS + }, related_action="deleted", batch_duration_seconds=time.monotonic() - batch_start, ) @@ -506,7 +523,7 @@ class WorkflowRunCleanup: return trigger_repo.count_by_run_ids(run_ids) @staticmethod - def _empty_related_counts() -> dict[str, int]: + def _empty_related_counts() -> RelatedCountsDict: return { "node_executions": 0, "offloads": 0, @@ -517,7 +534,7 @@ class WorkflowRunCleanup: } @staticmethod - def _format_related_counts(counts: dict[str, int]) -> str: + def _format_related_counts(counts: RelatedCountsDict) -> str: return ( f"node_executions {counts['node_executions']}, " f"offloads {counts['offloads']}, " @@ -527,6 +544,15 @@ class WorkflowRunCleanup: f"pause_reasons {counts['pause_reasons']}" ) + @staticmethod + def _accumulate_related_counts(totals: RelatedCountsDict, batch: RunsWithRelatedCountsDict) -> None: + totals["node_executions"] += batch.get("node_executions", 0) + totals["offloads"] += batch.get("offloads", 0) + totals["app_logs"] += batch.get("app_logs", 0) + totals["trigger_logs"] += batch.get("trigger_logs", 0) + totals["pauses"] += batch.get("pauses", 0) + totals["pause_reasons"] += batch.get("pause_reasons", 0) + def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]: run_ids = [run.id for run in runs] repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( diff --git a/api/services/retention/workflow_run/delete_archived_workflow_run.py b/api/services/retention/workflow_run/delete_archived_workflow_run.py index 11873bf1b9a..937a1067105 100644 --- a/api/services/retention/workflow_run/delete_archived_workflow_run.py +++ b/api/services/retention/workflow_run/delete_archived_workflow_run.py @@ -14,7 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker from extensions.ext_database import db from models.workflow import WorkflowRun -from repositories.api_workflow_run_repository import APIWorkflowRunRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository @@ -23,7 +23,17 @@ class DeleteResult: run_id: str tenant_id: str success: bool - deleted_counts: dict[str, int] = field(default_factory=dict) + deleted_counts: RunsWithRelatedCountsDict = field( + default_factory=lambda: { # type: ignore[assignment] + "runs": 0, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + ) error: str | None = None elapsed_time: float = 0.0 diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 12053377e24..cf39469be85 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -4,10 +4,9 @@ import logging import time import uuid from datetime import UTC, datetime -from typing import Any +from typing import TypedDict, cast -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelType +from sqlalchemy import select from sqlalchemy.orm import Session from core.db.session_factory import session_factory @@ -17,6 +16,8 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument @@ -25,6 +26,22 @@ from models.enums import SummaryStatus logger = logging.getLogger(__name__) +class SummaryEntryDict(TypedDict): + segment_id: str + segment_position: int + status: str + summary_preview: str | None + error: str | None + created_at: int | None + updated_at: int | None + + +class DocumentSummaryStatusDetailDict(TypedDict): + total_segments: int + summary_status: dict[str, int] + summaries: list[SummaryEntryDict] + + class SummaryIndexService: """Service for generating and managing summary indexes.""" @@ -93,8 +110,13 @@ class SummaryIndexService: """ with session_factory.create_session() as session: # Check if summary record already exists - existing_summary = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + existing_summary = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if existing_summary: @@ -293,8 +315,10 @@ class SummaryIndexService: summary_record_id, segment.id, ) - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where(DocumentSegmentSummary.id == summary_record_id) + .limit(1) ) if not summary_record_in_session: @@ -307,10 +331,13 @@ class SummaryIndexService: dataset.id, segment.id, ) - summary_record_in_session = ( - session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if not summary_record_in_session: @@ -322,7 +349,6 @@ class SummaryIndexService: summary_record_id, ) summary_record_in_session = DocumentSegmentSummary( - id=summary_record_id, # Use the same ID if available dataset_id=dataset.id, document_id=segment.document_id, chunk_id=segment.id, @@ -333,6 +359,9 @@ class SummaryIndexService: status=SummaryStatus.COMPLETED, enabled=True, ) + if summary_record_in_session is None: + raise RuntimeError("summary_record_in_session should not be None at this point") + summary_record_in_session.id = summary_record_id session.add(summary_record_in_session) logger.info( "Created new summary record (id=%s) for segment %s after vectorization", @@ -471,8 +500,10 @@ class SummaryIndexService: with session_factory.create_session() as error_session: # Try to find the record by id first # Note: Using assignment only (no type annotation) to avoid redeclaration error - summary_record_in_session = ( - error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + summary_record_in_session = error_session.scalar( + select(DocumentSegmentSummary) + .where(DocumentSegmentSummary.id == summary_record_id) + .limit(1) ) if not summary_record_in_session: # Try to find by chunk_id and dataset_id @@ -484,10 +515,13 @@ class SummaryIndexService: dataset.id, segment.id, ) - summary_record_in_session = ( - error_session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record_in_session = error_session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record_in_session: @@ -535,14 +569,12 @@ class SummaryIndexService: with session_factory.create_session() as session: # Query existing summary records - existing_summaries = ( - session.query(DocumentSegmentSummary) - .filter( + existing_summaries = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset.id, ) - .all() - ) + ).all() existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} # Create or update records @@ -587,8 +619,13 @@ class SummaryIndexService: error: Error message """ with session_factory.create_session() as session: - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -623,8 +660,13 @@ class SummaryIndexService: with session_factory.create_session() as session: try: # Get or refresh summary record in this session - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if not summary_record_in_session: @@ -694,8 +736,13 @@ class SummaryIndexService: except Exception as e: logger.exception("Failed to generate summary for segment %s", segment.id) # Update summary record with error status - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record_in_session: summary_record_in_session.status = SummaryStatus.ERROR @@ -753,17 +800,17 @@ class SummaryIndexService: with session_factory.create_session() as session: # Query segments (only enabled segments) - query = session.query(DocumentSegment).filter_by( - dataset_id=dataset.id, - document_id=document.id, - status="completed", - enabled=True, # Only generate summaries for enabled segments + stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments ) if segment_ids: - query = query.filter(DocumentSegment.id.in_(segment_ids)) + stmt = stmt.where(DocumentSegment.id.in_(segment_ids)) - segments = query.all() + segments = list(session.scalars(stmt).all()) if not segments: logger.info("No segments found for document %s", document.id) @@ -832,15 +879,15 @@ class SummaryIndexService: from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=True, # Only disable enabled summaries + stmt = select(DocumentSegmentSummary).where( + DocumentSegmentSummary.dataset_id == dataset.id, + DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -895,15 +942,15 @@ class SummaryIndexService: return with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=False, # Only enable disabled summaries + stmt = select(DocumentSegmentSummary).where( + DocumentSegmentSummary.dataset_id == dataset.id, + DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -919,13 +966,13 @@ class SummaryIndexService: enabled_count = 0 for summary in summaries: # Get the original segment - segment = ( - session.query(DocumentSegment) - .filter_by( - id=summary.chunk_id, - dataset_id=dataset.id, + segment = session.scalar( + select(DocumentSegment) + .where( + DocumentSegment.id == summary.chunk_id, + DocumentSegment.dataset_id == dataset.id, ) - .first() + .limit(1) ) # Summary.enabled stays in sync with chunk.enabled, @@ -972,12 +1019,12 @@ class SummaryIndexService: segment_ids: List of segment IDs to delete summaries for. If None, delete all. """ with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) + stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -1030,10 +1077,13 @@ class SummaryIndexService: # Check if summary_content is empty (whitespace-only strings are considered empty) if not summary_content or not summary_content.strip(): # If summary is empty, only delete existing summary vector and record - summary_record = ( - session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -1061,8 +1111,13 @@ class SummaryIndexService: return None # Find existing summary record - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -1146,8 +1201,13 @@ class SummaryIndexService: except Exception as e: logger.exception("Failed to update summary for segment %s", segment.id) # Update summary record with error status if it exists - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: summary_record.status = SummaryStatus.ERROR @@ -1169,14 +1229,14 @@ class SummaryIndexService: DocumentSegmentSummary instance if found, None otherwise """ with session_factory.create_session() as session: - return ( - session.query(DocumentSegmentSummary) + return session.scalar( + select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment_id, DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) - .first() + .limit(1) ) @staticmethod @@ -1195,15 +1255,13 @@ class SummaryIndexService: return {} with session_factory.create_session() as session: - summary_records = ( - session.query(DocumentSegmentSummary) - .where( + summary_records = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) - .all() - ) + ).all() return {summary.chunk_id: summary for summary in summary_records} @@ -1223,16 +1281,16 @@ class SummaryIndexService: List of DocumentSegmentSummary instances (only enabled summaries) """ with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter( + stmt = select(DocumentSegmentSummary).where( DocumentSegmentSummary.document_id == document_id, DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - return query.all() + return list(session.scalars(stmt).all()) @staticmethod def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None: @@ -1249,16 +1307,15 @@ class SummaryIndexService: """ # Get all segments for this document (excluding qa_model and re_segment) with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment.id) - .where( - DocumentSegment.document_id == document_id, - DocumentSegment.status != "re_segment", - DocumentSegment.tenant_id == tenant_id, - ) - .all() + segment_ids = list( + session.scalars( + select(DocumentSegment.id).where( + DocumentSegment.document_id == document_id, + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + ).all() ) - segment_ids = [seg.id for seg in segments] if not segment_ids: return None @@ -1296,15 +1353,13 @@ class SummaryIndexService: # Get all segments for these documents (excluding qa_model and re_segment) with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment.id, DocumentSegment.document_id) - .where( + segments = session.execute( + select(DocumentSegment.id, DocumentSegment.document_id).where( DocumentSegment.document_id.in_(document_ids), DocumentSegment.status != "re_segment", DocumentSegment.tenant_id == tenant_id, ) - .all() - ) + ).all() # Group segments by document_id document_segments_map: dict[str, list[str]] = {} @@ -1352,7 +1407,7 @@ class SummaryIndexService: def get_document_summary_status_detail( document_id: str, dataset_id: str, - ) -> dict[str, Any]: + ) -> DocumentSummaryStatusDetailDict: """ Get detailed summary status for a document. @@ -1403,7 +1458,7 @@ class SummaryIndexService: SummaryStatus.NOT_STARTED: 0, } - summary_list = [] + summary_list: list[SummaryEntryDict] = [] for segment in segments: summary = summary_map.get(segment.id) if summary: @@ -1438,8 +1493,8 @@ class SummaryIndexService: } ) - return { - "total_segments": total_segments, - "summary_status": status_counts, - "summaries": summary_list, - } + return DocumentSummaryStatusDetailDict( + total_segments=total_segments, + summary_status=cast(dict[str, int], status_counts), + summaries=summary_list, + ) diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 194622bd862..1882c855ea7 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -2,6 +2,7 @@ import uuid import sqlalchemy as sa from flask_login import current_user +from pydantic import BaseModel, Field from sqlalchemy import func, select from werkzeug.exceptions import NotFound @@ -11,6 +12,28 @@ from models.enums import TagType from models.model import App, Tag, TagBinding +class SaveTagPayload(BaseModel): + name: str = Field(min_length=1, max_length=50) + type: TagType + + +class UpdateTagPayload(BaseModel): + name: str = Field(min_length=1, max_length=50) + type: TagType + + +class TagBindingCreatePayload(BaseModel): + tag_ids: list[str] + target_id: str + type: TagType + + +class TagBindingDeletePayload(BaseModel): + tag_id: str + target_id: str + type: TagType + + class TagService: @staticmethod def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None): @@ -78,12 +101,12 @@ class TagService: return tags or [] @staticmethod - def save_tags(args: dict) -> Tag: - if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]): + def save_tags(payload: SaveTagPayload) -> Tag: + if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name): raise ValueError("Tag name already exists") tag = Tag( - name=args["name"], - type=TagType(args["type"]), + name=payload.name, + type=TagType(payload.type), created_by=current_user.id, tenant_id=current_user.current_tenant_id, ) @@ -93,13 +116,24 @@ class TagService: return tag @staticmethod - def update_tags(args: dict, tag_id: str) -> Tag: - if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")): - raise ValueError("Tag name already exists") + def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag: tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1)) if not tag: raise NotFound("Tag not found") - tag.name = args["name"] + if payload.name != tag.name: + existing = db.session.scalar( + select(Tag) + .where( + Tag.name == payload.name, + Tag.tenant_id == current_user.current_tenant_id, + Tag.type == tag.type, + Tag.id != tag_id, + ) + .limit(1) + ) + if existing: + raise ValueError("Tag name already exists") + tag.name = payload.name db.session.commit() return tag @@ -122,21 +156,19 @@ class TagService: db.session.commit() @staticmethod - def save_tag_binding(args): - # check if target exists - TagService.check_target_exists(args["type"], args["target_id"]) - # save tag binding - for tag_id in args["tag_ids"]: + def save_tag_binding(payload: TagBindingCreatePayload): + TagService.check_target_exists(payload.type, payload.target_id) + for tag_id in payload.tag_ids: tag_binding = db.session.scalar( select(TagBinding) - .where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id) .limit(1) ) if tag_binding: continue new_tag_binding = TagBinding( tag_id=tag_id, - target_id=args["target_id"], + target_id=payload.target_id, tenant_id=current_user.current_tenant_id, created_by=current_user.id, ) @@ -144,17 +176,15 @@ class TagService: db.session.commit() @staticmethod - def delete_tag_binding(args): - # check if target exists - TagService.check_target_exists(args["type"], args["target_id"]) - # delete tag binding - tag_bindings = db.session.scalar( + def delete_tag_binding(payload: TagBindingDeletePayload): + TagService.check_target_exists(payload.type, payload.target_id) + tag_binding = db.session.scalar( select(TagBinding) - .where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == args["tag_id"]) + .where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id) .limit(1) ) - if tag_bindings: - db.session.delete(tag_bindings) + if tag_binding: + db.session.delete(tag_binding) db.session.commit() @staticmethod diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index dfc0c2c63f7..5ff2c217492 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -2,9 +2,9 @@ import json import logging from typing import Any, TypedDict, cast -from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select +from sqlalchemy.orm import sessionmaker from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_runtime import ToolRuntime @@ -16,11 +16,13 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ) +from core.tools.errors import ApiToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -92,7 +94,7 @@ class ApiToolManageService: @staticmethod def convert_schema_to_tool_bundles( - schema: str, extra_info: dict | None = None + schema: str, extra_info: dict[str, Any] | None = None ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ convert schema to tool bundles @@ -109,78 +111,92 @@ class ApiToolManageService: user_id: str, tenant_id: str, provider_name: str, - icon: dict, - credentials: dict, + icon: dict[str, Any], + credentials: dict[str, Any], schema_type: ApiProviderSchemaType, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - create api tool provider + Create a new API tool provider. + + :param user_id: The ID of the user creating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # Create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is not None: - raise ValueError(f"provider {provider_name} already exists") + if provider is not None: + raise ValueError(f"provider {provider_name} already exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if len(tool_bundles) > 100: - raise ValueError("the number of apis should be less than 100") + if len(tool_bundles) > 100: + raise ValueError("the number of apis should be less than 100") - # create db provider - db_provider = ApiToolProvider( - tenant_id=tenant_id, - user_id=user_id, - name=provider_name, - icon=json.dumps(icon), - schema=schema, - description=extra_info.get("description", ""), - schema_type_str=schema_type, - tools_str=json.dumps(jsonable_encoder(tool_bundles)), - credentials_str="{}", - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - ) + # create API tool provider + api_tool_provider = ApiToolProvider( + tenant_id=tenant_id, + user_id=user_id, + name=provider_name, + icon=json.dumps(icon), + schema=schema, + description=extra_info.get("description", ""), + schema_type_str=schema_type, + tools_str=json.dumps(jsonable_encoder(tool_bundles)), + credentials_str="{}", + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + ) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # create provider entity + provider_controller = ApiToolProviderController.from_db(api_tool_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - # encrypt credentials - encrypter, _ = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) - db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) + # encrypt credentials + encrypter, _ = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) + api_tool_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) - db.session.add(db_provider) - db.session.commit() + _session.add(api_tool_provider) - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) return {"result": "success"} @@ -212,16 +228,25 @@ class ApiToolManageService: @staticmethod def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]: """ - list api tool provider tools + List tools provided by a specific API tool provider. + + :param user_id: The ID of the user requesting the list. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :return: A list of ToolApiEntity objects. """ - provider: ApiToolProvider | None = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) if provider is None: raise ValueError(f"you have not added provider {provider_name}") @@ -244,110 +269,140 @@ class ApiToolManageService: tenant_id: str, provider_name: str, original_provider: str, - icon: dict, - credentials: dict, + icon: dict[str, Any], + credentials: dict[str, Any], _schema_type: ApiProviderSchemaType, schema: str, privacy_policy: str | None, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - update api tool provider + Update an existing API tool provider. + + :param user_id: The ID of the user updating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The new name of the API tool provider. + :param original_provider: The original name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param _schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + if provider is None: + raise ApiToolProviderNotFoundError(provider_name=original_provider, tenant_id=tenant_id) - # update db provider - provider.name = provider_name - provider.icon = json.dumps(icon) - provider.schema = schema - provider.description = extra_info.get("description", "") - provider.schema_type_str = schema_type - provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) - provider.privacy_policy = privacy_policy - provider.custom_disclaimer = custom_disclaimer + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + # update db provider + provider.name = provider_name + provider.icon = json.dumps(icon) + provider.schema = schema + provider.description = extra_info.get("description", "") + provider.schema_type_str = schema_type + provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) + provider.privacy_policy = privacy_policy + provider.custom_disclaimer = custom_disclaimer - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # create provider entity - provider_controller = ApiToolProviderController.from_db(provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # get original credentials if exists - encrypter, cache = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) + # create provider entity + provider_controller = ApiToolProviderController.from_db(provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - original_credentials = encrypter.decrypt(provider.credentials) - masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] + # get original credentials if exists + encrypter, cache = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) - credentials = dict(encrypter.encrypt(credentials)) - provider.credentials_str = json.dumps(credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - db.session.add(provider) - db.session.commit() + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + + credentials = dict(encrypter.encrypt(credentials)) + provider.credentials_str = json.dumps(credentials) + + _session.add(provider) + + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) # delete cache cache.delete() - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) - return {"result": "success"} @staticmethod def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + Delete an API tool provider. + + :param user_id: The ID of the user performing the deletion operation. + :param tenant_id: The ID of the workspace/tenant where the provider belongs. + :param provider_name: The unique name of the API tool provider to be deleted. + :raises ValueError: If the specified provider does not exist in the tenant. + :return: A dictionary indicating the result status. """ - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"you have not added provider {provider_name}") + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") - db.session.delete(provider) - db.session.commit() + _session.delete(provider) return {"result": "success"} @staticmethod - def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str) -> dict[str, Any]: """ - get api tool provider + Get API tool provider details. + + :param user_id: The ID of the user requesting the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider: The name of the API tool provider. + :return: A dictionary containing the provider details. """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) @@ -356,14 +411,24 @@ class ApiToolManageService: tenant_id: str, provider_name: str, tool_name: str, - credentials: dict, - parameters: dict, + credentials: dict[str, Any], + parameters: dict[str, Any], schema_type: ApiProviderSchemaType, schema: str, - ): + ) -> dict[str, Any]: """ - test api tool before adding api tool provider + Test an API tool before adding the API tool provider. + + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param tool_name: The name of the specific tool to test. + :param credentials: The credentials for the provider. + :param parameters: The parameters to pass to the tool. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :return: A dictionary containing the result or error message. """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema_type}") @@ -377,18 +442,21 @@ class ApiToolManageService: if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # create new session with automatic transaction management to get the provider + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if not db_provider: + if provider is None: # create a fake db provider - db_provider = ApiToolProvider( + provider = ApiToolProvider( tenant_id="", user_id="", name="", @@ -407,12 +475,12 @@ class ApiToolManageService: auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) + provider_controller = ApiToolProviderController.from_db(provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) # decrypt credentials - if db_provider.id: + if provider.id: encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, controller=provider_controller, @@ -443,14 +511,21 @@ class ApiToolManageService: @staticmethod def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ - list api tools + List all API tools for a specific tenant. + + :param tenant_id: The ID of the workspace/tenant. + :return: A list of ToolProviderApiEntity objects. """ # get all api providers - db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + # create new session with automatic transaction management + providers: list[ApiToolProvider] = [] + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + providers = list( + _session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + ) result: list[ToolProviderApiEntity] = [] - - for provider in db_providers: + for provider in providers: # convert provider controller to user provider provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index d529d2f065f..7bd056b8a0d 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -4,8 +4,8 @@ from collections.abc import Mapping from pathlib import Path from typing import Any -from sqlalchemy import exists, select -from sqlalchemy.orm import Session +from sqlalchemy import delete, exists, func, select, update +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE @@ -46,13 +46,16 @@ class BuiltinToolManageService: delete custom oauth client params """ tool_provider = ToolProviderID(provider) - with Session(db.engine) as session: - session.query(ToolOAuthTenantClient).filter_by( - tenant_id=tenant_id, - provider=tool_provider.provider_name, - plugin_id=tool_provider.plugin_id, - ).delete() - session.commit() + with sessionmaker(bind=db.engine).begin() as session: + session.execute( + delete(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ) + .execution_options(synchronize_session=False) + ) return {"result": "success"} @staticmethod @@ -144,21 +147,21 @@ class BuiltinToolManageService: tenant_id: str, provider: str, credential_id: str, - credentials: dict | None = None, + credentials: dict[str, Any] | None = None, name: str | None = None, ): """ update builtin tool provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # get if the provider exists - db_provider = ( - session.query(BuiltinToolProvider) + db_provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) - .first() + .limit(1) ) if db_provider is None: raise ValueError(f"you have not added provider {provider}") @@ -174,7 +177,7 @@ class BuiltinToolManageService: ) original_credentials = encrypter.decrypt(db_provider.credentials) - new_credentials: dict = { + new_credentials: dict[str, Any] = { key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE) for key, value in credentials.items() } @@ -203,9 +206,7 @@ class BuiltinToolManageService: db_provider.name = name - session.commit() except Exception as e: - session.rollback() raise ValueError(str(e)) return {"result": "success"} @@ -215,14 +216,14 @@ class BuiltinToolManageService: api_type: CredentialType, tenant_id: str, provider: str, - credentials: dict, + credentials: dict[str, Any], expires_at: int = -1, name: str | None = None, ): """ add builtin tool provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: try: lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" with redis_client.lock(lock, timeout=20): @@ -231,7 +232,13 @@ class BuiltinToolManageService: raise ValueError(f"provider {provider} does not need credentials") provider_count = ( - session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() + session.scalar( + select(func.count(BuiltinToolProvider.id)).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + ) + or 0 ) # check if the provider count is reached the limit @@ -281,9 +288,7 @@ class BuiltinToolManageService: ) session.add(db_provider) - session.commit() except Exception as e: - session.rollback() raise ValueError(str(e)) return {"result": "success"} @@ -309,16 +314,15 @@ class BuiltinToolManageService: def generate_builtin_tool_provider_name( session: Session, tenant_id: str, provider: str, credential_type: CredentialType ) -> str: - db_providers = ( - session.query(BuiltinToolProvider) - .filter_by( - tenant_id=tenant_id, - provider=provider, - credential_type=credential_type, + db_providers = session.scalars( + select(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.credential_type == credential_type, ) .order_by(BuiltinToolProvider.created_at.desc()) - .all() - ) + ).all() return generate_incremental_name( [provider.name for provider in db_providers], f"{credential_type.get_name()}", @@ -379,21 +383,20 @@ class BuiltinToolManageService: """ delete tool provider """ - with Session(db.engine) as session: - db_provider = ( - session.query(BuiltinToolProvider) + with sessionmaker(bind=db.engine).begin() as session: + db_provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) - .first() + .limit(1) ) if db_provider is None: raise ValueError(f"you have not added provider {provider}") session.delete(db_provider) - session.commit() # delete cache provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) @@ -409,20 +412,31 @@ class BuiltinToolManageService: """ set default provider """ - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # get provider - target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first() + target_provider = session.scalar( + select(BuiltinToolProvider) + .where(BuiltinToolProvider.id == id, BuiltinToolProvider.tenant_id == tenant_id) + .limit(1) + ) if target_provider is None: raise ValueError("provider not found") # clear default provider - session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True - ).update({"is_default": False}) + session.execute( + update(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.user_id == user_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.is_default.is_(True), + ) + .values(is_default=False) + .execution_options(synchronize_session=False) + ) # set new default provider target_provider.is_default = True - session.commit() return {"result": "success"} @@ -433,10 +447,13 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider_name) with Session(db.engine, autoflush=False) as session: - system_client: ToolOAuthSystemClient | None = ( - session.query(ToolOAuthSystemClient) - .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) - .first() + system_client = session.scalar( + select(ToolOAuthSystemClient) + .where( + ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id, + ToolOAuthSystemClient.provider == tool_provider.provider_name, + ) + .limit(1) ) return system_client is not None @@ -447,15 +464,15 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider) with Session(db.engine, autoflush=False) as session: - user_client: ToolOAuthTenantClient | None = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=tool_provider.provider_name, - plugin_id=tool_provider.plugin_id, - enabled=True, + user_client = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) return user_client is not None and user_client.enabled @@ -472,15 +489,15 @@ class BuiltinToolManageService: cache=NoOpProviderCredentialCache(), ) with Session(db.engine, autoflush=False) as session: - user_client: ToolOAuthTenantClient | None = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=tool_provider.provider_name, - plugin_id=tool_provider.plugin_id, - enabled=True, + user_client = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) oauth_params: Mapping[str, Any] | None = None if user_client: @@ -494,10 +511,13 @@ class BuiltinToolManageService: if not is_verified: return oauth_params - system_client: ToolOAuthSystemClient | None = ( - session.query(ToolOAuthSystemClient) - .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) - .first() + system_client = session.scalar( + select(ToolOAuthSystemClient) + .where( + ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id, + ToolOAuthSystemClient.provider == tool_provider.provider_name, + ) + .limit(1) ) if system_client: try: @@ -589,8 +609,8 @@ class BuiltinToolManageService: provider_name = provider_id_entity.provider_name if provider_id_entity.organization != "langgenius": - provider = ( - session.query(BuiltinToolProvider) + provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == full_provider_name, @@ -599,11 +619,11 @@ class BuiltinToolManageService: BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) - .first() + .limit(1) ) else: - provider = ( - session.query(BuiltinToolProvider) + provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_name) @@ -613,7 +633,7 @@ class BuiltinToolManageService: BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) - .first() + .limit(1) ) if provider is None: @@ -623,21 +643,21 @@ class BuiltinToolManageService: return provider except Exception: # it's an old provider without organization - return ( - session.query(BuiltinToolProvider) + return session.scalar( + select(BuiltinToolProvider) .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) .order_by( BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) - .first() + .limit(1) ) @staticmethod def save_custom_oauth_client_params( tenant_id: str, provider: str, - client_params: dict | None = None, + client_params: dict[str, Any] | None = None, enable_oauth_custom_client: bool | None = None, ): """ @@ -654,15 +674,15 @@ class BuiltinToolManageService: if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): raise ValueError(f"Provider {provider} is not a builtin or plugin provider") - with Session(db.engine) as session: - custom_client_params = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=tool_provider.plugin_id, - provider=tool_provider.provider_name, + with sessionmaker(bind=db.engine).begin() as session: + custom_client_params = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, ) - .first() + .limit(1) ) # if the record does not exist, create a basic record @@ -690,7 +710,6 @@ class BuiltinToolManageService: if enable_oauth_custom_client is not None: custom_client_params.enabled = enable_oauth_custom_client - session.commit() return {"result": "success"} @staticmethod @@ -700,14 +719,14 @@ class BuiltinToolManageService: """ with Session(db.engine) as session: tool_provider = ToolProviderID(provider) - custom_oauth_client_params: ToolOAuthTenantClient | None = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=tool_provider.plugin_id, - provider=tool_provider.provider_name, + custom_oauth_client_params = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, ) - .first() + .limit(1) ) if custom_oauth_client_params is None: return {} diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index deb26438a8f..89762d6772a 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -17,6 +17,7 @@ from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.entities import AuthActionType, AuthResult from core.mcp.error import MCPAuthError, MCPError from core.mcp.types import Tool as MCPTool from core.tools.entities.api_entities import ToolProviderApiEntity @@ -285,7 +286,7 @@ class MCPToolManageService: # Batch query all users to avoid N+1 problem user_ids = {provider.user_id for provider in mcp_providers} - users = self._session.query(Account).where(Account.id.in_(user_ids)).all() + users = self._session.scalars(select(Account).where(Account.id.in_(user_ids))).all() user_name_map = {user.id: user.name for user in users} return [ @@ -496,7 +497,13 @@ class MCPToolManageService: ) as mcp_client: return mcp_client.list_tools() - def execute_auth_actions(self, auth_result: Any) -> dict[str, str]: + _ACTION_TO_OAUTH: dict[AuthActionType, OAuthDataType] = { + AuthActionType.SAVE_CLIENT_INFO: OAuthDataType.CLIENT_INFO, + AuthActionType.SAVE_TOKENS: OAuthDataType.TOKENS, + AuthActionType.SAVE_CODE_VERIFIER: OAuthDataType.CODE_VERIFIER, + } + + def execute_auth_actions(self, auth_result: AuthResult) -> dict[str, str]: """ Execute the actions returned by the auth function. @@ -508,19 +515,13 @@ class MCPToolManageService: Returns: The response from the auth result """ - from core.mcp.entities import AuthAction, AuthActionType - - action: AuthAction for action in auth_result.actions: if action.provider_id is None or action.tenant_id is None: continue - if action.action_type == AuthActionType.SAVE_CLIENT_INFO: - self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO) - elif action.action_type == AuthActionType.SAVE_TOKENS: - self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS) - elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER: - self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER) + oauth_type = self._ACTION_TO_OAUTH.get(action.action_type) + if oauth_type is not None: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, oauth_type) return auth_result.response diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 7cd61e3162a..47aca9b0af3 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,7 +1,6 @@ -import json import logging from collections.abc import Mapping -from typing import Any, Union +from typing import Any from pydantic import TypeAdapter, ValidationError from yarl import URL @@ -21,6 +20,7 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolParameter, ToolProviderType, + emoji_icon_adapter, ) from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter @@ -48,21 +48,30 @@ class ToolTransformService: URL(dify_config.CONSOLE_API_URL or "/") / "console" / "api" / "workspaces" / "current" / "tool-provider" ) - if provider_type == ToolProviderType.BUILT_IN: - return str(url_prefix / "builtin" / provider_name / "icon") - elif provider_type in {ToolProviderType.API, ToolProviderType.WORKFLOW}: - try: - if isinstance(icon, str): - return json.loads(icon) + match provider_type: + case ToolProviderType.BUILT_IN: + return str(url_prefix / "builtin" / provider_name / "icon") + case ToolProviderType.API | ToolProviderType.WORKFLOW: + try: + if isinstance(icon, str): + parsed = emoji_icon_adapter.validate_json(icon) + return {"background": parsed["background"], "content": parsed["content"]} + return {"background": icon["background"], "content": icon["content"]} + except (ValueError, ValidationError, KeyError): + return {"background": "#252525", "content": "\ud83d\ude01"} + case ToolProviderType.MCP: + if isinstance(icon, Mapping): + return {"background": icon.get("background", ""), "content": icon.get("content", "")} return icon - except (json.JSONDecodeError, ValueError): - return {"background": "#252525", "content": "\ud83d\ude01"} - elif provider_type == ToolProviderType.MCP: - return icon - return "" + case ToolProviderType.PLUGIN | ToolProviderType.APP | ToolProviderType.DATASET_RETRIEVAL: + return "" + case _: + return "" @staticmethod - def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): + def repack_provider( + tenant_id: str, provider: dict[str, Any] | ToolProviderApiEntity | PluginDatasourceProviderEntity + ): """ repack provider @@ -419,7 +428,7 @@ class ToolTransformService: @staticmethod def convert_builtin_provider_to_credential_entity( - provider: BuiltinToolProvider, credentials: dict + provider: BuiltinToolProvider, credentials: dict[str, Any] ) -> ToolProviderCredentialApiEntity: return ToolProviderCredentialApiEntity( id=provider.id, diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index fb6b5bea24d..8f6600af035 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,19 +1,20 @@ import json import logging from datetime import datetime +from typing import Any -from graphon.model_runtime.utils.encoders import jsonable_encoder -from sqlalchemy import or_, select -from sqlalchemy.orm import Session +from sqlalchemy import delete, or_, select +from sqlalchemy.orm import sessionmaker from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration, emoji_icon_adapter from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow @@ -35,37 +36,50 @@ class WorkflowToolManageService: workflow_app_id: str, name: str, label: str, - icon: dict, + icon: dict[str, Any], description: str, parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): # check if the name is unique - existing_workflow_tool_provider = ( - db.session.query(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - # name or app_id - or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + existing_workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + # query if the name or app_id exists + existing_workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .limit(1) ) - .first() - ) + # if the name or app_id exists raise error if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") - app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first() + # query the app + app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + app = _session.scalar(select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)) + # if not found raise error if app is None: raise ValueError(f"App {workflow_app_id} not found") + # query the workflow workflow: Workflow | None = app.workflow + + # if not found raise error if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") + # check if workflow configuration is synced WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) + # create workflow tool provider workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -82,15 +96,18 @@ class WorkflowToolManageService: try: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: + logger.warning(e, exc_info=True) raise ValueError(str(e)) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): - session.add(workflow_tool_provider) + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + _session.add(workflow_tool_provider) + # keep the session open to make orm instances in the same session if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) + return {"result": "success"} @classmethod @@ -101,7 +118,7 @@ class WorkflowToolManageService: workflow_tool_id: str, name: str, label: str, - icon: dict, + icon: dict[str, Any], description: str, parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", @@ -109,6 +126,7 @@ class WorkflowToolManageService: ): """ Update a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: workflow tool id @@ -121,62 +139,82 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - # check if the name is unique - existing_workflow_tool_provider = ( - db.session.query(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.name == name, - WorkflowToolProvider.id != workflow_tool_id, - ) - .first() - ) + existing_workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + # query if the name exists for other tools + existing_workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .limit(1) + ) + + # if the name exists raise error if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} already exists") - workflow_tool_provider: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .first() - ) + # query the workflow tool provider + workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + # if not found raise error if workflow_tool_provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - app: App | None = ( - db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() - ) + # query the app + app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + app = _session.scalar( + select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1) + ) + # if not found raise error if app is None: raise ValueError(f"App {workflow_tool_provider.app_id} not found") + # query the workflow workflow: Workflow | None = app.workflow + + # if not found raise error if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + # check if workflow configuration is synced WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) - workflow_tool_provider.name = name - workflow_tool_provider.label = label - workflow_tool_provider.icon = json.dumps(icon) - workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) - workflow_tool_provider.privacy_policy = privacy_policy - workflow_tool_provider.version = workflow.version - workflow_tool_provider.updated_at = datetime.now() + with sessionmaker(db.engine).begin() as _session: + _session.add(workflow_tool_provider) - try: - WorkflowToolProviderController.from_db(workflow_tool_provider) - except Exception as e: - raise ValueError(str(e)) + # update workflow tool provider + workflow_tool_provider.name = name + workflow_tool_provider.label = label + workflow_tool_provider.icon = json.dumps(icon) + workflow_tool_provider.description = description + workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) + workflow_tool_provider.privacy_policy = privacy_policy + workflow_tool_provider.version = workflow.version + workflow_tool_provider.updated_at = datetime.now() - db.session.commit() + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) - if labels is not None: - ToolLabelManager.update_tool_labels( - ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels - ) + if labels is not None: + ToolLabelManager.update_tool_labels( + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), + labels, + session=_session, + ) return {"result": "success"} @@ -184,28 +222,32 @@ class WorkflowToolManageService: def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: """ List workflow tools. + :param user_id: the user id :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.scalars( - select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) - ).all() + + providers: list[WorkflowToolProvider] = [] + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + providers = list( + _session.scalars(select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)).all() + ) # Create a mapping from provider_id to app_id - provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools} + provider_id_to_app_id = {provider.id: provider.app_id for provider in providers} tools: list[WorkflowToolProviderController] = [] - for provider in db_tools: + for provider in providers: try: tools.append(ToolTransformService.workflow_provider_to_controller(provider)) except Exception: # skip deleted tools logger.exception("Failed to load workflow tool provider %s", provider.id) - labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) + labels = ToolLabelManager.get_tools_labels([tool for tool in tools if isinstance(tool, ToolProviderController)]) - result = [] + result: list[ToolProviderApiEntity] = [] for tool in tools: workflow_app_id = provider_id_to_app_id.get(tool.provider_id) @@ -230,15 +272,18 @@ class WorkflowToolManageService: def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Delete a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id """ - db.session.query(WorkflowToolProvider).where( - WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id - ).delete() - db.session.commit() + with sessionmaker(db.engine).begin() as _session: + _ = _session.execute( + delete(WorkflowToolProvider).where( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id + ) + ) return {"result": "success"} @@ -246,47 +291,59 @@ class WorkflowToolManageService: def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Get a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id :return: the tool """ - db_tool: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .first() - ) - return cls._get_workflow_tool(tenant_id, db_tool) + + tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + + return cls._get_workflow_tool(tenant_id, tool_provider) @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str): """ Get a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) - .first() - ) - return cls._get_workflow_tool(tenant_id, db_tool) + + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + tool_provider: WorkflowToolProvider | None = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .limit(1) + ) + + return cls._get_workflow_tool(tenant_id, tool_provider) @classmethod def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None): """ Get a workflow tool. + :db_tool: the database tool :return: the tool """ if db_tool is None: raise ValueError("Tool not found") - workflow_app: App | None = ( - db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first() - ) + workflow_app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + workflow_app = _session.scalar( + select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1) + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") @@ -309,7 +366,7 @@ class WorkflowToolManageService: "label": db_tool.label, "workflow_tool_id": db_tool.id, "workflow_app_id": db_tool.app_id, - "icon": json.loads(db_tool.icon), + "icon": emoji_icon_adapter.validate_json(db_tool.icon), "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "output_schema": output_schema, @@ -326,28 +383,32 @@ class WorkflowToolManageService: def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]: """ List workflow tool provider tools. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id :return: the list of tools """ - db_tool: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .first() - ) - if db_tool is None: + provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + + if provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - tool = ToolTransformService.workflow_provider_to_controller(db_tool) + tool = ToolTransformService.workflow_provider_to_controller(provider) workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) if len(workflow_tools) == 0: raise ValueError(f"Tool {workflow_tool_id} not found") return [ ToolTransformService.convert_tool_entity_to_api_entity( - tool=tool.get_tools(db_tool.tenant_id)[0], + tool=tool.get_tools(provider.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool), tenant_id=tenant_id, ) diff --git a/api/services/trigger/app_trigger_service.py b/api/services/trigger/app_trigger_service.py index 6d5a719f63f..723d29e947f 100644 --- a/api/services/trigger/app_trigger_service.py +++ b/api/services/trigger/app_trigger_service.py @@ -8,7 +8,7 @@ This service centralizes all AppTrigger-related business logic. import logging from sqlalchemy import update -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models.enums import AppTriggerStatus @@ -34,13 +34,12 @@ class AppTriggerService: """ try: - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: session.execute( update(AppTrigger) .where(AppTrigger.tenant_id == tenant_id, AppTrigger.status == AppTriggerStatus.ENABLED) .values(status=AppTriggerStatus.RATE_LIMITED) ) - session.commit() logger.info("Marked all enabled triggers as rate limited for tenant %s", tenant_id) except Exception: logger.exception("Failed to mark all enabled triggers as rate limited for tenant %s", tenant_id) diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 25e80770b83..a827222c1dc 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -2,7 +2,6 @@ import json import logging from datetime import datetime -from graphon.entities.graph_config import NodeConfigDict from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,6 +13,7 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError +from graphon.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 008d8bdb8af..6e14d996ea9 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -3,10 +3,10 @@ import logging import time as _time import uuid from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict -from sqlalchemy import desc, func -from sqlalchemy.orm import Session +from sqlalchemy import delete, desc, func, select +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE @@ -42,6 +42,10 @@ from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) +class VerifyCredentialsResult(TypedDict): + verified: bool + + class TriggerProviderService: """Service for managing trigger providers and credentials""" @@ -69,27 +73,28 @@ class TriggerProviderService: workflows_in_use_map: dict[str, int] = {} with Session(db.engine, expire_on_commit=False) as session: # Get all subscriptions - subscriptions_db = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) + subscriptions_db = session.scalars( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + ) .order_by(desc(TriggerSubscription.created_at)) - .all() - ) + ).all() subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db] if not subscriptions: return [] - usage_counts = ( - session.query( + usage_counts = session.execute( + select( WorkflowPluginTrigger.subscription_id, func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"), ) - .filter( + .where( WorkflowPluginTrigger.tenant_id == tenant_id, WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]), ) .group_by(WorkflowPluginTrigger.subscription_id) - .all() - ) + ).all() workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts} provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) @@ -146,15 +151,19 @@ class TriggerProviderService: """ try: provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Use distributed lock to prevent race conditions lock_key = f"trigger_provider_create_lock:{tenant_id}_{provider_id}" with redis_client.lock(lock_key, timeout=20): # Check provider count limit provider_count = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) - .count() + session.scalar( + select(func.count(TriggerSubscription.id)).where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + ) + ) + or 0 ) if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__: @@ -164,10 +173,14 @@ class TriggerProviderService: ) # Check if name already exists - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() + existing = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + TriggerSubscription.name == name, + ) + .limit(1) ) if existing: raise ValueError(f"Credential name '{name}' already exists for this provider") @@ -205,7 +218,6 @@ class TriggerProviderService: subscription.id = subscription_id or str(uuid.uuid4()) session.add(subscription) - session.commit() return { "result": "success", @@ -241,12 +253,17 @@ class TriggerProviderService: :param expires_at: Optional new expiration timestamp :return: Success response with updated subscription info """ - with Session(db.engine, expire_on_commit=False) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # Use distributed lock to prevent race conditions on the same subscription lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}" with redis_client.lock(lock_key, timeout=20): - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if not subscription: raise ValueError(f"Trigger subscription {subscription_id} not found") @@ -256,10 +273,14 @@ class TriggerProviderService: # Check for name uniqueness if name is being updated if name is not None and name != subscription.name: - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() + existing = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + TriggerSubscription.name == name, + ) + .limit(1) ) if existing: raise ValueError(f"Subscription name '{name}' already exists for this provider") @@ -302,8 +323,6 @@ class TriggerProviderService: if expires_at is not None: subscription.expires_at = expires_at - session.commit() - # Clear subscription cache delete_cache_for_subscription( tenant_id=tenant_id, @@ -319,11 +338,18 @@ class TriggerProviderService: with Session(db.engine, expire_on_commit=False) as session: subscription: TriggerSubscription | None = None if subscription_id: - subscription = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) else: - subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first() + subscription = session.scalar( + select(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant_id).limit(1) + ) if subscription: provider_controller = TriggerManager.get_trigger_provider( tenant_id, TriggerProviderID(subscription.provider_id) @@ -352,8 +378,13 @@ class TriggerProviderService: :param subscription_id: Subscription instance ID :return: Success response """ - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -404,8 +435,15 @@ class TriggerProviderService: :param subscription_id: Subscription instance ID :return: New token info """ - with Session(db.engine) as session: - subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + with sessionmaker(bind=db.engine).begin() as session: + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) + ) if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -448,7 +486,6 @@ class TriggerProviderService: # Update credentials subscription.credentials = dict(encrypter.encrypt(dict(refreshed_credentials.credentials))) subscription.credential_expires_at = refreshed_credentials.expires_at - session.commit() # Clear cache cache.delete() @@ -478,9 +515,14 @@ class TriggerProviderService: """ now_ts: int = int(now if now is not None else _time.time()) - with Session(db.engine) as session: - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + with sessionmaker(bind=db.engine).begin() as session: + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if subscription is None: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -531,7 +573,6 @@ class TriggerProviderService: # Persist refreshed properties and expires_at subscription.properties = dict(properties_encrypter.encrypt(dict(refreshed.properties))) subscription.expires_at = int(refreshed.expires_at) - session.commit() properties_cache.delete() logger.info( @@ -557,15 +598,15 @@ class TriggerProviderService: tenant_id=tenant_id, provider_id=provider_id ) with Session(db.engine, expire_on_commit=False) as session: - tenant_client: TriggerOAuthTenantClient | None = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - enabled=True, + tenant_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) oauth_params: Mapping[str, Any] | None = None @@ -583,10 +624,13 @@ class TriggerProviderService: return None # Check for system-level OAuth client - system_client: TriggerOAuthSystemClient | None = ( - session.query(TriggerOAuthSystemClient) - .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) - .first() + system_client = session.scalar( + select(TriggerOAuthSystemClient) + .where( + TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id, + TriggerOAuthSystemClient.provider == provider_id.provider_name, + ) + .limit(1) ) if system_client: @@ -607,10 +651,13 @@ class TriggerProviderService: if not is_verified: return False with Session(db.engine, expire_on_commit=False) as session: - system_client: TriggerOAuthSystemClient | None = ( - session.query(TriggerOAuthSystemClient) - .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) - .first() + system_client = session.scalar( + select(TriggerOAuthSystemClient) + .where( + TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id, + TriggerOAuthSystemClient.provider == provider_id.provider_name, + ) + .limit(1) ) return system_client is not None @@ -639,16 +686,16 @@ class TriggerProviderService: tenant_id=tenant_id, provider_id=provider_id ) - with Session(db.engine) as session: + with sessionmaker(bind=db.engine).begin() as session: # Find existing custom client params - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, ) - .first() + .limit(1) ) # Create new record if doesn't exist @@ -683,8 +730,6 @@ class TriggerProviderService: if enabled is not None: custom_client.enabled = enabled - session.commit() - return {"result": "success"} @classmethod @@ -697,14 +742,14 @@ class TriggerProviderService: :return: Masked OAuth client parameters """ with Session(db.engine) as session: - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, ) - .first() + .limit(1) ) if custom_client is None: @@ -733,13 +778,16 @@ class TriggerProviderService: :param provider_id: Provider identifier :return: Success response """ - with Session(db.engine) as session: - session.query(TriggerOAuthTenantClient).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ).delete() - session.commit() + with sessionmaker(bind=db.engine).begin() as session: + session.execute( + delete(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + ) + .execution_options(synchronize_session=False) + ) return {"result": "success"} @@ -753,15 +801,15 @@ class TriggerProviderService: :return: True if enabled, False otherwise """ with Session(db.engine, expire_on_commit=False) as session: - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, - enabled=True, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) return custom_client is not None @@ -771,7 +819,9 @@ class TriggerProviderService: Get a trigger subscription by the endpoint ID. """ with Session(db.engine, expire_on_commit=False) as session: - subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first() + subscription = session.scalar( + select(TriggerSubscription).where(TriggerSubscription.endpoint_id == endpoint_id).limit(1) + ) if not subscription: return None provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( @@ -800,7 +850,7 @@ class TriggerProviderService: provider_id: TriggerProviderID, subscription_id: str, credentials: dict[str, Any], - ) -> dict[str, Any]: + ) -> VerifyCredentialsResult: """ Verify credentials for an existing subscription without updating it. diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index d72c0416092..911331e357c 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,10 +5,9 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response -from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse @@ -21,6 +20,7 @@ from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_ from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger @@ -215,7 +215,7 @@ class TriggerService: not_found_in_cache.append(node_info) continue - with Session(db.engine) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: try: # lock the concurrent plugin trigger creation redis_client.lock(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:apps:{app.id}:lock", timeout=10) @@ -260,7 +260,6 @@ class TriggerService: cache.model_dump_json(), ex=60 * 60, ) - session.commit() # Update existing records if subscription_id changed for node_info in nodes_in_graph: @@ -290,14 +289,12 @@ class TriggerService: cache.model_dump_json(), ex=60 * 60, ) - session.commit() # delete the nodes not found in the graph for node_id in nodes_id_in_db: if node_id not in nodes_id_in_graph: session.delete(nodes_id_in_db[node_id]) redis_client.delete(f"{cls.__PLUGIN_TRIGGER_NODE_CACHE_KEY__}:{app.id}:{node_id}") - session.commit() except Exception: logger.exception("Failed to sync plugin trigger relationships for app %s", app.id) raise diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index c03275497d7..ca4e43e5164 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -3,16 +3,13 @@ import logging import mimetypes import secrets from collections.abc import Callable, Mapping, Sequence -from typing import Any +from typing import Any, NotRequired, TypedDict import orjson from flask import request -from graphon.entities.graph_config import NodeConfigDict -from graphon.file import FileTransferMethod -from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from werkzeug.datastructures import FileStorage from werkzeug.exceptions import RequestEntityTooLarge @@ -31,6 +28,9 @@ from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory +from graphon.entities.graph_config import NodeConfigDict +from graphon.file import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from models.enums import AppTriggerStatus, AppTriggerType from models.model import App from models.trigger import AppTrigger, WorkflowWebhookTrigger @@ -50,6 +50,26 @@ logger = logging.getLogger(__name__) _file_access_controller = DatabaseFileAccessController() +class RawWebhookDataDict(TypedDict): + method: str + headers: dict[str, str] + query_params: dict[str, str] + body: dict[str, Any] + files: dict[str, Any] + + +class ValidationResultDict(TypedDict): + valid: bool + error: NotRequired[str] + + +class WorkflowInputsDict(TypedDict): + webhook_data: RawWebhookDataDict + webhook_headers: dict[str, str] + webhook_query_params: dict[str, str] + webhook_body: dict[str, Any] + + class WebhookService: """Service for handling webhook operations.""" @@ -84,32 +104,32 @@ class WebhookService: """ with Session(db.engine) as session: # Get webhook trigger - webhook_trigger = ( - session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first() + webhook_trigger = session.scalar( + select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).limit(1) ) if not webhook_trigger: raise ValueError(f"Webhook not found: {webhook_id}") if is_debug: - workflow = ( - session.query(Workflow) - .filter( + workflow = session.scalar( + select(Workflow) + .where( Workflow.app_id == webhook_trigger.app_id, Workflow.version == Workflow.VERSION_DRAFT, ) .order_by(Workflow.created_at.desc()) - .first() + .limit(1) ) else: # Check if the corresponding AppTrigger exists - app_trigger = ( - session.query(AppTrigger) - .filter( + app_trigger = session.scalar( + select(AppTrigger) + .where( AppTrigger.app_id == webhook_trigger.app_id, AppTrigger.node_id == webhook_trigger.node_id, AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK, ) - .first() + .limit(1) ) if not app_trigger: @@ -126,14 +146,14 @@ class WebhookService: raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}") # Get workflow - workflow = ( - session.query(Workflow) - .filter( + workflow = session.scalar( + select(Workflow) + .where( Workflow.app_id == webhook_trigger.app_id, Workflow.version != Workflow.VERSION_DRAFT, ) .order_by(Workflow.created_at.desc()) - .first() + .limit(1) ) if not workflow: raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}") @@ -145,7 +165,7 @@ class WebhookService: @classmethod def extract_and_validate_webhook_data( cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict - ) -> dict[str, Any]: + ) -> RawWebhookDataDict: """Extract and validate webhook data in a single unified process. Args: @@ -165,7 +185,7 @@ class WebhookService: node_data = WebhookData.model_validate(node_config["data"], from_attributes=True) validation_result = cls._validate_http_metadata(raw_data, node_data) if not validation_result["valid"]: - raise ValueError(validation_result["error"]) + raise ValueError(validation_result.get("error", "Validation failed")) # Process and validate data according to configuration processed_data = cls._process_and_validate_data(raw_data, node_data) @@ -173,7 +193,7 @@ class WebhookService: return processed_data @classmethod - def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> dict[str, Any]: + def extract_webhook_data(cls, webhook_trigger: WorkflowWebhookTrigger) -> RawWebhookDataDict: """Extract raw data from incoming webhook request without type conversion. Args: @@ -189,7 +209,7 @@ class WebhookService: """ cls._validate_content_length() - data = { + data: RawWebhookDataDict = { "method": request.method, "headers": dict(request.headers), "query_params": dict(request.args), @@ -223,7 +243,7 @@ class WebhookService: return data @classmethod - def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]: + def _process_and_validate_data(cls, raw_data: RawWebhookDataDict, node_data: WebhookData) -> RawWebhookDataDict: """Process and validate webhook data according to node configuration. Args: @@ -577,21 +597,38 @@ class WebhookService: Raises: ValueError: If the value cannot be converted to the specified type """ - if param_type == SegmentType.STRING: - return value - elif param_type == SegmentType.NUMBER: - if not cls._can_convert_to_number(value): - raise ValueError(f"Cannot convert '{value}' to number") - numeric_value = float(value) - return int(numeric_value) if numeric_value.is_integer() else numeric_value - elif param_type == SegmentType.BOOLEAN: - lower_value = value.lower() - bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False} - if lower_value not in bool_map: - raise ValueError(f"Cannot convert '{value}' to boolean") - return bool_map[lower_value] - else: - raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'") + match param_type: + case SegmentType.STRING: + return value + case SegmentType.NUMBER: + if not cls._can_convert_to_number(value): + raise ValueError(f"Cannot convert '{value}' to number") + numeric_value = float(value) + return int(numeric_value) if numeric_value.is_integer() else numeric_value + case SegmentType.BOOLEAN: + lower_value = value.lower() + bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False} + if lower_value not in bool_map: + raise ValueError(f"Cannot convert '{value}' to boolean") + return bool_map[lower_value] + case ( + SegmentType.OBJECT + | SegmentType.FILE + | SegmentType.ARRAY_ANY + | SegmentType.ARRAY_STRING + | SegmentType.ARRAY_NUMBER + | SegmentType.ARRAY_OBJECT + | SegmentType.ARRAY_FILE + | SegmentType.ARRAY_BOOLEAN + | SegmentType.SECRET + | SegmentType.INTEGER + | SegmentType.FLOAT + | SegmentType.NONE + | SegmentType.GROUP + ): + raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'") + case _: + raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'") @classmethod def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any: @@ -664,7 +701,7 @@ class WebhookService: raise ValueError(f"Required header missing: {header_name}") @classmethod - def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]: + def _validate_http_metadata(cls, webhook_data: RawWebhookDataDict, node_data: WebhookData) -> ValidationResultDict: """Validate HTTP method and content-type. Args: @@ -708,7 +745,7 @@ class WebhookService: return content_type.split(";")[0].strip() @classmethod - def _validation_error(cls, error_message: str) -> dict[str, Any]: + def _validation_error(cls, error_message: str) -> ValidationResultDict: """Create a standard validation error response. Args: @@ -729,7 +766,7 @@ class WebhookService: return False @classmethod - def build_workflow_inputs(cls, webhook_data: dict[str, Any]) -> dict[str, Any]: + def build_workflow_inputs(cls, webhook_data: RawWebhookDataDict) -> WorkflowInputsDict: """Construct workflow inputs payload from webhook data. Args: @@ -747,7 +784,7 @@ class WebhookService: @classmethod def trigger_workflow_execution( - cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow + cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: RawWebhookDataDict, workflow: Workflow ) -> None: """Trigger workflow execution via AsyncWorkflowService. @@ -892,7 +929,7 @@ class WebhookService: logger.warning("Failed to acquire lock for webhook sync, app %s", app.id) raise RuntimeError("Failed to acquire lock for webhook trigger synchronization") - with Session(db.engine) as session: + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: # fetch the non-cached nodes from DB all_records = session.scalars( select(WorkflowWebhookTrigger).where( @@ -921,14 +958,12 @@ class WebhookService: redis_client.set( f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}", cache.model_dump_json(), ex=60 * 60 ) - session.commit() # delete the nodes not found in the graph for node_id in nodes_id_in_db: if node_id not in nodes_id_in_graph: session.delete(nodes_id_in_db[node_id]) redis_client.delete(f"{cls.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:{node_id}") - session.commit() except Exception: logger.exception("Failed to sync webhook relationships for app %s", app.id) raise diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 5427b7b3a7a..1529c2b98fa 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, overload +from configs import dify_config from graphon.file import File from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable from graphon.variables.segments import ( @@ -21,8 +22,6 @@ from graphon.variables.segments import ( ) from graphon.variables.utils import dumps_with_segments -from configs import dify_config - _MAX_DEPTH = 100 @@ -129,6 +128,7 @@ class VariableTruncator(BaseTruncator): used_size += self.calculate_json_size(key) if used_size > budget: truncated_mapping[key] = "..." + is_truncated = True continue value_budget = (budget - used_size) // (length - len(truncated_mapping)) if isinstance(value, Segment): @@ -164,12 +164,12 @@ class VariableTruncator(BaseTruncator): result = self._truncate_segment(segment, self._max_size_bytes) if result.value_size > self._max_size_bytes: - if isinstance(result.value, str): - result = self._truncate_string(result.value, self._max_size_bytes) - return TruncationResult(StringSegment(value=result.value), True) + if isinstance(result.value, StringSegment): + fallback_result = self._truncate_string(result.value.value, self._max_size_bytes) + return TruncationResult(StringSegment(value=fallback_result.value), True) # Apply final fallback - convert to JSON string and truncate - json_str = dumps_with_segments(result.value, ensure_ascii=False) + json_str = dumps_with_segments(result.value) if len(json_str) > self._max_size_bytes: json_str = json_str[: self._max_size_bytes] + "..." return TruncationResult(result=StringSegment(value=json_str), truncated=True) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index e7266cb8e94..58193d75a9a 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,21 +1,21 @@ import logging -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, select from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.entities import ParentMode from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument -from services.entities.knowledge_entities.knowledge_entities import ParentMode logger = logging.getLogger(__name__) diff --git a/api/services/website_service.py b/api/services/website_service.py index 6a521a9cc0c..ea584088bbc 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -3,7 +3,7 @@ from __future__ import annotations import datetime import json from dataclasses import dataclass -from typing import Any +from typing import Any, NotRequired, TypedDict, cast import httpx from flask_login import current_user @@ -91,7 +91,7 @@ class WebsiteCrawlApiRequest: return CrawlRequest(url=self.url, provider=self.provider, options=options) @classmethod - def from_args(cls, args: dict) -> WebsiteCrawlApiRequest: + def from_args(cls, args: dict[str, Any]) -> WebsiteCrawlApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") url = args.get("url") @@ -115,7 +115,7 @@ class WebsiteCrawlStatusApiRequest: job_id: str @classmethod - def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest: + def from_args(cls, args: dict[str, Any], job_id: str) -> WebsiteCrawlStatusApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") if not provider: @@ -126,6 +126,15 @@ class WebsiteCrawlStatusApiRequest: return cls(provider=provider, job_id=job_id) +class CrawlStatusDict(TypedDict): + status: str + job_id: str + total: int + current: int + data: list[Any] + time_consuming: NotRequired[str | float] + + class WebsiteService: """Service class for website crawling operations using different providers.""" @@ -154,7 +163,7 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str: + def _get_decrypted_api_key(cls, tenant_id: str, config: dict[str, Any]) -> str: """Decrypt and return the API key from config.""" api_key = config.get("api_key") if not api_key: @@ -162,7 +171,7 @@ class WebsiteService: return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key) @classmethod - def document_create_args_validate(cls, args: dict): + def document_create_args_validate(cls, args: dict[str, Any]): """Validate arguments for document creation.""" try: WebsiteCrawlApiRequest.from_args(args) @@ -186,7 +195,7 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) params: dict[str, Any] @@ -216,7 +225,7 @@ class WebsiteService: return {"status": "active", "job_id": job_id} @classmethod - def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]: # Convert CrawlOptions back to dict format for WaterCrawlProvider options = { "limit": request.options.limit, @@ -261,13 +270,13 @@ class WebsiteService: return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} @classmethod - def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]: + def get_crawl_status(cls, job_id: str, provider: str) -> CrawlStatusDict: """Get crawl status using string parameters.""" api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id) return cls.get_crawl_status_typed(api_request) @classmethod - def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]: + def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> CrawlStatusDict: """Get crawl status using typed request.""" api_key, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider) @@ -281,10 +290,10 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: + def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> CrawlStatusDict: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id) - crawl_status_data: dict[str, Any] = { + crawl_status_data: CrawlStatusDict = { "status": result["status"], "job_id": job_id, "total": result["total"] or 0, @@ -302,18 +311,18 @@ class WebsiteService: return crawl_status_data @classmethod - def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]: - return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)) + def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> CrawlStatusDict: + return cast(CrawlStatusDict, dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id))) @classmethod - def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: + def _get_jinareader_status(cls, job_id: str, api_key: str) -> CrawlStatusDict: response = _adaptive_http_client.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id}, ) data = response.json().get("data", {}) - crawl_status_data = { + crawl_status_data: CrawlStatusDict = { "status": data.get("status", "active"), "job_id": job_id, "total": len(data.get("urls", [])), @@ -355,7 +364,9 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: + def _get_firecrawl_url_data( + cls, job_id: str, url: str, api_key: str, config: dict[str, Any] + ) -> dict[str, Any] | None: crawl_data: list[FirecrawlDocumentData] | None = None file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): @@ -429,7 +440,7 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: + def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) params = {"onlyMainContent": request.only_main_content} return dict(firecrawl_app.scrape_url(url=request.url, params=params)) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index c1ad3f33ad0..5dedb9e3729 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,11 +1,6 @@ import json from typing import Any, TypedDict -from graphon.file import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.input_entities import VariableEntity from sqlalchemy import select from core.app.app_config.entities import ( @@ -24,6 +19,11 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.app_event import app_was_created from extensions.ext_database import db +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType @@ -170,34 +170,38 @@ class WorkflowConverter: graph = self._append_node(graph, llm_node) - if new_app_mode == AppMode.WORKFLOW: - # convert to end node by app mode - end_node = self._convert_to_end_node() - graph = self._append_node(graph, end_node) - else: - answer_node = self._convert_to_answer_node() - graph = self._append_node(graph, answer_node) - app_model_config_dict = app_config.app_model_config_dict - # features - if new_app_mode == AppMode.ADVANCED_CHAT: - features = { - "opening_statement": app_model_config_dict.get("opening_statement"), - "suggested_questions": app_model_config_dict.get("suggested_questions"), - "suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"), - "speech_to_text": app_model_config_dict.get("speech_to_text"), - "text_to_speech": app_model_config_dict.get("text_to_speech"), - "file_upload": app_model_config_dict.get("file_upload"), - "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), - "retriever_resource": app_model_config_dict.get("retriever_resource"), - } - else: - features = { - "text_to_speech": app_model_config_dict.get("text_to_speech"), - "file_upload": app_model_config_dict.get("file_upload"), - "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), - } + match new_app_mode: + case AppMode.WORKFLOW: + end_node = self._convert_to_end_node() + graph = self._append_node(graph, end_node) + features = { + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + } + case AppMode.ADVANCED_CHAT: + answer_node = self._convert_to_answer_node() + graph = self._append_node(graph, answer_node) + features = { + "opening_statement": app_model_config_dict.get("opening_statement"), + "suggested_questions": app_model_config_dict.get("suggested_questions"), + "suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"), + "speech_to_text": app_model_config_dict.get("speech_to_text"), + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + "retriever_resource": app_model_config_dict.get("retriever_resource"), + } + case _: + answer_node = self._convert_to_answer_node() + graph = self._append_node(graph, answer_node) + features = { + "text_to_speech": app_model_config_dict.get("text_to_speech"), + "file_upload": app_model_config_dict.get("file_upload"), + "sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"), + } # create workflow record workflow = Workflow( @@ -220,19 +224,23 @@ class WorkflowConverter: def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: app_mode_enum = AppMode.value_of(app_model.mode) app_config: EasyUIBasedAppConfig - if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: - app_model.mode = AppMode.AGENT_CHAT - app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config - ) - elif app_mode_enum == AppMode.CHAT: - app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) - elif app_mode_enum == AppMode.COMPLETION: - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config - ) - else: - raise ValueError("Invalid app mode") + effective_mode = ( + AppMode.AGENT_CHAT if app_model.is_agent and app_mode_enum != AppMode.AGENT_CHAT else app_mode_enum + ) + match effective_mode: + case AppMode.AGENT_CHAT: + app_model.mode = AppMode.AGENT_CHAT + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config + ) + case AppMode.CHAT: + app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) + case AppMode.COMPLETION: + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config + ) + case _: + raise ValueError("Invalid app mode") return app_config diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index b5ab176ad20..59e02ec9b96 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -3,10 +3,10 @@ import uuid from datetime import datetime from typing import Any, TypedDict -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog diff --git a/api/services/workflow_collaboration_service.py b/api/services/workflow_collaboration_service.py new file mode 100644 index 00000000000..cf2f5090526 --- /dev/null +++ b/api/services/workflow_collaboration_service.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import logging +import time +from collections.abc import Mapping + +from sqlalchemy import select + +from core.db.session_factory import session_factory +from models.account import Account +from models.model import App +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo + +logger = logging.getLogger(__name__) + + +class WorkflowCollaborationService: + def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None: + self._repository = repository + self._socketio = socketio + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(repository={self._repository})" + + def save_socket_identity(self, sid: str, user: Account) -> None: + """Persist the authenticated console user on the raw socket session.""" + self._socketio.save_session( + sid, + { + "user_id": user.id, + "username": user.name, + "avatar": user.avatar, + "tenant_id": user.current_tenant_id, + }, + ) + + def authorize_and_join_workflow_room(self, workflow_id: str, sid: str) -> tuple[str, bool] | None: + """ + Join a collaboration room only after validating the socket session and tenant-scoped app access. + + The Socket.IO payload still calls the room key `workflow_id`, but the identifier is the workflow app's + `App.id`. Returning `None` lets the controller reject the join before any Redis or room state is created. + """ + session = self._socketio.get_session(sid) + user_id = session.get("user_id") + tenant_id = session.get("tenant_id") + if not user_id or not tenant_id: + return None + + if not self._can_access_workflow(workflow_id, str(tenant_id)): + logger.warning( + "Workflow collaboration join rejected: workflow_id=%s tenant_id=%s user_id=%s sid=%s", + workflow_id, + tenant_id, + user_id, + sid, + ) + return None + + session_info: WorkflowSessionInfo = { + "user_id": str(user_id), + "username": str(session.get("username", "Unknown")), + "avatar": session.get("avatar"), + "sid": sid, + "connected_at": int(time.time()), + } + + self._repository.set_session_info(workflow_id, session_info) + + leader_sid = self.get_or_set_leader(workflow_id, sid) + is_leader = leader_sid == sid + + self._socketio.enter_room(sid, workflow_id) + self.broadcast_online_users(workflow_id) + + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + + return str(user_id), is_leader + + def _can_access_workflow(self, workflow_id: str, tenant_id: str) -> bool: + """Check room access without relying on Flask's app-context-bound scoped session.""" + with session_factory.create_session() as session: + app_id = session.scalar(select(App.id).where(App.id == workflow_id, App.tenant_id == tenant_id).limit(1)) + return app_id is not None + + def disconnect_session(self, sid: str) -> None: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return + + workflow_id = mapping["workflow_id"] + self._repository.delete_session(workflow_id, sid) + + self.handle_leader_disconnect(workflow_id, sid) + self.broadcast_online_users(workflow_id) + + def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + user_id = mapping["user_id"] + self.refresh_session_state(workflow_id, sid) + + event_type = data.get("type") + event_data = data.get("data") + timestamp = data.get("timestamp", int(time.time())) + + if not event_type: + return {"msg": "invalid event type"}, 400 + + if event_type == "sync_request": + leader_sid = self._repository.get_current_leader(workflow_id) + target_sid: str | None + if leader_sid and self.is_session_active(workflow_id, leader_sid): + target_sid = leader_sid + else: + if leader_sid: + self._repository.delete_leader(workflow_id) + target_sid = self._select_graph_leader(workflow_id, preferred_sid=sid) + if target_sid: + self._repository.set_leader(workflow_id, target_sid) + self.broadcast_leader_change(workflow_id, target_sid) + + if not target_sid: + return {"msg": "no_active_leader"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=target_sid, + ) + + return {"msg": "sync_request_forwarded"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=workflow_id, + skip_sid=sid, + ) + + return {"msg": "event_broadcasted"}, 200 + + def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + self.refresh_session_state(workflow_id, sid) + + self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid) + + return {"msg": "graph_update_broadcasted"}, 200 + + def get_or_set_leader(self, workflow_id: str, sid: str) -> str: + current_leader = self._repository.get_current_leader(workflow_id) + + if current_leader: + if self.is_session_active(workflow_id, current_leader): + return current_leader + self._repository.delete_session(workflow_id, current_leader) + self._repository.delete_leader(workflow_id) + + was_set = self._repository.set_leader_if_absent(workflow_id, sid) + + if was_set: + if current_leader: + self.broadcast_leader_change(workflow_id, sid) + return sid + + current_leader = self._repository.get_current_leader(workflow_id) + if current_leader: + return current_leader + + return sid + + def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if not current_leader: + return + + if current_leader != disconnected_sid: + return + + new_leader_sid = self._select_graph_leader(workflow_id) + if new_leader_sid: + self._repository.set_leader(workflow_id, new_leader_sid) + self.broadcast_leader_change(workflow_id, new_leader_sid) + else: + self._repository.delete_leader(workflow_id) + + def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str | None) -> None: + for sid in self._repository.get_session_sids(workflow_id): + try: + is_leader = new_leader_sid is not None and sid == new_leader_sid + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + except Exception: + logging.exception("Failed to emit leader status to session %s", sid) + + def get_current_leader(self, workflow_id: str) -> str | None: + return self._repository.get_current_leader(workflow_id) + + def _prune_inactive_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + """Remove inactive sessions from storage and return active sessions only.""" + sessions = self._repository.list_sessions(workflow_id) + if not sessions: + return [] + + active_sessions: list[WorkflowSessionInfo] = [] + stale_sids: list[str] = [] + for session in sessions: + sid = session["sid"] + if self.is_session_active(workflow_id, sid): + active_sessions.append(session) + else: + stale_sids.append(sid) + + for sid in stale_sids: + self._repository.delete_session(workflow_id, sid) + + return active_sessions + + def broadcast_online_users(self, workflow_id: str) -> None: + users = self._prune_inactive_sessions(workflow_id) + users.sort(key=lambda x: x.get("connected_at") or 0) + + leader_sid = self.get_current_leader(workflow_id) + previous_leader = leader_sid + active_sids = {user["sid"] for user in users} + if leader_sid and leader_sid not in active_sids: + self._repository.delete_leader(workflow_id) + leader_sid = None + + if not leader_sid and users: + leader_sid = self._select_graph_leader(workflow_id) + if leader_sid: + self._repository.set_leader(workflow_id, leader_sid) + + if leader_sid != previous_leader: + self.broadcast_leader_change(workflow_id, leader_sid) + + self._socketio.emit( + "online_users", + {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, + room=workflow_id, + ) + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + self._repository.refresh_session_state(workflow_id, sid) + self._ensure_leader(workflow_id, sid) + + def _ensure_leader(self, workflow_id: str, sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if current_leader and self.is_session_active(workflow_id, current_leader): + self._repository.expire_leader(workflow_id) + return + + if current_leader: + self._repository.delete_leader(workflow_id) + + self._repository.set_leader(workflow_id, sid) + self.broadcast_leader_change(workflow_id, sid) + + def _select_graph_leader(self, workflow_id: str, preferred_sid: str | None = None) -> str | None: + session_sids = [ + session["sid"] + for session in self._repository.list_sessions(workflow_id) + if session.get("graph_active", True) and self.is_session_active(workflow_id, session["sid"]) + ] + if not session_sids: + return None + if preferred_sid and preferred_sid in session_sids: + return preferred_sid + return session_sids[0] + + def is_session_active(self, workflow_id: str, sid: str) -> bool: + if not sid: + return False + + try: + if not self._socketio.manager.is_connected(sid, "/"): + return False + except AttributeError: + return False + + if not self._repository.session_exists(workflow_id, sid): + return False + + if not self._repository.sid_mapping_exists(sid): + return False + + return True diff --git a/api/services/workflow_comment_service.py b/api/services/workflow_comment_service.py new file mode 100644 index 00000000000..ff47e4f253e --- /dev/null +++ b/api/services/workflow_comment_service.py @@ -0,0 +1,564 @@ +import logging +from collections.abc import Sequence + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session, selectinload +from werkzeug.exceptions import Forbidden, NotFound + +from configs import dify_config +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now +from libs.helper import uuid_value +from models import App, TenantAccountJoin, WorkflowComment, WorkflowCommentMention, WorkflowCommentReply +from models.account import Account +from tasks.mail_workflow_comment_task import send_workflow_comment_mention_email_task + +logger = logging.getLogger(__name__) + + +class WorkflowCommentService: + """Service for managing workflow comments.""" + + @staticmethod + def _validate_content(content: str) -> None: + if len(content.strip()) == 0: + raise ValueError("Comment content cannot be empty") + + if len(content) > 1000: + raise ValueError("Comment content cannot exceed 1000 characters") + + @staticmethod + def _filter_valid_mentioned_user_ids( + mentioned_user_ids: Sequence[str], *, session: Session, tenant_id: str + ) -> list[str]: + """Return deduplicated UUID user IDs that belong to the tenant, preserving input order.""" + unique_user_ids: list[str] = [] + seen: set[str] = set() + for user_id in mentioned_user_ids: + if not isinstance(user_id, str): + continue + if not uuid_value(user_id): + continue + if user_id in seen: + continue + seen.add(user_id) + unique_user_ids.append(user_id) + if not unique_user_ids: + return [] + + tenant_member_ids = { + str(account_id) + for account_id in session.scalars( + select(TenantAccountJoin.account_id).where( + TenantAccountJoin.tenant_id == tenant_id, + TenantAccountJoin.account_id.in_(unique_user_ids), + ) + ).all() + } + + return [user_id for user_id in unique_user_ids if user_id in tenant_member_ids] + + @staticmethod + def _format_comment_excerpt(content: str, max_length: int = 200) -> str: + """Trim comment content for email display.""" + trimmed = content.strip() + if len(trimmed) <= max_length: + return trimmed + if max_length <= 3: + return trimmed[:max_length] + return f"{trimmed[: max_length - 3].rstrip()}..." + + @staticmethod + def _build_mention_email_payloads( + session: Session, + tenant_id: str, + app_id: str, + mentioner_id: str, + mentioned_user_ids: Sequence[str], + content: str, + ) -> list[dict[str, str]]: + """Prepare email payloads for mentioned users, including workflow app link.""" + if not mentioned_user_ids: + return [] + + candidate_user_ids = [user_id for user_id in mentioned_user_ids if user_id != mentioner_id] + if not candidate_user_ids: + return [] + + app_name_value = session.scalar(select(App.name).where(App.id == app_id, App.tenant_id == tenant_id)) + app_name = app_name_value if isinstance(app_name_value, str) and app_name_value else "Dify app" + commenter_name_value = session.scalar(select(Account.name).where(Account.id == mentioner_id)) + commenter_name = ( + commenter_name_value if isinstance(commenter_name_value, str) and commenter_name_value else "Dify user" + ) + comment_excerpt = WorkflowCommentService._format_comment_excerpt(content) + base_url = dify_config.CONSOLE_WEB_URL.rstrip("/") + app_url = f"{base_url}/app/{app_id}/workflow" + + accounts = session.scalars( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == tenant_id, Account.id.in_(candidate_user_ids)) + ).all() + + payloads: list[dict[str, str]] = [] + for account in accounts: + email = account.email + if not isinstance(email, str) or not email: + continue + mentioned_name = account.name if isinstance(account.name, str) and account.name else email + language = ( + account.interface_language + if isinstance(account.interface_language, str) and account.interface_language + else "en-US" + ) + payloads.append( + { + "language": language, + "to": email, + "mentioned_name": mentioned_name, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_excerpt, + "app_url": app_url, + } + ) + return payloads + + @staticmethod + def _dispatch_mention_emails(payloads: Sequence[dict[str, str]]) -> None: + """Enqueue mention notification emails.""" + for payload in payloads: + send_workflow_comment_mention_email_task.delay(**payload) + + @staticmethod + def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]: + """Get all comments for a workflow.""" + with Session(db.engine) as session: + # Get all comments with eager loading + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id) + .order_by(desc(WorkflowComment.created_at)) + ) + + comments = session.scalars(stmt).all() + + # Batch preload all Account objects to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, comments) + + return comments + + @staticmethod + def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None: + """Batch preload Account objects for comments, replies, and mentions.""" + # Collect all user IDs + user_ids: set[str] = set() + for comment in comments: + user_ids.add(comment.created_by) + if comment.resolved_by: + user_ids.add(comment.resolved_by) + user_ids.update(reply.created_by for reply in comment.replies) + user_ids.update(mention.mentioned_user_id for mention in comment.mentions) + + if not user_ids: + return + + # Batch query all accounts + accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all() + account_map = {str(account.id): account for account in accounts} + + # Cache accounts on objects + for comment in comments: + comment.cache_created_by_account(account_map.get(comment.created_by)) + comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None) + for reply in comment.replies: + reply.cache_created_by_account(account_map.get(reply.created_by)) + for mention in comment.mentions: + mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id)) + + @staticmethod + def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment: + """Get a specific comment.""" + + def _get_comment(session: Session) -> WorkflowComment: + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Preload accounts to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, [comment]) + + return comment + + if session is not None: + return _get_comment(session) + else: + with Session(db.engine, expire_on_commit=False) as session: + return _get_comment(session) + + @staticmethod + def create_comment( + tenant_id: str, + app_id: str, + created_by: str, + content: str, + position_x: float, + position_y: float, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Create a new workflow comment and send mention notification emails.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine) as session: + comment = WorkflowComment( + tenant_id=tenant_id, + app_id=app_id, + position_x=position_x, + position_y=position_y, + content=content, + created_by=created_by, + ) + + session.add(comment) + session.flush() # Get the comment ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids or [], + session=session, + tenant_id=tenant_id, + ) + for user_id in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention, not reply mention + mentioned_user_id=user_id, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + # Return only what we need - id and created_at + return {"id": comment.id, "created_at": comment.created_at} + + @staticmethod + def update_comment( + tenant_id: str, + app_id: str, + comment_id: str, + user_id: str, + content: str, + position_x: float | None = None, + position_y: float | None = None, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Update a workflow comment and notify newly mentioned users. + + `mentioned_user_ids=None` means "leave mentions unchanged". + Passing an explicit list replaces the existing comment mentions, including clearing them with `[]`. + """ + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Get comment with validation + stmt = select(WorkflowComment).where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Only the creator can update the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can update it") + + # Update comment fields + comment.content = content + if position_x is not None: + comment.position_x = position_x + if position_y is not None: + comment.position_y = position_y + + mention_email_payloads: list[dict[str, str]] = [] + if mentioned_user_ids is not None: + # Replace comment mentions only when the client explicitly sends the mention list. + existing_mentions = session.scalars( + select(WorkflowCommentMention).where( + WorkflowCommentMention.comment_id == comment.id, + WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions + ) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + filtered_mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids, + session=session, + tenant_id=tenant_id, + ) + new_mentioned_user_ids = [ + mentioned_user_id + for mentioned_user_id in filtered_mentioned_user_ids + if mentioned_user_id not in existing_mentioned_user_ids + ] + for mentioned_user_id in filtered_mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention + mentioned_user_id=mentioned_user_id, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": comment.id, "updated_at": comment.updated_at} + + @staticmethod + def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None: + """Delete a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + + # Only the creator can delete the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can delete it") + + # Delete associated mentions (both comment and reply mentions) + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id) + ).all() + for mention in mentions: + session.delete(mention) + + # Delete associated replies + replies = session.scalars( + select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id) + ).all() + for reply in replies: + session.delete(reply) + + session.delete(comment) + session.commit() + + @staticmethod + def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment: + """Resolve a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + if comment.resolved: + return comment + + comment.resolved = True + comment.resolved_at = naive_utc_now() + comment.resolved_by = user_id + session.commit() + + return comment + + @staticmethod + def create_reply( + comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None + ) -> dict: + """Add a reply to a workflow comment and notify mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Check if comment exists + comment = session.get(WorkflowComment, comment_id) + if not comment: + raise NotFound("Comment not found") + + reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by) + + session.add(reply) + session.flush() # Get the reply ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids or [], + session=session, + tenant_id=comment.tenant_id, + ) + for user_id in mentioned_user_ids: + # Create mention linking to specific reply + mention = WorkflowCommentMention(comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "created_at": reply.created_at} + + @staticmethod + def _get_reply_in_comment_scope( + *, + session: Session, + tenant_id: str, + app_id: str, + comment_id: str, + reply_id: str, + ) -> WorkflowCommentReply: + """Get a reply scoped to tenant/app/comment to prevent cross-thread mutations.""" + stmt = ( + select(WorkflowCommentReply) + .join(WorkflowComment, WorkflowComment.id == WorkflowCommentReply.comment_id) + .where( + WorkflowCommentReply.id == reply_id, + WorkflowCommentReply.comment_id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + .limit(1) + ) + reply = session.scalar(stmt) + if not reply: + raise NotFound("Reply not found") + return reply + + @staticmethod + def update_reply( + tenant_id: str, + app_id: str, + comment_id: str, + reply_id: str, + user_id: str, + content: str, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Update a comment reply and notify newly mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + reply = WorkflowCommentService._get_reply_in_comment_scope( + session=session, + tenant_id=tenant_id, + app_id=app_id, + comment_id=comment_id, + reply_id=reply_id, + ) + + # Only the creator can update the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can update it") + + reply.content = content + + # Update mentions - first remove existing mentions for this reply + existing_mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + # Add mentions + raw_mentioned_user_ids = mentioned_user_ids or [] + comment = session.get(WorkflowComment, reply.comment_id) + mentioned_user_ids = [] + if comment: + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + raw_mentioned_user_ids, + session=session, + tenant_id=comment.tenant_id, + ) + new_mentioned_user_ids = [ + user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids + ] + for user_id_str in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str + ) + session.add(mention) + + mention_email_payloads: list[dict[str, str]] = [] + if comment: + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + session.refresh(reply) # Refresh to get updated timestamp + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "updated_at": reply.updated_at} + + @staticmethod + def delete_reply(tenant_id: str, app_id: str, comment_id: str, reply_id: str, user_id: str) -> None: + """Delete a comment reply.""" + with Session(db.engine, expire_on_commit=False) as session: + reply = WorkflowCommentService._get_reply_in_comment_scope( + session=session, + tenant_id=tenant_id, + app_id=app_id, + comment_id=comment_id, + reply_id=reply_id, + ) + + # Only the creator can delete the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can delete it") + + # Delete associated mentions first + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id) + ).all() + for mention in mentions: + session.delete(mention) + + session.delete(reply) + session.commit() + + @staticmethod + def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment: + """Validate that a comment belongs to the specified tenant and app.""" + return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 9ed60bf86b8..96f936ff9ba 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -3,23 +3,11 @@ import json import logging from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from enum import StrEnum -from typing import Any, ClassVar +from typing import Any, ClassVar, NotRequired, TypedDict -from graphon.enums import NodeType -from graphon.file import File -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.variable_assigner.common.helpers import get_updated_variables -from graphon.variable_loader import VariableLoader -from graphon.variables import Segment, StringSegment, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import ( - ArrayFileSegment, - FileSegment, -) -from graphon.variables.types import SegmentType -from graphon.variables.utils import dumps_with_segments -from sqlalchemy import Engine, orm, select +from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session, sessionmaker @@ -39,6 +27,19 @@ from core.workflow.variable_prefixes import ( from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation @@ -145,7 +146,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -179,7 +180,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -190,7 +191,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -222,11 +223,10 @@ class WorkflowDraftVariableService: ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: - return ( - self._session.query(WorkflowDraftVariable) + return self._session.scalar( + select(WorkflowDraftVariable) .options(orm.selectinload(WorkflowDraftVariable.variable_file)) .where(WorkflowDraftVariable.id == variable_id) - .first() ) def get_draft_variables_by_selectors( @@ -254,20 +254,21 @@ class WorkflowDraftVariableService: # Alternatively, a `SELECT` statement could be constructed for each selector and # combined using `UNION` to fetch all rows. # Benchmarking indicates that both approaches yield comparable performance. - query = ( - self._session.query(WorkflowDraftVariable) - .options( - orm.selectinload(WorkflowDraftVariable.variable_file).selectinload( - WorkflowDraftVariableFile.upload_file + return list( + self._session.scalars( + select(WorkflowDraftVariable) + .options( + orm.selectinload(WorkflowDraftVariable.variable_file).selectinload( + WorkflowDraftVariableFile.upload_file + ) + ) + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.user_id == user_id, + or_(*ors), ) ) - .where( - WorkflowDraftVariable.app_id == app_id, - WorkflowDraftVariable.user_id == user_id, - or_(*ors), - ) ) - return query.all() def list_variables_without_values( self, app_id: str, page: int, limit: int, user_id: str @@ -277,18 +278,21 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.user_id == user_id, ] total = None - query = self._session.query(WorkflowDraftVariable).where(*criteria) + base_stmt = select(WorkflowDraftVariable).where(*criteria) if page == 1: - total = query.count() - variables = ( - # Do not load the `value` field - query.options( - orm.defer(WorkflowDraftVariable.value, raiseload=True), + from sqlalchemy import func as sa_func + + total = self._session.scalar(select(sa_func.count()).select_from(base_stmt.subquery())) + variables = list( + self._session.scalars( + # Do not load the `value` field + base_stmt.options( + orm.defer(WorkflowDraftVariable.value, raiseload=True), + ) + .order_by(WorkflowDraftVariable.created_at.desc()) + .limit(limit) + .offset((page - 1) * limit) ) - .order_by(WorkflowDraftVariable.created_at.desc()) - .limit(limit) - .offset((page - 1) * limit) - .all() ) return WorkflowDraftVariableList(variables=variables, total=total) @@ -299,11 +303,13 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.user_id == user_id, ] - query = self._session.query(WorkflowDraftVariable).where(*criteria) - variables = ( - query.options(orm.selectinload(WorkflowDraftVariable.variable_file)) - .order_by(WorkflowDraftVariable.created_at.desc()) - .all() + variables = list( + self._session.scalars( + select(WorkflowDraftVariable) + .options(orm.selectinload(WorkflowDraftVariable.variable_file)) + .where(*criteria) + .order_by(WorkflowDraftVariable.created_at.desc()) + ) ) return WorkflowDraftVariableList(variables=variables) @@ -326,8 +332,8 @@ class WorkflowDraftVariableService: return self._get_variable(app_id, node_id, name, user_id=user_id) def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None: - return ( - self._session.query(WorkflowDraftVariable) + return self._session.scalar( + select(WorkflowDraftVariable) .options(orm.selectinload(WorkflowDraftVariable.variable_file)) .where( WorkflowDraftVariable.app_id == app_id, @@ -335,7 +341,6 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.name == name, WorkflowDraftVariable.user_id == user_id, ) - .first() ) def update_variable( @@ -488,20 +493,20 @@ class WorkflowDraftVariableService: self._session.delete(variable) def delete_user_workflow_variables(self, app_id: str, user_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.user_id == user_id, ) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def delete_app_workflow_variables(self, app_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where(WorkflowDraftVariable.app_id == app_id) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def delete_workflow_draft_variable_file(self, deletions: list[DraftVarFileDeletion]): @@ -540,14 +545,14 @@ class WorkflowDraftVariableService: return self._delete_node_variables(app_id, node_id, user_id=user_id) def _delete_node_variables(self, app_id: str, node_id: str, user_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.user_id == user_id, ) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None: @@ -588,13 +593,11 @@ class WorkflowDraftVariableService: conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id) if conv_id is not None: - conversation = ( - self._session.query(Conversation) - .where( + conversation = self._session.scalar( + select(Conversation).where( Conversation.id == conv_id, Conversation.app_id == workflow.app_id, ) - .first() ) # Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB). if conversation is not None: @@ -723,8 +726,27 @@ def _batch_upsert_draft_variable( session.execute(stmt) -def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: - d: dict[str, Any] = { +class _InsertionDict(TypedDict): + id: str + app_id: str + user_id: str | None + last_edited_at: datetime | None + node_id: str + name: str + selector: str + value_type: SegmentType + value: str + node_execution_id: str | None + file_id: str | None + visible: NotRequired[bool] + editable: NotRequired[bool] + created_at: NotRequired[datetime] + updated_at: NotRequired[datetime] + description: NotRequired[str] + + +def _model_to_insertion_dict(model: WorkflowDraftVariable) -> _InsertionDict: + d: _InsertionDict = { "id": model.id, "app_id": model.app_id, "user_id": model.user_id, @@ -1045,7 +1067,7 @@ class DraftVariableSaver: filename = f"{self._generate_filename(name)}.txt" else: # For other types, store as JSON - original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False) + original_content_serialized = dumps_with_segments(value_seg.value) content_type = "application/json" filename = f"{self._generate_filename(name)}.json" @@ -1075,9 +1097,8 @@ class DraftVariableSaver: ) engine = bind = self._session.get_bind() assert isinstance(engine, Engine) - with Session(bind=engine, expire_on_commit=False) as session: + with sessionmaker(bind=engine, expire_on_commit=False).begin() as session: session.add(variable_file) - session.commit() return truncation_result.result, variable_file diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 601e9261fc6..5fca4447230 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,10 +9,6 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker @@ -26,6 +22,10 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index b903d8df5fd..29b9e72a009 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,5 +1,6 @@ import threading from collections.abc import Sequence +from typing import TypedDict from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker @@ -19,6 +20,14 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory +class WorkflowRunListArgs(TypedDict, total=False): + """Expected shape of the args dict passed to workflow run pagination methods.""" + + limit: int + last_id: str + status: str + + class WorkflowRunService: _session_factory: sessionmaker _workflow_run_repo: APIWorkflowRunRepository @@ -37,7 +46,10 @@ class WorkflowRunService: self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory) def get_paginate_advanced_chat_workflow_runs( - self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING + self, + app_model: App, + args: WorkflowRunListArgs, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, ) -> InfiniteScrollPagination: """ Get advanced chat app workflow run list @@ -73,7 +85,10 @@ class WorkflowRunService: return pagination def get_paginate_workflow_runs( - self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING + self, + app_model: App, + args: WorkflowRunListArgs, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, ) -> InfiniteScrollPagination: """ Get workflow run list diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 8f365c7c51f..d4b9095ce53 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,7 +5,41 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast -from graphon.entities import GraphInitParams, WorkflowNodeExecution +from sqlalchemy import exists, select +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.entities import PluginCredentialType +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager +from core.repositories import DifyCoreRepositoryFactory +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.trigger.constants import is_trigger_node_type +from core.workflow.human_input_adapter import ( + DeliveryChannelConfig, + adapt_human_input_node_data_for_graph, + parse_human_input_delivery_methods, +) +from core.workflow.node_factory import ( + LATEST_VERSION, + DifyGraphInitContext, + get_node_type_classes_mapping, + is_start_node_type, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace +from enums.cloud_plan import CloudPlan +from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated +from extensions.ext_database import db +from extensions.ext_storage import storage +from factories.file_factory import build_from_mapping, build_from_mappings +from graphon.entities import WorkflowNodeExecution from graphon.entities.graph_config import NodeConfigDict from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import ( @@ -30,34 +64,6 @@ from graphon.variable_loader import load_into_variable_pool from graphon.variables import VariableBase from graphon.variables.input_entities import VariableEntityType from graphon.variables.variables import Variable -from sqlalchemy import exists, select -from sqlalchemy.orm import Session, sessionmaker - -from configs import dify_config -from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager -from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.app.file_access import DatabaseFileAccessController -from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager -from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.trigger.constants import is_trigger_node_type -from core.workflow.human_input_compat import ( - DeliveryChannelConfig, - normalize_human_input_node_data_for_graph, - parse_human_input_delivery_methods, -) -from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type -from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from core.workflow.workflow_entry import WorkflowEntry -from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace -from enums.cloud_plan import CloudPlan -from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated -from extensions.ext_database import db -from extensions.ext_storage import storage -from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account from models.human_input import HumanInputFormRecipient, RecipientType @@ -66,7 +72,6 @@ from models.tools import WorkflowToolProvider from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService -from services.enterprise.plugin_manager_service import PluginCredentialType from services.errors.app import ( IsDraftWorkflowError, TriggerNodeLimitExceededError, @@ -194,6 +199,16 @@ class WorkflowService: return workflow + def get_accessible_app_ids(self, app_ids: Sequence[str], tenant_id: str) -> set[str]: + """ + Return app IDs that belong to the given tenant. + """ + if not app_ids: + return set() + + stmt = select(App.id).where(App.id.in_(app_ids), App.tenant_id == tenant_id) + return {str(app_id) for app_id in db.session.scalars(stmt).all()} + def get_all_published_workflow( self, *, @@ -236,8 +251,8 @@ class WorkflowService: self, *, app_model: App, - graph: dict, - features: dict, + graph: dict[str, Any], + features: dict[str, Any], unique_hash: str | None, account: Account, environment_variables: Sequence[VariableBase], @@ -291,6 +306,78 @@ class WorkflowService: # return draft workflow return workflow + def update_draft_workflow_environment_variables( + self, + *, + app_model: App, + environment_variables: Sequence[VariableBase], + account: Account, + ): + """ + Update draft workflow environment variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.environment_variables = environment_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_conversation_variables( + self, + *, + app_model: App, + conversation_variables: Sequence[VariableBase], + account: Account, + ): + """ + Update draft workflow conversation variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.conversation_variables = conversation_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_features( + self, + *, + app_model: App, + features: dict, + account: Account, + ): + """ + Update draft workflow features + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + # validate features structure + self.validate_features_structure(app_model=app_model, features=features) + + workflow.features = json.dumps(features) + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + def restore_published_workflow_to_draft( self, *, @@ -571,7 +658,7 @@ class WorkflowService: except Exception as e: raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") - def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None: + def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None: """ Validate load balancing credentials for a workflow node. @@ -635,7 +722,7 @@ class WorkflowService: # If we can't determine the status, assume load balancing is not enabled return False - def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]: + def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict[str, Any]]: """ Get all load balancing configurations for a model. @@ -659,7 +746,7 @@ class WorkflowService: _, custom_configs = model_load_balancing_service.get_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model" ) - all_configs = configs + custom_configs + all_configs = cast(list[dict[str, Any]], configs) + cast(list[dict[str, Any]], custom_configs) return [config for config in all_configs if config.get("credential_id")] @@ -704,7 +791,7 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config @@ -834,10 +921,10 @@ class WorkflowService: if workflow_node_execution is None: raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: outputs = workflow_node_execution.load_full_outputs(session, storage) - with Session(bind=db.engine) as session, session.begin(): + with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( session=session, app_id=app_model.id, @@ -848,7 +935,6 @@ class WorkflowService: user=account, ) draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs) - session.commit() enqueue_draft_node_execution_trace( execution=workflow_node_execution, @@ -977,7 +1063,7 @@ class WorkflowService: enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None - with Session(bind=db.engine) as session, session.begin(): + with sessionmaker(bind=db.engine).begin() as session: draft_var_saver = DraftVariableSaver( session=session, app_id=app_model.id, @@ -988,7 +1074,6 @@ class WorkflowService: enclosing_node_id=enclosing_node_id, ) draft_var_saver.save(outputs=outputs, process_data={}) - session.commit() return outputs @@ -1011,7 +1096,7 @@ class WorkflowService: raise ValueError("Node type must be human-input.") node_data = HumanInputNodeData.model_validate( - normalize_human_input_node_data_for_graph(node_config["data"]), + adapt_human_input_node_data_for_graph(node_config["data"]), from_attributes=True, ) delivery_method = self._resolve_human_input_delivery_method( @@ -1134,28 +1219,31 @@ class WorkflowService: node_config: NodeConfigDict, variable_pool: VariablePool, ) -> HumanInputNode: - graph_init_params = GraphInitParams( + run_context = build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=account.id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow.id, graph_config=workflow.graph_dict, - run_context=build_dify_run_context( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - user_id=account.id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ), + run_context=run_context, call_depth=0, ) + graph_init_params = graph_init_context.to_graph_init_params() graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, start_at=time.perf_counter(), ) + node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"])) node = HumanInputNode( - id=node_config["id"], - config=node_config, + node_id=node_config["id"], + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + runtime=DifyHumanInputNodeRuntime(run_context), ) return node @@ -1209,7 +1297,7 @@ class WorkflowService: return variable_pool def run_free_workflow_node( - self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] + self, node_data: dict[str, Any], tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: """ Run free workflow node @@ -1356,7 +1444,7 @@ class WorkflowService: node_execution.status = WorkflowNodeExecutionStatus.FAILED node_execution.error = error - def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: + def convert_to_workflow(self, app_model: App, account: Account, args: dict[str, Any]) -> App: """ Basic mode of chatbot app(expert mode) to workflow Completion App to Workflow App @@ -1416,19 +1504,20 @@ class WorkflowService: if node_type == BuiltinNodeTypes.HUMAN_INPUT: self._validate_human_input_node_data(node_data) - def validate_features_structure(self, app_model: App, features: dict): - if app_model.mode == AppMode.ADVANCED_CHAT: - return AdvancedChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=features, only_structure_validate=True - ) - elif app_model.mode == AppMode.WORKFLOW: - return WorkflowAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=features, only_structure_validate=True - ) - else: - raise ValueError(f"Invalid app mode: {app_model.mode}") + def validate_features_structure(self, app_model: App, features: dict[str, Any]): + match app_model.mode: + case AppMode.ADVANCED_CHAT: + return AdvancedChatAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True + ) + case AppMode.WORKFLOW: + return WorkflowAppConfigManager.config_validate( + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True + ) + case _: + raise ValueError(f"Invalid app mode: {app_model.mode}") - def _validate_human_input_node_data(self, node_data: dict) -> None: + def _validate_human_input_node_data(self, node_data: dict[str, Any]) -> None: """ Validate HumanInput node data format. @@ -1441,12 +1530,12 @@ class WorkflowService: from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) + HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") def update_workflow( - self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any] ) -> Workflow | None: """ Update workflow attributes @@ -1506,14 +1595,12 @@ class WorkflowService: # Don't use workflow.tool_published as it's not accurate for specific workflow versions # Check if there's a tool provider using this specific workflow version - tool_provider = ( - session.query(WorkflowToolProvider) - .where( + tool_provider = session.scalar( + select(WorkflowToolProvider).where( WorkflowToolProvider.tenant_id == workflow.tenant_id, WorkflowToolProvider.app_id == workflow.app_id, WorkflowToolProvider.version == workflow.version, ) - .first() ) if tool_provider: diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index ae55c9ee032..c9d4673c0ad 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import delete, select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -30,7 +31,9 @@ def add_document_to_index_task(dataset_document_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first() + dataset_document = session.scalar( + select(DatasetDocument).where(DatasetDocument.id == dataset_document_id).limit(1) + ) if not dataset_document: logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) return @@ -45,15 +48,14 @@ def add_document_to_index_task(dataset_document_id: str): if not dataset: raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.") - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.status == SegmentStatus.COMPLETED, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() documents = [] multimodal_documents = [] @@ -104,18 +106,15 @@ def add_document_to_index_task(dataset_document_id: str): index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) # delete auto disable log - session.query(DatasetAutoDisableLog).where( - DatasetAutoDisableLog.document_id == dataset_document.id - ).delete() + session.execute( + delete(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id) + ) # update segment to enable - session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update( - { - DocumentSegment.enabled: True, - DocumentSegment.disabled_at: None, - DocumentSegment.disabled_by: None, - DocumentSegment.updated_at: naive_utc_now(), - } + session.execute( + update(DocumentSegment) + .where(DocumentSegment.document_id == dataset_document.id) + .values(enabled=True, disabled_at=None, disabled_by=None, updated_at=naive_utc_now()) ) session.commit() diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index c734e1321ba..89844ef44b4 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from werkzeug.exceptions import NotFound from core.db.session_factory import session_factory @@ -35,7 +36,9 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: with session_factory.create_session() as session: # get app info - app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = session.scalar( + select(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").limit(1) + ) if app: try: @@ -53,8 +56,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: ) documents.append(document) # if annotation reply is enabled , batch add annotations' index - app_annotation_setting = ( - session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if app_annotation_setting: diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 41cf7ccbf61..6a9b52e7e5b 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -24,14 +24,16 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): start_at = time.perf_counter() # get app info with session_factory.create_session() as session: - app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = session.scalar( + select(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").limit(1) + ) annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id))) if not app: logger.info(click.style(f"App not found: {app_id}", fg="red")) return - app_annotation_setting = ( - session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if not app_annotation_setting: diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 2c07fe0f31f..4cbca13a92e 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -36,7 +36,9 @@ def enable_annotation_reply_task( start_at = time.perf_counter() # get app info with session_factory.create_session() as session: - app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = session.scalar( + select(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").limit(1) + ) if not app: logger.info(click.style(f"App not found: {app_id}", fg="red")) @@ -51,8 +53,8 @@ def enable_annotation_reply_task( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION ) - annotation_setting = ( - session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + annotation_setting = session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if annotation_setting: if dataset_collection_binding.id != annotation_setting.collection_binding_id: diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 8f2f5f261e6..c22e7e9918b 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,7 +7,6 @@ from typing import Annotated, Any from celery import shared_task from flask import current_app, json -from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -23,6 +22,7 @@ from core.app.entities.app_invoke_entities import ( from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index 0a73c912798..58092689929 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -7,15 +7,15 @@ with appropriate retry policies and error handling. import logging from datetime import UTC, datetime -from typing import Any +from typing import Any, NotRequired from celery import shared_task -from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from typing_extensions import TypedDict from configs import dify_config -from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.app.layers.timeslice_layer import TimeSliceLayer @@ -23,6 +23,7 @@ from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant @@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf logger = logging.getLogger(__name__) +class WorkflowGeneratorArgsDict(TypedDict): + inputs: dict[str, Any] + files: list[Any] + _skip_prepare_user_inputs: bool + workflow_id: NotRequired[str] + + @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) def execute_workflow_professional(task_data_dict: dict[str, Any]): """Execute workflow for professional tier with highest priority""" @@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]): ) -def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]: +def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict: """Build args passed into WorkflowAppGenerator.generate for Celery executions.""" - - args: dict[str, Any] = { + return { "inputs": dict(trigger_data.inputs), "files": list(trigger_data.files), - SKIP_PREPARE_USER_INPUTS_KEY: True, + "_skip_prepare_user_inputs": True, } - return args def _execute_workflow_common( @@ -156,7 +162,12 @@ def _execute_workflow_common( state_owner_user_id=workflow.created_by, ) - # Execute the workflow with the trigger type + # NOTE (hj24) + # Release the transaction before the blocking generate() call, + # otherwise the connection stays "idle in transaction" for hours. + session.commit() + # NOTE END + generator.generate( app_model=app_model, workflow=workflow, diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 747106d373e..56c371fcc12 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -1,9 +1,11 @@ import logging import time +from typing import cast import click from celery import shared_task from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -73,7 +75,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form try: # Fetch dataset in a fresh session to avoid DetachedInstanceError with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id) else: @@ -92,14 +94,16 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form # ============ Step 3: Delete metadata binding (separate short transaction) ============ try: with session_factory.create_session() as session: - deleted_count = ( - session.query(DatasetMetadataBinding) - .where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id.in_(document_ids), - ) - .delete(synchronize_session=False) + result = cast( + CursorResult, + session.execute( + delete(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id.in_(document_ids), + ) + ), ) + deleted_count = result.rowcount session.commit() logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id) except Exception: diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 20335d9b9f9..beb23d83541 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -3,18 +3,19 @@ import tempfile import time import uuid from pathlib import Path +from typing import Any import click import pandas as pd from celery import shared_task -from graphon.model_runtime.entities.model_entities import ModelType -from sqlalchemy import func +from sqlalchemy import func, select from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -51,8 +52,8 @@ def batch_create_segment_to_index_task( # Initialize variables with default values upload_file_key: str | None = None - dataset_config: dict | None = None - document_config: dict | None = None + dataset_config: dict[str, Any] | None = None + document_config: dict[str, Any] | None = None with session_factory.create_session() as session: try: @@ -140,10 +141,8 @@ def batch_create_segment_to_index_task( content = segment["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) - max_position = ( - session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document_config["id"]) - .scalar() + max_position = session.scalar( + select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document_config["id"]) ) segment_document = DocumentSegment( tenant_id=tenant_id, diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 0d51a743ad1..377d0e5cc70 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -112,7 +112,9 @@ def clean_dataset_task( segment_ids = [segment.id for segment in segments] for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() + image_files = session.scalars( + select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + ).all() for image_file in image_files: if image_file is None: continue @@ -150,20 +152,22 @@ def clean_dataset_task( ) session.execute(binding_delete_stmt) - session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() - session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() - session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete() + session.execute(delete(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id)) + session.execute(delete(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id)) + session.execute(delete(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id)) # delete dataset metadata - session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete() - session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete() + session.execute(delete(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id)) + session.execute(delete(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id)) # delete pipeline and workflow if pipeline_id: - session.query(Pipeline).where(Pipeline.id == pipeline_id).delete() - session.query(Workflow).where( - Workflow.tenant_id == tenant_id, - Workflow.app_id == pipeline_id, - Workflow.type == WorkflowType.RAG_PIPELINE, - ).delete() + session.execute(delete(Pipeline).where(Pipeline.id == pipeline_id)) + session.execute( + delete(Workflow).where( + Workflow.tenant_id == tenant_id, + Workflow.app_id == pipeline_id, + Workflow.type == WorkflowType.RAG_PIPELINE, + ) + ) # delete files if documents: file_ids = [] @@ -174,7 +178,7 @@ def clean_dataset_task( if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] file_ids.append(file_id) - files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() + files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() for file in files: storage.delete(file.key) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index a017e9114b9..a657cd553ab 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -32,7 +32,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i with session_factory.create_session() as session: try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Document has no dataset") @@ -63,7 +63,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if index_node_ids: index_processor = IndexProcessorFactory(doc_form).init_index_processor() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset: index_processor.clean( dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True @@ -94,7 +94,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i with session_factory.create_session() as session, session.begin(): if file_id: - file = session.query(UploadFile).where(UploadFile.id == file_id).first() + file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if file: try: storage.delete(file.key) @@ -124,10 +124,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i with session_factory.create_session() as session, session.begin(): # delete dataset metadata binding - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id == document_id, - ).delete() + session.execute( + delete(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id == document_id, + ) + ) end_at = time.perf_counter() logger.info( diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index c22ee761d8c..e3be24ac749 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -26,7 +26,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): total_index_node_ids = [] with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Document has no dataset") @@ -41,7 +41,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): total_index_node_ids.extend([segment.index_node_id for segment in segments]) with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset: index_processor.clean( dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index b3cbc73d6e8..3448325104f 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -27,7 +28,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N start_at = time.perf_counter() with session_factory.create_session() as session: - segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)) if not segment: logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return @@ -39,11 +40,10 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N try: # update segment status to indexing - session.query(DocumentSegment).filter_by(id=segment.id).update( - { - DocumentSegment.status: SegmentStatus.INDEXING, - DocumentSegment.indexing_at: naive_utc_now(), - } + session.execute( + update(DocumentSegment) + .where(DocumentSegment.id == segment.id) + .values(status=SegmentStatus.INDEXING, indexing_at=naive_utc_now()) ) session.commit() document = Document( @@ -81,11 +81,10 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N index_processor.load(dataset, [document]) # update segment to completed - session.query(DocumentSegment).filter_by(id=segment.id).update( - { - DocumentSegment.status: SegmentStatus.COMPLETED, - DocumentSegment.completed_at: naive_utc_now(), - } + session.execute( + update(DocumentSegment) + .where(DocumentSegment.id == segment.id) + .values(status=SegmentStatus.COMPLETED, completed_at=naive_utc_now()) ) session.commit() diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index fa844a8647d..c9b5121a087 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task # type: ignore +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): with session_factory.create_session() as session: try: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] for segment in segments: @@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): # clean keywords index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) index_processor.load(dataset, documents, with_keywords=False) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() elif action == "update": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() # add new index if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() @@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] multimodal_documents = [] @@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): index_processor.load( dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False ) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() else: diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 0047e04a178..36605359dc5 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -29,7 +29,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): with session_factory.create_session() as session: try: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") @@ -49,23 +49,24 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] for segment in segments: @@ -82,13 +83,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() elif action == "update": @@ -104,8 +109,10 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() @@ -115,15 +122,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] multimodal_documents = [] @@ -172,13 +178,17 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): index_processor.load( dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False ) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() else: diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index ecf6f9cb397..55a99dde7a1 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -1,6 +1,7 @@ import logging from celery import shared_task +from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_account_task(account_id): with session_factory.create_session() as session: - account = session.query(Account).where(Account.id == account_id).first() + account = session.scalar(select(Account).where(Account.id == account_id).limit(1)) try: if dify_config.BILLING_ENABLED: BillingService.delete_account(account_id) diff --git a/api/tasks/delete_conversation_task.py b/api/tasks/delete_conversation_task.py index 9664b8ac738..0b392f60965 100644 --- a/api/tasks/delete_conversation_task.py +++ b/api/tasks/delete_conversation_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import delete from core.db.session_factory import session_factory from models import ConversationVariable @@ -29,29 +30,21 @@ def delete_conversation_related_data(conversation_id: str): with session_factory.create_session() as session: try: - session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete( - synchronize_session=False + session.execute(delete(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id)) + + session.execute(delete(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id)) + + session.execute( + delete(ToolConversationVariables).where(ToolConversationVariables.conversation_id == conversation_id) ) - session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete( - synchronize_session=False - ) + session.execute(delete(ToolFile).where(ToolFile.conversation_id == conversation_id)) - session.query(ToolConversationVariables).where( - ToolConversationVariables.conversation_id == conversation_id - ).delete(synchronize_session=False) + session.execute(delete(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id)) - session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False) + session.execute(delete(Message).where(Message.conversation_id == conversation_id)) - session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete( - synchronize_session=False - ) - - session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False) - - session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False - ) + session.execute(delete(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id)) session.commit() diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index a6a2dcebc86..306a23aedae 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import delete +from sqlalchemy import delete, select from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -29,12 +29,12 @@ def delete_segment_from_index_task( start_at = time.perf_counter() with session_factory.create_session() as session: try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) return - dataset_document = session.query(Document).where(Document.id == document_id).first() + dataset_document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not dataset_document: return @@ -60,11 +60,9 @@ def delete_segment_from_index_task( ) if dataset.is_multimodal: # delete segment attachment binding - segment_attachment_bindings = ( - session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() - ) + segment_attachment_bindings = session.scalars( + select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + ).all() if segment_attachment_bindings: attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) @@ -77,7 +75,7 @@ def delete_segment_from_index_task( session.execute(segment_attachment_bind_delete_stmt) # delete upload file - session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + session.execute(delete(UploadFile).where(UploadFile.id.in_(attachment_ids))) session.commit() end_at = time.perf_counter() diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index bc45171623d..dd1a40844bf 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -24,7 +25,7 @@ def disable_segment_from_index_task(segment_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)) if not segment: logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 3cc267e821f..86e96ea3f0d 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -27,12 +27,12 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen start_at = time.perf_counter() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) return - dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document = session.scalar(select(DatasetDocument).where(DatasetDocument.id == document_id).limit(1)) if not dataset_document: logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) @@ -58,11 +58,9 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen index_node_ids = [segment.index_node_id for segment in segments] if dataset.is_multimodal: segment_ids = [segment.id for segment in segments] - segment_attachment_bindings = ( - session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() - ) + segment_attachment_bindings = session.scalars( + select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + ).all() if segment_attachment_bindings: attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] index_node_ids.extend(attachment_ids) @@ -87,16 +85,14 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green")) except Exception: # update segment error msg - session.query(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ).update( - { - "disabled_at": None, - "disabled_by": None, - "enabled": True, - } + session.execute( + update(DocumentSegment) + .where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .values(disabled_at=None, disabled_by=None, enabled=True) ) session.commit() finally: diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index f99e90062fb..90c80be3a1e 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -32,7 +32,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): tenant_id = None with session_factory.create_session() as session, session.begin(): - document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) + ) if not document: logger.info(click.style(f"Document not found: {document_id}", fg="red")) @@ -42,7 +44,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow")) return - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") @@ -87,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): ) with session_factory.create_session() as session, session.begin(): - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if document: document.indexing_status = IndexingStatus.ERROR document.error = "Datasource credential not found. Please reconnect your Notion workspace." @@ -112,7 +114,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: index_processor = IndexProcessorFactory(index_type).init_index_processor() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset: index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green")) @@ -120,7 +122,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.exception("Failed to clean vector index for document %s", document_id) with session_factory.create_session() as session, session.begin(): - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not document: logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow")) return @@ -140,7 +142,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: indexing_runner = IndexingRunner() with session_factory.create_session() as session: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if document: indexing_runner.run([document]) end_at = time.perf_counter() @@ -150,7 +152,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): except Exception as e: logger.exception("document_indexing_sync_task failed for document_id: %s", document_id) with session_factory.create_session() as session, session.begin(): - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if document: document.indexing_status = IndexingStatus.ERROR document.error = str(e) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 23a80fa1065..31dad7937cd 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -5,6 +5,7 @@ from typing import Any, Protocol import click from celery import current_app, shared_task +from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory @@ -53,11 +54,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): Usage: _document_indexing(dataset_id, document_ids) """ - documents = [] start_at = time.perf_counter() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow")) return @@ -79,8 +79,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) except Exception as e: for document_id in document_ids: - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) if document: document.indexing_status = IndexingStatus.ERROR @@ -92,8 +92,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # Phase 1: Update status to parsing (short transaction) with session_factory.create_session() as session, session.begin(): - documents = ( - session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all() + documents: list[Document] = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) for document in documents: @@ -122,7 +124,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # Trigger summary index generation for completed documents if enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled # Re-query dataset to get latest summary_index_setting (in case it was updated) - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.warning("Dataset %s not found after indexing", dataset_id) return @@ -134,10 +136,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.expire_all() # Check each document's indexing status and trigger summary generation if completed - documents = ( - session.query(Document) - .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) - .all() + documents = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) for document in documents: diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 62bce24de46..15f0e0162b6 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -28,7 +28,9 @@ def document_indexing_update_task(dataset_id: str, document_id: str): start_at = time.perf_counter() with session_factory.create_session() as session, session.begin(): - document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) + ) if not document: logger.info(click.style(f"Document not found: {document_id}", fg="red")) @@ -37,7 +39,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: return diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 13c651753ff..6bc58bdf9cd 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -82,7 +82,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st with session_factory.create_session() as session: try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset is None: logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) return diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 5ad17d75d4e..8334ca25881 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -29,7 +30,7 @@ def enable_segment_to_index_task(segment_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)) if not segment: logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index d90eb4c39fb..603abf62fe3 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -30,12 +30,12 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i """ start_at = time.perf_counter() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan")) return - dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + dataset_document = session.scalar(select(DatasetDocument).where(DatasetDocument.id == document_id).limit(1)) if not dataset_document: logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan")) @@ -123,17 +123,14 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i except Exception as e: logger.exception("enable segments to index failed") # update segment error msg - session.query(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.document_id == document_id, - ).update( - { - "error": str(e), - "status": "error", - "disabled_at": naive_utc_now(), - "enabled": False, - } + session.execute( + update(DocumentSegment) + .where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .values(error=str(e), status="error", disabled_at=naive_utc_now(), enabled=False) ) session.commit() finally: diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index e3d82d28516..9eda5716b8c 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -5,6 +5,7 @@ import time import click from celery import shared_task +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -39,12 +40,12 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: try: with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) return - document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() + document = session.scalar(select(DatasetDocument).where(DatasetDocument.id == document_id).limit(1)) if not document: logger.error(click.style(f"Document not found: {document_id}", fg="red")) return @@ -108,13 +109,12 @@ def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: if segment_ids: error_message = f"Summary generation failed: {str(e)}" with session_factory.create_session() as session: - session.query(DocumentSegment).filter( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset_id, - ).update( - { - DocumentSegment.error: error_message, - }, - synchronize_session=False, + session.execute( + update(DocumentSegment) + .where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + ) + .values(error=error_message) ) session.commit() diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index ca73b4d3745..fd743205a1b 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,8 +2,6 @@ import logging from datetime import timedelta from celery import shared_task -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker @@ -11,6 +9,8 @@ from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index a316eec7b95..2a60be7762a 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,15 +6,15 @@ from typing import Any import click from celery import shared_task -from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/mail_workflow_comment_task.py b/api/tasks/mail_workflow_comment_task.py new file mode 100644 index 00000000000..36d51f05142 --- /dev/null +++ b/api/tasks/mail_workflow_comment_task.py @@ -0,0 +1,65 @@ +import logging +import time + +import click +from celery import shared_task + +from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service + +logger = logging.getLogger(__name__) + + +@shared_task(queue="mail") +def send_workflow_comment_mention_email_task( + language: str, + to: str, + mentioned_name: str, + commenter_name: str, + app_name: str, + comment_content: str, + app_url: str, +): + """ + Send workflow comment mention email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + mentioned_name: Name of the mentioned user + commenter_name: Name of the comment author + app_name: Name of the app where the comment was made + comment_content: Comment content excerpt + app_url: Link to the app workflow page + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start workflow comment mention mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.WORKFLOW_COMMENT_MENTION, + language_code=language, + to=to, + template_context={ + "to": to, + "mentioned_name": mentioned_name, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_content, + "app_url": app_url, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Send workflow comment mention mail to {to} succeeded: latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("workflow comment mention email to %s failed", to) diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 3c5e152520a..d8fa73b42df 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -10,6 +10,7 @@ from typing import Any import click from celery import shared_task # type: ignore from flask import current_app, g +from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -118,20 +119,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], with Session(db.engine, expire_on_commit=False) as session: # Load required entities - account = session.query(Account).where(Account.id == user_id).first() + account = session.scalar(select(Account).where(Account.id == user_id).limit(1)) if not account: raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).where(Tenant.id == tenant_id).first() + tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id).limit(1)) if not tenant: raise ValueError(f"Tenant {tenant_id} not found") account.current_tenant = tenant - pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first() + pipeline = session.scalar(select(Pipeline).where(Pipeline.id == pipeline_id).limit(1)) if not pipeline: raise ValueError(f"Pipeline {pipeline_id} not found") - workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() + workflow = session.scalar(select(Workflow).where(Workflow.id == pipeline.workflow_id).limit(1)) if not workflow: raise ValueError(f"Workflow {pipeline.workflow_id} not found") diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 52f66dddb8f..8e1e096ed03 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -11,7 +11,8 @@ from typing import Any import click from celery import group, shared_task from flask import current_app, g -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity @@ -130,22 +131,22 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity - with Session(db.engine) as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Load required entities - account = session.query(Account).where(Account.id == user_id).first() + account = session.scalar(select(Account).where(Account.id == user_id).limit(1)) if not account: raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).where(Tenant.id == tenant_id).first() + tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id).limit(1)) if not tenant: raise ValueError(f"Tenant {tenant_id} not found") account.current_tenant = tenant - pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first() + pipeline = session.scalar(select(Pipeline).where(Pipeline.id == pipeline_id).limit(1)) if not pipeline: raise ValueError(f"Pipeline {pipeline_id} not found") - workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() + workflow = session.scalar(select(Workflow).where(Workflow.id == pipeline.workflow_id).limit(1)) if not workflow: raise ValueError(f"Workflow {pipeline.workflow_id} not found") diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index af72023da1e..73b121961c5 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -24,7 +25,9 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) + ) if not document: logger.info(click.style(f"Document not found: {document_id}", fg="red")) diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index 6f490ab7eab..e794195c92e 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -47,7 +47,7 @@ def regenerate_summary_index_task( try: with session_factory.create_session() as session: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) return @@ -84,8 +84,8 @@ def regenerate_summary_index_task( # For embedding_model change: directly query all segments with existing summaries # Don't require document indexing_status == "completed" # Include summaries with status "completed" or "error" (if they have content) - segments_with_summaries = ( - session.query(DocumentSegment, DocumentSegmentSummary) + segments_with_summaries = session.execute( + select(DocumentSegment, DocumentSegmentSummary) .join( DocumentSegmentSummary, DocumentSegment.id == DocumentSegmentSummary.chunk_id, @@ -110,8 +110,7 @@ def regenerate_summary_index_task( DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) - .all() - ) + ).all() if not segments_with_summaries: logger.info( @@ -215,8 +214,8 @@ def regenerate_summary_index_task( try: # Get all segments with existing summaries - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .join( DocumentSegmentSummary, DocumentSegment.id == DocumentSegmentSummary.chunk_id, @@ -229,8 +228,7 @@ def regenerate_summary_index_task( DocumentSegmentSummary.dataset_id == dataset_id, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if not segments: continue @@ -245,13 +243,13 @@ def regenerate_summary_index_task( summary_record = None try: # Get existing summary record - summary_record = ( - session.query(DocumentSegmentSummary) - .filter_by( - chunk_id=segment.id, - dataset_id=dataset_id, + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset_id, ) - .first() + .limit(1) ) if not summary_record: diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index b1840662ff4..5f1f0952af5 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -6,7 +6,7 @@ from typing import Any, cast import click import sqlalchemy as sa from celery import shared_task -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker @@ -99,7 +99,11 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): def _delete_app_model_configs(tenant_id: str, app_id: str): def del_model_config(session, model_config_id: str): - session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) + session.execute( + delete(AppModelConfig) + .where(AppModelConfig.id == model_config_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from app_model_configs where app_id=:app_id limit 1000""", @@ -111,7 +115,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): def _delete_app_site(tenant_id: str, app_id: str): def del_site(session, site_id: str): - session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) + session.execute(delete(Site).where(Site.id == site_id).execution_options(synchronize_session=False)) _delete_records( """select id from sites where app_id=:app_id limit 1000""", @@ -123,7 +127,9 @@ def _delete_app_site(tenant_id: str, app_id: str): def _delete_app_mcp_servers(tenant_id: str, app_id: str): def del_mcp_server(session, mcp_server_id: str): - session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) + session.execute( + delete(AppMCPServer).where(AppMCPServer.id == mcp_server_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from app_mcp_servers where app_id=:app_id limit 1000""", @@ -136,12 +142,14 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str): def del_api_token(session, api_token_id: str): # Fetch token details for cache invalidation - token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first() + token_obj = session.scalar(select(ApiToken).where(ApiToken.id == api_token_id).limit(1)) if token_obj: # Invalidate cache before deletion ApiTokenCache.delete(token_obj.token, token_obj.type) - session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) + session.execute( + delete(ApiToken).where(ApiToken.id == api_token_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from api_tokens where app_id=:app_id limit 1000""", @@ -153,7 +161,9 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): def _delete_installed_apps(tenant_id: str, app_id: str): def del_installed_app(session, installed_app_id: str): - session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) + session.execute( + delete(InstalledApp).where(InstalledApp.id == installed_app_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -165,7 +175,11 @@ def _delete_installed_apps(tenant_id: str, app_id: str): def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(session, recommended_app_id: str): - session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False) + session.execute( + delete(RecommendedApp) + .where(RecommendedApp.id == recommended_app_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", @@ -177,8 +191,10 @@ def _delete_recommended_apps(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(session, annotation_hit_history_id: str): - session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( - synchronize_session=False + session.execute( + delete(AppAnnotationHitHistory) + .where(AppAnnotationHitHistory.id == annotation_hit_history_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -189,8 +205,10 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): ) def del_annotation_setting(session, annotation_setting_id: str): - session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( - synchronize_session=False + session.execute( + delete(AppAnnotationSetting) + .where(AppAnnotationSetting.id == annotation_setting_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -203,7 +221,11 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_dataset_joins(tenant_id: str, app_id: str): def del_dataset_join(session, dataset_join_id: str): - session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) + session.execute( + delete(AppDatasetJoin) + .where(AppDatasetJoin.id == dataset_join_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from app_dataset_joins where app_id=:app_id limit 1000""", @@ -215,7 +237,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): def _delete_app_workflows(tenant_id: str, app_id: str): def del_workflow(session, workflow_id: str): - session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) + session.execute(delete(Workflow).where(Workflow.id == workflow_id).execution_options(synchronize_session=False)) _delete_records( """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -255,7 +277,11 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(session, workflow_app_log_id: str): - session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowAppLog) + .where(WorkflowAppLog.id == workflow_app_log_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -267,8 +293,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): def del_workflow_archive_log(session, workflow_archive_log_id: str): - session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowArchiveLog) + .where(WorkflowArchiveLog.id == workflow_archive_log_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -306,10 +334,14 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(session, conversation_id: str): - session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False + session.execute( + delete(PinnedConversation) + .where(PinnedConversation.conversation_id == conversation_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(Conversation).where(Conversation.id == conversation_id).execution_options(synchronize_session=False) ) - session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", @@ -329,17 +361,35 @@ def _delete_conversation_variables(*, app_id: str): def _delete_app_messages(tenant_id: str, app_id: str): def del_message(session, message_id: str): - session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False) - session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( - synchronize_session=False + session.execute( + delete(MessageFeedback) + .where(MessageFeedback.message_id == message_id) + .execution_options(synchronize_session=False) ) - session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) - session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False + session.execute( + delete(MessageAnnotation) + .where(MessageAnnotation.message_id == message_id) + .execution_options(synchronize_session=False) ) - session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) - session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) - session.query(Message).where(Message.id == message_id).delete() + session.execute( + delete(MessageChain) + .where(MessageChain.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(MessageAgentThought) + .where(MessageAgentThought.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(MessageFile).where(MessageFile.message_id == message_id).execution_options(synchronize_session=False) + ) + session.execute( + delete(SavedMessage) + .where(SavedMessage.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute(delete(Message).where(Message.id == message_id).execution_options(synchronize_session=False)) _delete_records( """select id from messages where app_id=:app_id limit 1000""", @@ -351,8 +401,10 @@ def _delete_app_messages(tenant_id: str, app_id: str): def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(session, tool_provider_id: str): - session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowToolProvider) + .where(WorkflowToolProvider.id == tool_provider_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -365,7 +417,9 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def _delete_app_tag_bindings(tenant_id: str, app_id: str): def del_tag_binding(session, tag_binding_id: str): - session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) + session.execute( + delete(TagBinding).where(TagBinding.id == tag_binding_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", @@ -377,7 +431,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): def _delete_end_users(tenant_id: str, app_id: str): def del_end_user(session, end_user_id: str): - session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) + session.execute(delete(EndUser).where(EndUser.id == end_user_id).execution_options(synchronize_session=False)) _delete_records( """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -389,7 +443,11 @@ def _delete_end_users(tenant_id: str, app_id: str): def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(session, trace_app_config_id: str): - session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False) + session.execute( + delete(TraceAppConfig) + .where(TraceAppConfig.id == trace_app_config_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", @@ -545,7 +603,9 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int: def _delete_app_triggers(tenant_id: str, app_id: str): def del_app_trigger(session, trigger_id: str): - session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) + session.execute( + delete(AppTrigger).where(AppTrigger.id == trigger_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -557,8 +617,10 @@ def _delete_app_triggers(tenant_id: str, app_id: str): def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def del_plugin_trigger(session, trigger_id: str): - session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowPluginTrigger) + .where(WorkflowPluginTrigger.id == trigger_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -571,8 +633,10 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def del_webhook_trigger(session, trigger_id: str): - session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowWebhookTrigger) + .where(WorkflowWebhookTrigger.id == trigger_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -585,7 +649,11 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def del_schedule_plan(session, plan_id: str): - session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowSchedulePlan) + .where(WorkflowSchedulePlan.id == plan_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -597,7 +665,11 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): def del_trigger_log(session, log_id: str): - session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowTriggerLog) + .where(WorkflowTriggerLog.id == log_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -607,7 +679,7 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): ) -def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: +def _delete_records(query_sql: str, params: dict[str, Any], delete_func: Callable, name: str) -> None: while True: with session_factory.create_session() as session: rs = session.execute(sa.text(query_sql), params) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 55259ab527e..74e8a012cf2 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import select +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -26,7 +26,7 @@ def remove_document_from_index_task(document_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - document = session.query(Document).where(Document.id == document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not document: logger.info(click.style(f"Document not found: {document_id}", fg="red")) return @@ -68,13 +68,15 @@ def remove_document_from_index_task(document_id: str): except Exception: logger.exception("clean dataset %s from index failed", dataset.id) # update segment to disable - session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update( - { - DocumentSegment.enabled: False, - DocumentSegment.disabled_at: naive_utc_now(), - DocumentSegment.disabled_by: document.disabled_by, - DocumentSegment.updated_at: naive_utc_now(), - } + session.execute( + update(DocumentSegment) + .where(DocumentSegment.document_id == document.id) + .values( + enabled=False, + disabled_at=naive_utc_now(), + disabled_by=document.disabled_by, + updated_at=naive_utc_now(), + ) ) session.commit() diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 4fcb0cf8049..7cc28d52264 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -32,15 +32,15 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ start_at = time.perf_counter() with session_factory.create_session() as session: try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) return - user = session.query(Account).where(Account.id == user_id).first() + user = session.scalar(select(Account).where(Account.id == user_id).limit(1)) if not user: logger.info(click.style(f"User not found: {user_id}", fg="red")) return - tenant = session.query(Tenant).where(Tenant.id == dataset.tenant_id).first() + tenant = session.scalar(select(Tenant).where(Tenant.id == dataset.tenant_id).limit(1)) if not tenant: raise ValueError("Tenant not found") user.current_tenant = tenant @@ -58,10 +58,8 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ "your subscription." ) except Exception as e: - document = ( - session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) if document: document.indexing_status = IndexingStatus.ERROR @@ -73,8 +71,8 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ return logger.info(click.style(f"Start retry document: {document_id}", fg="green")) - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) if not document: logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index aa6bce958ba..ab21f63f7ea 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -29,7 +29,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset is None: raise ValueError("Dataset not found") @@ -45,8 +45,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): "your subscription." ) except Exception as e: - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) if document: document.indexing_status = IndexingStatus.ERROR @@ -58,7 +58,9 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): return logger.info(click.style(f"Start sync website document: {document_id}", fg="green")) - document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) + ) if not document: logger.info(click.style(f"Document not found: {document_id}", fg="yellow")) return diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 56626e372ea..25ea53dfac5 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,7 +12,6 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -29,6 +28,7 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from enums.quota_type import QuotaType, unlimited +from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py index 7698a1a6b8b..1daf8f302cf 100644 --- a/api/tasks/trigger_subscription_refresh_tasks.py +++ b/api/tasks/trigger_subscription_refresh_tasks.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from typing import Any from celery import shared_task +from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -22,7 +23,11 @@ def _now_ts() -> int: def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None: - return session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + return session.scalar( + select(TriggerSubscription) + .where(TriggerSubscription.tenant_id == tenant_id, TriggerSubscription.id == subscription_id) + .limit(1) + ) def _refresh_oauth_if_expired(tenant_id: str, subscription: TriggerSubscription, now: int) -> None: diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 0c7f74c180a..5ca04fd7c2e 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -7,13 +7,14 @@ improving performance by offloading storage operations to background workers. import json import logging +from typing import Any from celery import shared_task -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) def save_workflow_execution_task( self, - execution_data: dict, + execution_data: dict[str, Any], tenant_id: str, app_id: str, triggered_from: str, diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index f25ebe3bae4..0d5475a56d2 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -7,15 +7,16 @@ improving performance by offloading storage operations to background workers. import json import logging +from typing import Any from celery import shared_task +from sqlalchemy import select + +from core.db.session_factory import session_factory from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter -from sqlalchemy import select - -from core.db.session_factory import session_factory from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) def save_workflow_node_execution_task( self, - execution_data: dict, + execution_data: dict[str, Any], tenant_id: str, app_id: str, triggered_from: str, diff --git a/api/templates/without-brand/workflow_comment_mention_template_en-US.html b/api/templates/without-brand/workflow_comment_mention_template_en-US.html new file mode 100644 index 00000000000..1ef8fe4e3f3 --- /dev/null +++ b/api/templates/without-brand/workflow_comment_mention_template_en-US.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

You were mentioned in a workflow comment

+
+

Hi {{ mentioned_name }},

+

{{ commenter_name }} mentioned you in {{ app_name }}.

+
+
+

{{ comment_content }}

+
+

Open {{ application_title }} to reply to the comment.

+
+ + + diff --git a/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html b/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html new file mode 100644 index 00000000000..8b9b2dbe717 --- /dev/null +++ b/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

你在工作流评论中被提及

+
+

你好,{{ mentioned_name }}:

+

{{ commenter_name }} 在 {{ app_name }} 中提及了你。

+
+
+

{{ comment_content }}

+
+

请在 {{ application_title }} 中查看并回复此评论。

+
+ + + diff --git a/api/templates/workflow_comment_mention_template_en-US.html b/api/templates/workflow_comment_mention_template_en-US.html new file mode 100644 index 00000000000..1ef8fe4e3f3 --- /dev/null +++ b/api/templates/workflow_comment_mention_template_en-US.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

You were mentioned in a workflow comment

+
+

Hi {{ mentioned_name }},

+

{{ commenter_name }} mentioned you in {{ app_name }}.

+
+
+

{{ comment_content }}

+
+

Open {{ application_title }} to reply to the comment.

+
+ + + diff --git a/api/templates/workflow_comment_mention_template_zh-CN.html b/api/templates/workflow_comment_mention_template_zh-CN.html new file mode 100644 index 00000000000..8b9b2dbe717 --- /dev/null +++ b/api/templates/workflow_comment_mention_template_zh-CN.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

你在工作流评论中被提及

+
+

你好,{{ mentioned_name }}:

+

{{ commenter_name }} 在 {{ app_name }} 中提及了你。

+
+
+

{{ comment_content }}

+
+

请在 {{ application_title }} 中查看并回复此评论。

+
+ + + diff --git a/api/tests/__init__.py b/api/tests/__init__.py index e69de29bb2d..ced6188ce8e 100644 --- a/api/tests/__init__.py +++ b/api/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite root package (enables ``import tests.integration_tests...`` with ``pythonpath = .``).""" diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index f84d39aeb5b..c07ab6d6bf6 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -33,6 +33,7 @@ REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false REDIS_DB=0 +REDIS_KEY_PREFIX= # PostgreSQL database configuration DB_USERNAME=postgres diff --git a/api/tests/integration_tests/__init__.py b/api/tests/integration_tests/__init__.py index e69de29bb2d..c66cd71b7e1 100644 --- a/api/tests/integration_tests/__init__.py +++ b/api/tests/integration_tests/__init__.py @@ -0,0 +1 @@ +"""Integration tests package.""" diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 44adadeaa5c..09078d196da 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -8,6 +8,7 @@ from collections.abc import Generator import pytest from flask import Flask from flask.testing import FlaskClient +from sqlalchemy import delete, select from sqlalchemy.orm import Session from app_factory import create_app @@ -47,7 +48,7 @@ os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage" os.environ.setdefault("STORAGE_TYPE", "opendal") os.environ.setdefault("OPENDAL_SCHEME", "fs") -_CACHED_APP = create_app() +_SIO_APP, _CACHED_APP = create_app() @pytest.fixture(scope="session") @@ -83,15 +84,15 @@ def setup_account(request) -> Generator[Account, None, None]: with _CACHED_APP.test_request_context(): with Session(bind=db.engine, expire_on_commit=False) as session: - account = session.query(Account).filter_by(email=email).one() + account = session.scalars(select(Account).filter_by(email=email)).one() yield account with _CACHED_APP.test_request_context(): - db.session.query(DifySetup).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Account).delete() - db.session.query(Tenant).delete() + db.session.execute(delete(DifySetup)) + db.session.execute(delete(TenantAccountJoin)) + db.session.execute(delete(Account)) + db.session.execute(delete(Tenant)) db.session.commit() diff --git a/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py deleted file mode 100644 index 038f37af5fa..00000000000 --- a/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ /dev/null @@ -1,47 +0,0 @@ -import uuid -from unittest import mock - -from controllers.console.app import workflow_draft_variable as draft_variable_api -from controllers.console.app import wraps -from factories.variable_factory import build_segment -from models import App, AppMode -from models.workflow import WorkflowDraftVariable -from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService - - -def _get_mock_srv_class() -> type[WorkflowDraftVariableService]: - return mock.create_autospec(WorkflowDraftVariableService) - - -class TestWorkflowDraftNodeVariableListApi: - def test_get(self, test_client, auth_header, monkeypatch): - srv_class = _get_mock_srv_class() - mock_app_model: App = App() - mock_app_model.id = str(uuid.uuid4()) - test_node_id = "test_node_id" - mock_app_model.mode = AppMode.ADVANCED_CHAT - mock_load_app_model = mock.Mock(return_value=mock_app_model) - - monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) - - var1 = WorkflowDraftVariable.new_node_variable( - app_id="test_app_1", - node_id="test_node_1", - name="str_var", - value=build_segment("str_value"), - node_execution_id=str(uuid.uuid4()), - ) - srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True) - srv_class.return_value = srv_instance - srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1]) - - response = test_client.get( - f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables", - headers=auth_header, - ) - assert response.status_code == 200 - response_dict = response.json - assert isinstance(response_dict, dict) - assert "items" in response_dict - assert len(response_dict["items"]) == 1 diff --git a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py deleted file mode 100644 index e55c12e6780..00000000000 --- a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py +++ /dev/null @@ -1,244 +0,0 @@ -"""Integration tests for Trigger Provider subscription permission verification.""" - -import uuid -from unittest import mock - -import pytest -from flask.testing import FlaskClient - -from controllers.console.workspace import trigger_providers as trigger_providers_api -from libs.datetime_utils import naive_utc_now -from models import Tenant -from models.account import Account, TenantAccountJoin, TenantAccountRole - - -class TestTriggerProviderSubscriptionPermissions: - """Test permission verification for Trigger Provider subscription endpoints.""" - - @pytest.fixture - def mock_account(self, monkeypatch: pytest.MonkeyPatch): - """Create a mock Account for testing.""" - - account = Account(name="Test User", email="test@example.com") - account.id = str(uuid.uuid4()) - account.last_active_at = naive_utc_now() - account.created_at = naive_utc_now() - account.updated_at = naive_utc_now() - - # Create mock tenant - tenant = Tenant(name="Test Tenant") - tenant.id = str(uuid.uuid4()) - - mock_session_instance = mock.Mock() - - mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER) - monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join)) - - mock_scalars_result = mock.Mock() - mock_scalars_result.one.return_value = tenant - monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result)) - - mock_session_context = mock.Mock() - mock_session_context.__enter__.return_value = mock_session_instance - monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context) - - account.current_tenant = tenant - account.current_tenant_id = tenant.id - return account - - @pytest.mark.parametrize( - ("role", "list_status", "get_status", "update_status", "create_status", "build_status", "delete_status"), - [ - # Admin/Owner can do everything - (TenantAccountRole.OWNER, 200, 200, 200, 200, 200, 200), - (TenantAccountRole.ADMIN, 200, 200, 200, 200, 200, 200), - # Editor can list, get, update (parameters), but not create, build, or delete - (TenantAccountRole.EDITOR, 200, 200, 200, 403, 403, 403), - # Normal user cannot do anything - (TenantAccountRole.NORMAL, 403, 403, 403, 403, 403, 403), - # Dataset operator cannot do anything - (TenantAccountRole.DATASET_OPERATOR, 403, 403, 403, 403, 403, 403), - ], - ) - def test_trigger_subscription_permissions( - self, - test_client: FlaskClient, - auth_header, - monkeypatch, - mock_account, - role: TenantAccountRole, - list_status: int, - get_status: int, - update_status: int, - create_status: int, - build_status: int, - delete_status: int, - ): - """Test that different roles have appropriate permissions for trigger subscription operations.""" - # Set user role - mock_account.role = role - - # Mock current user - monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) - - # Mock AccountService.load_user to prevent authentication issues - from services.account_service import AccountService - - mock_load_user = mock.Mock(return_value=mock_account) - monkeypatch.setattr(AccountService, "load_user", mock_load_user) - - # Test data - provider = "some_provider/some_trigger" - subscription_builder_id = str(uuid.uuid4()) - subscription_id = str(uuid.uuid4()) - - # Mock service methods - mock_list_subscriptions = mock.Mock(return_value=[]) - monkeypatch.setattr( - "services.trigger.trigger_provider_service.TriggerProviderService.list_trigger_provider_subscriptions", - mock_list_subscriptions, - ) - - mock_get_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", - mock_get_subscription_builder, - ) - - mock_update_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", - mock_update_subscription_builder, - ) - - mock_create_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", - mock_create_subscription_builder, - ) - - mock_update_and_build_builder = mock.Mock() - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_and_build_builder", - mock_update_and_build_builder, - ) - - mock_delete_provider = mock.Mock() - mock_delete_plugin_trigger = mock.Mock() - mock_db_session = mock.Mock() - mock_db_session.commit = mock.Mock() - - def mock_session_func(engine=None): - return mock_session_context - - mock_session_context = mock.Mock() - mock_session_context.__enter__.return_value = mock_db_session - mock_session_context.__exit__.return_value = None - - monkeypatch.setattr("services.trigger.trigger_provider_service.Session", mock_session_func) - monkeypatch.setattr("services.trigger.trigger_subscription_operator_service.Session", mock_session_func) - - monkeypatch.setattr( - "services.trigger.trigger_provider_service.TriggerProviderService.delete_trigger_provider", - mock_delete_provider, - ) - monkeypatch.setattr( - "services.trigger.trigger_subscription_operator_service.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription", - mock_delete_plugin_trigger, - ) - - # Test 1: List subscriptions (should work for Editor, Admin, Owner) - response = test_client.get( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/list", - headers=auth_header, - ) - assert response.status_code == list_status - - # Test 2: Get subscription builder (should work for Editor, Admin, Owner) - response = test_client.get( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/{subscription_builder_id}", - headers=auth_header, - ) - assert response.status_code == get_status - - # Test 3: Update subscription builder parameters (should work for Editor, Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscription_builder_id}", - headers=auth_header, - json={"parameters": {"webhook_url": "https://example.com/webhook"}}, - ) - assert response.status_code == update_status - - # Test 4: Create subscription builder (should only work for Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create", - headers=auth_header, - json={"credential_type": "api_key"}, - ) - assert response.status_code == create_status - - # Test 5: Build/activate subscription (should only work for Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscription_builder_id}", - headers=auth_header, - json={"name": "Test Subscription"}, - ) - assert response.status_code == build_status - - # Test 6: Delete subscription (should only work for Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{subscription_id}/subscriptions/delete", - headers=auth_header, - ) - assert response.status_code == delete_status - - @pytest.mark.parametrize( - ("role", "status"), - [ - (TenantAccountRole.OWNER, 200), - (TenantAccountRole.ADMIN, 200), - # Editor should be able to access logs for debugging - (TenantAccountRole.EDITOR, 200), - (TenantAccountRole.NORMAL, 403), - (TenantAccountRole.DATASET_OPERATOR, 403), - ], - ) - def test_trigger_subscription_logs_permissions( - self, - test_client: FlaskClient, - auth_header, - monkeypatch, - mock_account, - role: TenantAccountRole, - status: int, - ): - """Test that different roles have appropriate permissions for accessing subscription logs.""" - # Set user role - mock_account.role = role - - # Mock current user - monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) - - # Mock AccountService.load_user to prevent authentication issues - from services.account_service import AccountService - - mock_load_user = mock.Mock(return_value=mock_account) - monkeypatch.setattr(AccountService, "load_user", mock_load_user) - - # Test data - provider = "some_provider/some_trigger" - subscription_builder_id = str(uuid.uuid4()) - - # Mock service method - mock_list_logs = mock.Mock(return_value=[]) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.list_logs", - mock_list_logs, - ) - - # Test access to logs - response = test_client.get( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscription_builder_id}", - headers=auth_header, - ) - assert response.status_code == status diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 91245e879e4..a876b0c4aae 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,9 +1,8 @@ from collections.abc import Generator -from graphon.node_events import StreamCompletedEvent - from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage +from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3fdea109762..2392084c366 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,8 +1,8 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: @@ -70,19 +70,16 @@ def test_node_integration_minimal_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=_GP(), graph_runtime_state=_GS(vp), ) diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py deleted file mode 100644 index c1bb8e12453..00000000000 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ /dev/null @@ -1,375 +0,0 @@ -import unittest -from datetime import UTC, datetime -from unittest.mock import patch -from uuid import uuid4 - -import pytest -from graphon.file import File, FileTransferMethod, FileType -from sqlalchemy.orm import Session - -from core.app.file_access import DatabaseFileAccessController -from extensions.ext_database import db -from extensions.storage.storage_type import StorageType -from factories.file_factory import StorageKeyLoader -from models import ToolFile, UploadFile -from models.enums import CreatorUserRole - - -@pytest.mark.usefixtures("flask_req_ctx") -class TestStorageKeyLoader(unittest.TestCase): - """ - Integration tests for StorageKeyLoader class. - - Tests the batched loading of storage keys from the database for files - with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE. - """ - - def setUp(self): - """Set up test data before each test method.""" - self.session = db.session() - self.tenant_id = str(uuid4()) - self.user_id = str(uuid4()) - self.conversation_id = str(uuid4()) - - # Create test data that will be cleaned up after each test - self.test_upload_files = [] - self.test_tool_files = [] - - # Create StorageKeyLoader instance - self.loader = StorageKeyLoader( - self.session, - self.tenant_id, - access_controller=DatabaseFileAccessController(), - ) - - def tearDown(self): - """Clean up test data after each test method.""" - self.session.rollback() - - def _create_upload_file( - self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None - ) -> UploadFile: - """Helper method to create an UploadFile record for testing.""" - if file_id is None: - file_id = str(uuid4()) - if storage_key is None: - storage_key = f"test_storage_key_{uuid4()}" - if tenant_id is None: - tenant_id = self.tenant_id - - upload_file = UploadFile( - tenant_id=tenant_id, - storage_type=StorageType.LOCAL, - key=storage_key, - name="test_file.txt", - size=1024, - extension=".txt", - mime_type="text/plain", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=self.user_id, - created_at=datetime.now(UTC), - used=False, - ) - upload_file.id = file_id - - self.session.add(upload_file) - self.session.flush() - self.test_upload_files.append(upload_file) - - return upload_file - - def _create_tool_file( - self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None - ) -> ToolFile: - """Helper method to create a ToolFile record for testing.""" - if file_id is None: - file_id = str(uuid4()) - if file_key is None: - file_key = f"test_file_key_{uuid4()}" - if tenant_id is None: - tenant_id = self.tenant_id - - tool_file = ToolFile( - user_id=self.user_id, - tenant_id=tenant_id, - conversation_id=self.conversation_id, - file_key=file_key, - mimetype="text/plain", - original_url="http://example.com/file.txt", - name="test_tool_file.txt", - size=2048, - ) - tool_file.id = file_id - self.session.add(tool_file) - self.session.flush() - self.test_tool_files.append(tool_file) - - return tool_file - - def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: - """Helper method to create a File object for testing.""" - if tenant_id is None: - tenant_id = self.tenant_id - - # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods - file_related_id = None - remote_url = None - - if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): - file_related_id = related_id - elif transfer_method == FileTransferMethod.REMOTE_URL: - remote_url = "https://example.com/test_file.txt" - file_related_id = related_id - - return File( - id=str(uuid4()), # Generate new UUID for File.id - tenant_id=tenant_id, - type=FileType.DOCUMENT, - transfer_method=transfer_method, - related_id=file_related_id, - remote_url=remote_url, - filename="test_file.txt", - extension=".txt", - mime_type="text/plain", - size=1024, - storage_key="initial_key", - ) - - def test_load_storage_keys_local_file(self): - """Test loading storage keys for LOCAL_FILE transfer method.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Load storage keys - self.loader.load_storage_keys([file]) - - # Verify storage key was loaded correctly - assert file._storage_key == upload_file.key - - def test_load_storage_keys_remote_url(self): - """Test loading storage keys for REMOTE_URL transfer method.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) - - # Load storage keys - self.loader.load_storage_keys([file]) - - # Verify storage key was loaded correctly - assert file._storage_key == upload_file.key - - def test_load_storage_keys_tool_file(self): - """Test loading storage keys for TOOL_FILE transfer method.""" - # Create test data - tool_file = self._create_tool_file() - file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) - - # Load storage keys - self.loader.load_storage_keys([file]) - - # Verify storage key was loaded correctly - assert file._storage_key == tool_file.file_key - - def test_load_storage_keys_mixed_methods(self): - """Test batch loading with mixed transfer methods.""" - # Create test data for different transfer methods - upload_file1 = self._create_upload_file() - upload_file2 = self._create_upload_file() - tool_file = self._create_tool_file() - - file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) - file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) - - files = [file1, file2, file3] - - # Load storage keys - self.loader.load_storage_keys(files) - - # Verify all storage keys were loaded correctly - assert file1._storage_key == upload_file1.key - assert file2._storage_key == upload_file2.key - assert file3._storage_key == tool_file.file_key - - def test_load_storage_keys_empty_list(self): - """Test with empty file list.""" - # Should not raise any exceptions - self.loader.load_storage_keys([]) - - def test_load_storage_keys_ignores_legacy_file_tenant_id(self): - """Legacy file tenant_id should not override the loader tenant scope.""" - upload_file = self._create_upload_file() - file = self._create_file( - related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) - ) - - self.loader.load_storage_keys([file]) - - assert file._storage_key == upload_file.key - - def test_load_storage_keys_missing_file_id(self): - """Test with None file.related_id.""" - # Create a file with valid parameters first, then manually set related_id to None - file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) - file.related_id = None - - # Should raise ValueError for None file related_id - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) - - assert str(context.value) == "file id should not be None." - - def test_load_storage_keys_nonexistent_upload_file_records(self): - """Test with missing UploadFile database records.""" - # Create file with non-existent upload file id - non_existent_id = str(uuid4()) - file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Should raise ValueError for missing record - with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) - - def test_load_storage_keys_nonexistent_tool_file_records(self): - """Test with missing ToolFile database records.""" - # Create file with non-existent tool file id - non_existent_id = str(uuid4()) - file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) - - # Should raise ValueError for missing record - with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) - - def test_load_storage_keys_invalid_uuid(self): - """Test with invalid UUID format.""" - # Create a file with valid parameters first, then manually set invalid related_id - file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) - file.related_id = "invalid-uuid-format" - - # Should raise ValueError for invalid UUID - with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) - - def test_load_storage_keys_batch_efficiency(self): - """Test batched operations use efficient queries.""" - # Create multiple files of different types - upload_files = [self._create_upload_file() for _ in range(3)] - tool_files = [self._create_tool_file() for _ in range(2)] - - files = [] - files.extend( - [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] - ) - files.extend( - [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] - ) - - # Mock the session to count queries - with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: - self.loader.load_storage_keys(files) - - # Should make exactly 2 queries (one for upload_files, one for tool_files) - assert mock_scalars.call_count == 2 - - # Verify all storage keys were loaded correctly - for i, file in enumerate(files[:3]): - assert file._storage_key == upload_files[i].key - for i, file in enumerate(files[3:]): - assert file._storage_key == tool_files[i].file_key - - def test_load_storage_keys_tenant_isolation(self): - """Test that tenant isolation works correctly.""" - # Create files for different tenants - other_tenant_id = str(uuid4()) - - # Create upload file for current tenant - upload_file_current = self._create_upload_file() - file_current = self._create_file( - related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE - ) - - # Create upload file for other tenant (but don't add to cleanup list) - upload_file_other = UploadFile( - tenant_id=other_tenant_id, - storage_type=StorageType.LOCAL, - key="other_tenant_key", - name="other_file.txt", - size=1024, - extension=".txt", - mime_type="text/plain", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=self.user_id, - created_at=datetime.now(UTC), - used=False, - ) - upload_file_other.id = str(uuid4()) - self.session.add(upload_file_other) - self.session.flush() - - # Create file for other tenant but try to load with current tenant's loader - file_other = self._create_file( - related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id - ) - - # Should raise ValueError due to tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file_other]) - - assert "Upload file not found for id:" in str(context.value) - - # Current tenant's file should still work - self.loader.load_storage_keys([file_current]) - assert file_current._storage_key == upload_file_current.key - - def test_load_storage_keys_mixed_tenant_batch(self): - """Test batch with mixed tenant files (should fail on first mismatch).""" - # Create files for current tenant - upload_file_current = self._create_upload_file() - file_current = self._create_file( - related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE - ) - - # Create file for different tenant - other_tenant_id = str(uuid4()) - file_other = self._create_file( - related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id - ) - - # Should raise ValueError on tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file_current, file_other]) - - assert "Upload file not found for id:" in str(context.value) - - def test_load_storage_keys_duplicate_file_ids(self): - """Test handling of duplicate file IDs in the batch.""" - # Create upload file - upload_file = self._create_upload_file() - - # Create two File objects with same related_id - file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Should handle duplicates gracefully - self.loader.load_storage_keys([file1, file2]) - - # Both files should have the same storage key - assert file1._storage_key == upload_file.key - assert file2._storage_key == upload_file.key - - def test_load_storage_keys_session_isolation(self): - """Test that the loader uses the provided session correctly.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Create loader with different session (same underlying connection) - - with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader( - other_session, - self.tenant_id, - access_controller=DatabaseFileAccessController(), - ) - with pytest.raises(ValueError): - other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index ce04a158a82..c4146d5ccdd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,6 +4,9 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + # import monkeypatch from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import ( @@ -23,9 +26,6 @@ from graphon.model_runtime.entities.model_entities import ( ) from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - class MockModelClass(PluginModelClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: diff --git a/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py b/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py index 951a5ab4b43..0a19debc392 100644 --- a/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py +++ b/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py @@ -1,5 +1,5 @@ import pytest -from sqlalchemy import delete +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from models import Tenant @@ -61,7 +61,11 @@ class TestPluginPermissionLifecycle: assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS with session_factory.create_session() as session: - count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count() + count = session.scalar( + select(func.count()) + .select_from(TenantPluginPermission) + .where(TenantPluginPermission.tenant_id == tenant) + ) assert count == 1 diff --git a/api/tests/integration_tests/services/retention/test_messages_clean_service.py b/api/tests/integration_tests/services/retention/test_messages_clean_service.py index 348bb0af4ac..352960bcc28 100644 --- a/api/tests/integration_tests/services/retention/test_messages_clean_service.py +++ b/api/tests/integration_tests/services/retention/test_messages_clean_service.py @@ -3,7 +3,7 @@ import math import uuid import pytest -from sqlalchemy import delete +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from models import Tenant @@ -210,7 +210,7 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 with session_factory.create_session() as session: - remaining = session.query(Message).where(Message.id.in_(all_ids)).count() + remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids))) assert remaining == len(all_ids) def test_billing_disabled_deletes_all_in_range(self, seed_messages): @@ -231,7 +231,7 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == len(all_ids) with session_factory.create_session() as session: - remaining = session.query(Message).where(Message.id.in_(all_ids)).count() + remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids))) assert remaining == 0 def test_start_from_filters_correctly(self, seed_messages): @@ -254,7 +254,7 @@ class TestMessagesCleanServiceIntegration: with session_factory.create_session() as session: all_ids = list(msg_ids.values()) - remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()} + remaining_ids = set(session.scalars(select(Message.id).where(Message.id.in_(all_ids))).all()) assert msg_ids["old"] not in remaining_ids assert msg_ids["very_old"] in remaining_ids @@ -282,7 +282,7 @@ class TestMessagesCleanServiceIntegration: assert stats["batches"] >= expected_batches with session_factory.create_session() as session: - remaining = session.query(Message).where(Message.id.in_(msg_ids)).count() + remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(msg_ids))) assert remaining == 0 def test_no_messages_in_range_returns_empty_stats(self, seed_messages): @@ -319,9 +319,17 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 with session_factory.create_session() as session: - assert session.query(Message).where(Message.id == msg_id).count() == 0 - assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0 - assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0 + assert session.scalar(select(func.count()).select_from(Message).where(Message.id == msg_id)) == 0 + assert ( + session.scalar(select(func.count()).select_from(MessageFeedback).where(MessageFeedback.id == fb_id)) + == 0 + ) + assert ( + session.scalar( + select(func.count()).select_from(MessageAnnotation).where(MessageAnnotation.id == ann_id) + ) + == 0 + ) def test_factory_from_time_range_validation(self): with pytest.raises(ValueError, match="start_from"): diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 5c6636f31ec..e1306443381 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,11 +3,7 @@ import unittest import uuid import pytest -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType -from graphon.variables.variables import StringVariable -from sqlalchemy import delete +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID @@ -15,6 +11,10 @@ from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile @@ -38,21 +38,25 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def setUp(self): self._test_app_id = str(uuid.uuid4()) + self._test_user_id = str(uuid.uuid4()) self._session: Session = db.session() sys_var = WorkflowDraftVariable.new_sys_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="sys_var", value=build_segment("sys_value"), node_execution_id=self._node_exec_id, ) conv_var = WorkflowDraftVariable.new_conversation_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="conv_var", value=build_segment("conv_value"), ) node2_vars = [ WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node2_id, name="int_var", value=build_segment(1), @@ -61,6 +65,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): ), WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node2_id, name="str_var", value=build_segment("str_value"), @@ -70,6 +75,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): ] node1_var = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node1_id, name="str_var", value=build_segment("str_value"), @@ -141,24 +147,27 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_delete_node_variables(self): srv = self._get_test_srv() srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id) - node2_var_count = ( - self._session.query(WorkflowDraftVariable) + node2_var_count = self._session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == self._test_app_id, WorkflowDraftVariable.node_id == self._node2_id, + WorkflowDraftVariable.user_id == self._test_user_id, ) - .count() ) assert node2_var_count == 0 def test_delete_variable(self): srv = self._get_test_srv() - node_1_var = ( - self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one() - ) + node_1_var = self._session.scalars( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id) + ).one() srv.delete_variable(node_1_var) exists = bool( - self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first() + self._session.scalars( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id) + ).first() ) assert exists is False @@ -248,9 +257,7 @@ class TestDraftVariableLoader(unittest.TestCase): def tearDown(self): with Session(bind=db.engine, expire_on_commit=False) as session: - session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete( - synchronize_session=False - ) + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id)) session.commit() def test_variable_loader_with_empty_selector(self): @@ -431,9 +438,11 @@ class TestDraftVariableLoader(unittest.TestCase): # Clean up with Session(bind=db.engine) as session: # Query and delete by ID to ensure they're tracked in this session - session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() - session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() - session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id)) + session.execute( + delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id) + ) + session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id)) session.commit() # Clean up storage try: @@ -534,9 +543,11 @@ class TestDraftVariableLoader(unittest.TestCase): # Clean up with Session(bind=db.engine) as session: # Query and delete by ID to ensure they're tracked in this session - session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() - session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() - session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id)) + session.execute( + delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id) + ) + session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id)) session.commit() # Clean up storage try: diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 38dc8bbb281..4f444598b12 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import patch import pytest -from graphon.variables.segments import StringSegment -from sqlalchemy import delete +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -108,8 +108,12 @@ class TestDeleteDraftVariablesIntegration: app2_id = data["app2"].id with session_factory.create_session() as session: - app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + app1_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) + app2_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id) + ) assert app1_vars_before == 5 assert app2_vars_before == 5 @@ -117,8 +121,12 @@ class TestDeleteDraftVariablesIntegration: assert deleted_count == 5 with session_factory.create_session() as session: - app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + app1_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) + app2_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id) + ) assert app1_vars_after == 0 assert app2_vars_after == 5 @@ -130,7 +138,9 @@ class TestDeleteDraftVariablesIntegration: assert deleted_count == 5 with session_factory.create_session() as session: - remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + remaining_vars = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) assert remaining_vars == 0 def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data): @@ -143,14 +153,18 @@ class TestDeleteDraftVariablesIntegration: app1_id = data["app1"].id with session_factory.create_session() as session: - vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) assert vars_before == 5 deleted_count = _delete_draft_variables(app1_id) assert deleted_count == 5 with session_factory.create_session() as session: - vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) assert vars_after == 0 def test_batch_deletion_handles_large_dataset(self, app_and_tenant): @@ -175,7 +189,9 @@ class TestDeleteDraftVariablesIntegration: deleted_count = delete_draft_variables_batch(app.id, batch_size=8) assert deleted_count == 25 with session_factory.create_session() as session: - remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count() + remaining = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app.id) + ) assert remaining == 0 finally: with session_factory.create_session() as session: @@ -193,7 +209,6 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant from graphon.variables.types import SegmentType - from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -307,13 +322,17 @@ class TestDeleteDraftVariablesWithOffloadIntegration: mock_storage.delete.return_value = None with session_factory.create_session() as session: - draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = ( - session.query(WorkflowDraftVariableFile) - .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) - .count() + draft_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) + var_files_before = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + ) + upload_files_before = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) ) - upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -322,16 +341,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert deleted_count == 3 with session_factory.create_session() as session: - draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + draft_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = ( - session.query(WorkflowDraftVariableFile) + var_files_after = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) - .count() ) - upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + upload_files_after = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) assert var_files_after == 0 assert upload_files_after == 0 @@ -352,16 +375,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert deleted_count == 3 with session_factory.create_session() as session: - draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + draft_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = ( - session.query(WorkflowDraftVariableFile) + var_files_after = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) - .count() ) - upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + upload_files_after = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) assert var_files_after == 0 assert upload_files_after == 0 @@ -425,7 +452,6 @@ class TestDeleteDraftVariablesSessionCommit: def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" from graphon.variables.types import SegmentType - from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant @@ -579,7 +605,9 @@ class TestDeleteDraftVariablesSessionCommit: # Verify all data was deleted (proves transaction was committed) with session_factory.create_session() as session: - remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + remaining_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert deleted_count == 10 assert remaining_count == 0 @@ -592,7 +620,9 @@ class TestDeleteDraftVariablesSessionCommit: # Verify initial state with session_factory.create_session() as session: - initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + initial_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert initial_count == 10 # Perform deletion with small batch size to force multiple commits @@ -602,13 +632,17 @@ class TestDeleteDraftVariablesSessionCommit: # Verify all data is deleted in a new session (proves commits worked) with session_factory.create_session() as session: - final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + final_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert final_count == 0 # Verify specific IDs are deleted with session_factory.create_session() as session: - remaining_vars = ( - session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count() + remaining_vars = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) ) assert remaining_vars == 0 @@ -626,7 +660,9 @@ class TestDeleteDraftVariablesSessionCommit: app_id = data["app"].id with session_factory.create_session() as session: - initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + initial_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert initial_count == 10 # Delete all in a single batch @@ -635,7 +671,9 @@ class TestDeleteDraftVariablesSessionCommit: # Verify data is persisted with session_factory.create_session() as session: - final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + final_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert final_count == 0 def test_invalid_batch_size_raises_error(self, setup_commit_test_data): @@ -659,13 +697,17 @@ class TestDeleteDraftVariablesSessionCommit: # Verify initial state with session_factory.create_session() as session: - draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = ( - session.query(WorkflowDraftVariableFile) - .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) - .count() + draft_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) + var_files_before = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + ) + upload_files_before = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) ) - upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -676,13 +718,17 @@ class TestDeleteDraftVariablesSessionCommit: # Verify all data is persisted (deleted) in new session with session_factory.create_session() as session: - draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_after = ( - session.query(WorkflowDraftVariableFile) - .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) - .count() + draft_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) + var_files_after = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + ) + upload_files_after = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) ) - upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_after == 0 assert var_files_after == 0 assert upload_files_after == 0 diff --git a/api/tests/integration_tests/vdb/milvus/__init__.py b/api/tests/integration_tests/vdb/milvus/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/myscale/__init__.py b/api/tests/integration_tests/vdb/myscale/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/oceanbase/__init__.py b/api/tests/integration_tests/vdb/oceanbase/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/opengauss/__init__.py b/api/tests/integration_tests/vdb/opengauss/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/opensearch/__init__.py b/api/tests/integration_tests/vdb/opensearch/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py deleted file mode 100644 index 81ebb1d2f7d..00000000000 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ /dev/null @@ -1,235 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchConfig, OpenSearchVector -from core.rag.models.document import Document -from extensions import ext_redis - - -def get_example_text() -> str: - return "This is a sample text for testing purposes." - - -@pytest.fixture(scope="module") -def setup_mock_redis(): - ext_redis.redis_client.get = MagicMock(return_value=None) - ext_redis.redis_client.set = MagicMock(return_value=None) - - mock_redis_lock = MagicMock() - mock_redis_lock.__enter__ = MagicMock() - mock_redis_lock.__exit__ = MagicMock() - ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock) - - -class TestOpenSearchConfig: - def test_to_opensearch_params(self): - config = OpenSearchConfig( - host="localhost", - port=9200, - secure=True, - user="admin", - password="password", - ) - - params = config.to_opensearch_params() - - assert params["hosts"] == [{"host": "localhost", "port": 9200}] - assert params["use_ssl"] is True - assert params["verify_certs"] is True - assert params["connection_class"].__name__ == "Urllib3HttpConnection" - assert params["http_auth"] == ("admin", "password") - - @patch("boto3.Session", autospec=True) - @patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth", autospec=True) - def test_to_opensearch_params_with_aws_managed_iam( - self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock - ): - mock_credentials = MagicMock() - mock_boto_session.return_value.get_credentials.return_value = mock_credentials - - mock_auth_instance = mock_aws_signer_auth.return_value - aws_region = "ap-southeast-2" - aws_service = "aoss" - host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com" - port = 9201 - - config = OpenSearchConfig( - host=host, - port=port, - secure=True, - auth_method="aws_managed_iam", - aws_region=aws_region, - aws_service=aws_service, - ) - - params = config.to_opensearch_params() - - assert params["hosts"] == [{"host": host, "port": port}] - assert params["use_ssl"] is True - assert params["verify_certs"] is True - assert params["connection_class"].__name__ == "Urllib3HttpConnection" - assert params["http_auth"] is mock_auth_instance - - mock_aws_signer_auth.assert_called_once_with( - credentials=mock_credentials, region=aws_region, service=aws_service - ) - assert mock_boto_session.return_value.get_credentials.called - - -class TestOpenSearchVector: - def setup_method(self): - self.collection_name = "test_collection" - self.example_doc_id = "example_doc_id" - self.vector = OpenSearchVector( - collection_name=self.collection_name, - config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"), - ) - self.vector._client = MagicMock() - - @pytest.mark.parametrize( - ("search_response", "expected_length", "expected_doc_id"), - [ - ( - { - "hits": { - "total": {"value": 1}, - "hits": [ - { - "_source": { - "page_content": get_example_text(), - "metadata": {"document_id": "example_doc_id"}, - } - } - ], - } - }, - 1, - "example_doc_id", - ), - ({"hits": {"total": {"value": 0}, "hits": []}}, 0, None), - ], - ) - def test_search_by_full_text(self, search_response, expected_length, expected_doc_id): - self.vector._client.search.return_value = search_response - - hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) - assert len(hits_by_full_text) == expected_length - if expected_length > 0: - assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id - - def test_search_by_vector(self): - vector = [0.1] * 128 - mock_response = { - "hits": { - "total": {"value": 1}, - "hits": [ - { - "_source": { - Field.CONTENT_KEY: get_example_text(), - Field.METADATA_KEY: {"document_id": self.example_doc_id}, - }, - "_score": 1.0, - } - ], - } - } - self.vector._client.search.return_value = mock_response - - hits_by_vector = self.vector.search_by_vector(query_vector=vector) - - print("Hits by vector:", hits_by_vector) - print("Expected document ID:", self.example_doc_id) - print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits") - - assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}" - assert hits_by_vector[0].metadata["document_id"] == self.example_doc_id, ( - f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" - ) - - def test_get_ids_by_metadata_field(self): - mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} - self.vector._client.search.return_value = mock_response - - doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) - embedding = [0.1] * 128 - - with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk: - mock_bulk.return_value = ([], []) - self.vector.add_texts([doc], [embedding]) - - ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) - assert len(ids) == 1 - assert ids[0] == "mock_id" - - def test_add_texts(self): - self.vector._client.index.return_value = {"result": "created"} - - doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) - embedding = [0.1] * 128 - - with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk: - mock_bulk.return_value = ([], []) - self.vector.add_texts([doc], [embedding]) - - mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} - self.vector._client.search.return_value = mock_response - - ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) - assert len(ids) == 1 - assert ids[0] == "mock_id" - - def test_delete_nonexistent_index(self): - """Test deleting a non-existent index.""" - # Create a vector instance with a non-existent collection name - self.vector._client.indices.exists.return_value = False - - # Should not raise an exception - self.vector.delete() - - # Verify that exists was called but delete was not - self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower()) - self.vector._client.indices.delete.assert_not_called() - - def test_delete_existing_index(self): - """Test deleting an existing index.""" - self.vector._client.indices.exists.return_value = True - - self.vector.delete() - - # Verify both exists and delete were called - self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower()) - self.vector._client.indices.delete.assert_called_once_with(index=self.collection_name.lower()) - - -@pytest.mark.usefixtures("setup_mock_redis") -class TestOpenSearchVectorWithRedis: - def setup_method(self): - self.tester = TestOpenSearchVector() - - def test_search_by_full_text(self): - self.tester.setup_method() - search_response = { - "hits": { - "total": {"value": 1}, - "hits": [ - {"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}} - ], - } - } - expected_length = 1 - expected_doc_id = "example_doc_id" - self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id) - - def test_get_ids_by_metadata_field(self): - self.tester.setup_method() - self.tester.test_get_ids_by_metadata_field() - - def test_add_texts(self): - self.tester.setup_method() - self.tester.test_add_texts() - - def test_search_by_vector(self): - self.tester.setup_method() - self.tester.test_search_by_vector() diff --git a/api/tests/integration_tests/vdb/oracle/__init__.py b/api/tests/integration_tests/vdb/oracle/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/__init__.py b/api/tests/integration_tests/vdb/pgvecto_rs/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/pgvector/__init__.py b/api/tests/integration_tests/vdb/pgvector/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/pyvastbase/__init__.py b/api/tests/integration_tests/vdb/pyvastbase/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/qdrant/__init__.py b/api/tests/integration_tests/vdb/qdrant/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/tablestore/__init__.py b/api/tests/integration_tests/vdb/tablestore/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/tcvectordb/__init__.py b/api/tests/integration_tests/vdb/tcvectordb/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/tidb_vector/__init__.py b/api/tests/integration_tests/vdb/tidb_vector/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/upstash/__init__.py b/api/tests/integration_tests/vdb/upstash/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/vikingdb/__init__.py b/api/tests/integration_tests/vdb/vikingdb/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/vdb/weaviate/__init__.py b/api/tests/integration_tests/vdb/weaviate/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index c0143faa853..a9a2617baed 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,12 +1,11 @@ from unittest.mock import MagicMock -from graphon.model_runtime.entities.model_entities import ModelType - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py deleted file mode 100644 index 487178ff580..00000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor - -CODE_LANGUAGE = "unsupported_language" - - -def test_unsupported_with_code_template(): - with pytest.raises(CodeExecutionError) as e: - CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) - assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py deleted file mode 100644 index c8eb9ec3e4d..00000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ /dev/null @@ -1,95 +0,0 @@ -import base64 - -from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer - -CODE_LANGUAGE = CodeLanguage.JINJA2 - - -def test_jinja2(): - """Test basic Jinja2 template rendering.""" - template = "Hello {{template}}" - # Template must be base64 encoded to match the new safe embedding approach - template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") - inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") - code = ( - Jinja2TemplateTransformer.get_runner_script() - .replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64) - .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) - ) - result = CodeExecutor.execute_code( - language=CODE_LANGUAGE, preload=Jinja2TemplateTransformer.get_preload_script(), code=code - ) - assert result == "<>Hello World<>\n" - - -def test_jinja2_with_code_template(): - """Test template rendering via the high-level workflow API.""" - result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"} - ) - assert result == {"result": "Hello World"} - - -def test_jinja2_get_runner_script(): - """Test that runner script contains required placeholders.""" - runner_script = Jinja2TemplateTransformer.get_runner_script() - assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1 - assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 - assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 - - -def test_jinja2_template_with_special_characters(): - """ - Test that templates with special characters (quotes, newlines) render correctly. - This is a regression test for issue #26818 where textarea pre-fill values - containing special characters would break template rendering. - """ - # Template with triple quotes, single quotes, double quotes, and newlines - template = """ - - - -

Status: "{{ status }}"

-
'''code block'''
- -""" - inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"} - - result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) - - # Verify the template rendered correctly with all special characters - output = result["result"] - assert 'value="TASK-123"' in output - assert "" in output - assert 'Status: "completed"' in output - assert "'''code block'''" in output - - -def test_jinja2_template_with_html_textarea_prefill(): - """ - Specific test for HTML textarea with Jinja2 variable pre-fill. - Verifies fix for issue #26818. - """ - template = "" - notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes." - inputs = {"notes": notes_content} - - result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) - - expected_output = f"" - assert result["result"] == expected_output - - -def test_jinja2_assemble_runner_script_encodes_template(): - """Test that assemble_runner_script properly base64 encodes the template.""" - template = "Hello {{ name }}!" - inputs = {"name": "World"} - - script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs) - - # The template should be base64 encoded in the script - template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") - assert template_b64 in script - # The raw template should NOT appear in the script (it's encoded) - assert "Hello {{ name }}!" not in script diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py deleted file mode 100644 index 25af312afa4..00000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ /dev/null @@ -1,36 +0,0 @@ -from textwrap import dedent - -from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer - -CODE_LANGUAGE = CodeLanguage.PYTHON3 - - -def test_python3_plain(): - code = 'print("Hello World")' - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == "Hello World\n" - - -def test_python3_json(): - code = dedent(""" - import json - print(json.dumps({'Hello': 'World'})) - """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == '{"Hello": "World"}\n' - - -def test_python3_with_code_template(): - result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} - ) - assert result == {"result": "HelloWorld"} - - -def test_python3_get_runner_script(): - runner_script = Python3TemplateTransformer.get_runner_script() - assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 4f41396c22b..aaa60929935 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,18 @@ import time import uuid import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.node_events import NodeRunResult -from graphon.nodes.code.code_node import CodeNode -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import NodeRunResult +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeNodeData +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.code_executor",) @@ -64,8 +65,8 @@ def init_code_node(code_config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = CodeNode( - id=str(uuid.uuid4()), - config=code_config, + node_id=str(uuid.uuid4()), + config=CodeNodeData.model_validate(code_config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, code_executor=node_factory._code_executor, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index b1f937e7388..b9f7b9575be 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,11 +3,6 @@ import uuid from urllib.parse import urlencode import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -16,6 +11,11 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.http",) @@ -75,8 +75,8 @@ def init_http_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, @@ -192,6 +192,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -202,8 +203,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool - from core.workflow.system_variables import build_system_variables - # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), @@ -724,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=graph_config["nodes"][1], + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index f0f3fcead19..3eead701639 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,20 +4,20 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.llm.node import LLMNode from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.nodes.protocols import HttpClientProtocol from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_manager import ModelInstance -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -76,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode: llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=LLMNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index fe512c25856..f2eabb86c34 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,17 +3,17 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -70,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = ParameterExtractorNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ParameterExtractorNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 2d728569bee..e2e0723fb86 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,15 +1,15 @@ import time import uuid -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.template_rendering import TemplateRenderError - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -87,8 +87,8 @@ def test_execute_template_transform(): assert graph is not None node = TemplateTransformNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=TemplateTransformNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, jinja2_template_renderer=_SimpleJinja2Renderer(), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 750ced7075e..a8e9422c1ee 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,18 +2,18 @@ import time import uuid from unittest.mock import MagicMock, patch -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.node_events import StreamCompletedEvent -from graphon.nodes.protocols import ToolFileManagerProtocol -from graphon.nodes.tool.tool_node import ToolNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.entities import ToolNodeData +from graphon.nodes.tool.tool_node import ToolNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -61,8 +61,8 @@ def init_tool_node(config: dict): tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) node = ToolNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index ef74893f070..66a25e5dafb 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -369,7 +369,7 @@ def _create_app_with_containers() -> Flask: # Create and configure the Flask application logger.info("Initializing Flask application...") - app = create_app() + sio_app, app = create_app() logger.info("Flask application created successfully") # Initialize database schema diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 0841217fcfe..18755ef012d 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -234,6 +234,35 @@ class TestAppEndpoints: } ) + def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch): + api = app_module.AppIconApi() + method = _unwrap(api.post) + payload = { + "icon": "https://example.com/icon.png", + "icon_type": "image", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app_icon.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetail, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace()) + + assert response == {"id": "app-1"} + assert app_service.update_app_icon.call_args.args[1:] == ( + payload["icon"], + payload["icon_background"], + app_module.IconType.IMAGE, + ) + class TestOpsTraceEndpoints: @pytest.fixture @@ -313,6 +342,21 @@ class TestSiteEndpoints: method = _unwrap(api.post) site = MagicMock() + site.app_id = "app-1" + site.code = "test-code" + site.title = "My Site" + site.icon = None + site.icon_background = None + site.description = "Test site" + site.default_language = "en-US" + site.customize_domain = None + site.copyright = None + site.privacy_policy = None + site.custom_disclaimer = "" + site.customize_token_strategy = "not_allow" + site.prompt_public = False + site.show_workflow_steps = True + site.use_icon_as_answer_icon = False monkeypatch.setattr( site_module.db, "session", @@ -328,13 +372,29 @@ class TestSiteEndpoints: with app.test_request_context("/", json={"title": "My Site"}): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is site + assert isinstance(result, dict) + assert result["title"] == "My Site" def test_app_site_access_token_reset(self, app, monkeypatch): api = site_module.AppSiteAccessTokenReset() method = _unwrap(api.post) site = MagicMock() + site.app_id = "app-1" + site.code = "old-code" + site.title = "My Site" + site.icon = None + site.icon_background = None + site.description = None + site.default_language = "en-US" + site.customize_domain = None + site.copyright = None + site.privacy_policy = None + site.custom_disclaimer = "" + site.customize_token_strategy = "not_allow" + site.prompt_public = False + site.show_workflow_steps = True + site.use_icon_as_answer_icon = False monkeypatch.setattr( site_module.db, "session", @@ -351,7 +411,8 @@ class TestSiteEndpoints: with app.test_request_context("/"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is site + assert isinstance(result, dict) + assert result["access_token"] == "code" class TestWorkflowEndpoints: @@ -400,7 +461,7 @@ class TestWorkflowAppLogEndpoints: monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker) def fake_get_paginate(self, **_kwargs): - return {"items": [], "total": 0} + return {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} monkeypatch.setattr( workflow_app_log_module.WorkflowAppService, @@ -411,7 +472,7 @@ class TestWorkflowAppLogEndpoints: with app.test_request_context("/?page=1&limit=20"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result == {"items": [], "total": 0} + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} class TestWorkflowDraftVariableEndpoints: @@ -555,7 +616,7 @@ class TestWorkflowTriggerEndpoints: trigger = MagicMock() session = MagicMock() - session.query.return_value.where.return_value.first.return_value = trigger + session.scalar.return_value = trigger class DummySessionCtx: def __enter__(self): @@ -576,7 +637,8 @@ class TestWorkflowTriggerEndpoints: with app.test_request_context("/?node_id=node-1"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is trigger + assert isinstance(result, dict) + assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys()) class TestWrapsEndpoints: diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index d8c6821f8db..25d19cf35a3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -96,6 +96,56 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED + def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.commit.assert_called_once_with() + fake_session.rollback.assert_not_called() + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.rollback.assert_called_once_with() + fake_session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + class TestAppImportConfirmApi: @pytest.fixture diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 5cc458fe2ef..5a22f81a697 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,15 +4,15 @@ import json import uuid from flask.testing import FlaskClient -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole +from models.account import AccountStatus, TenantAccountRole, TenantStatus from models.enums import ConversationFromSource, CreatorUserRole from models.model import App, AppMode, Conversation, Message from models.workflow import WorkflowRun @@ -30,7 +30,7 @@ def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: db_session.add(account) db_session.commit() - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) db_session.add(tenant) db_session.commit() diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py new file mode 100644 index 00000000000..fad0b8b10e2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -0,0 +1,73 @@ +from datetime import datetime +from unittest.mock import patch + +import pytest +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from controllers.console.app.conversation import _get_conversation +from models.enums import ConversationFromSource +from models.model import AppMode, Conversation +from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + create_console_app, +) + + +def test_get_conversation_mark_read_keeps_updated_at_unchanged( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + original_updated_at = datetime(2026, 2, 8, 0, 0, 0) + conversation = Conversation( + app_id=app.id, + name="read timestamp test", + inputs={}, + status="normal", + mode=AppMode.CHAT, + from_source=ConversationFromSource.CONSOLE, + from_account_id=account.id, + updated_at=original_updated_at, + ) + db_session_with_containers.add(conversation) + db_session_with_containers.commit() + + read_at = datetime(2026, 2, 9, 0, 0, 0) + + with ( + patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ), + patch( + "controllers.console.app.conversation.naive_utc_now", + return_value=read_at, + autospec=True, + ), + ): + loaded = _get_conversation(app, conversation.id) + + db_session_with_containers.refresh(conversation) + + assert loaded.id == conversation.id + assert conversation.read_at == read_at + assert conversation.read_account_id == account.id + assert conversation.updated_at == original_updated_at + + +def test_get_conversation_raises_not_found_for_missing_conversation( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ): + with pytest.raises(NotFound): + _get_conversation(app, "00000000-0000-0000-0000-000000000000") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py index 6b51ec98bcd..eff6dd789d7 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py @@ -148,14 +148,18 @@ def test_chat_message_list_success( account.id, created_at_offset_seconds=1, ) + # Capture IDs before the HTTP request detaches ORM instances from the session + app_id = app.id + conversation_id = conversation.id + second_id = second.id with patch( "controllers.console.app.message.attach_message_extra_contents", side_effect=_attach_message_extra_contents, ): response = test_client_with_containers.get( - f"/console/api/apps/{app.id}/chat-messages", - query_string={"conversation_id": conversation.id, "limit": 1}, + f"/console/api/apps/{app_id}/chat-messages", + query_string={"conversation_id": conversation_id, "limit": 1}, headers=authenticate_console_client(test_client_with_containers, account), ) @@ -165,7 +169,7 @@ def test_chat_message_list_success( assert payload["limit"] == 1 assert payload["has_more"] is True assert len(payload["data"]) == 1 - assert payload["data"][0]["id"] == second.id + assert payload["data"][0]["id"] == second_id def test_message_feedback_not_found( diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index 8ddf867370e..290be876974 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -3,12 +3,12 @@ import uuid from flask.testing import FlaskClient -from graphon.variables.segments import StringSegment from sqlalchemy import select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from factories.variable_factory import segment_to_variable +from graphon.variables.segments import StringSegment from models import Workflow from models.model import AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 879c337319c..320da85b601 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -158,7 +158,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 7b7393dade9..d2703ed5cce 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -113,12 +113,14 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") def test_reset_fetches_account_with_original_email( self, mock_get_reset_data, mock_revoke_token, + mock_db, mock_get_account, mock_update_account, app, @@ -126,6 +128,7 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account wraps_features = SimpleNamespace(enable_email_password_login=True) with ( @@ -161,7 +164,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index a2f1328579f..1eabb45422a 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -437,7 +437,10 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 8f9db287e31..50249bcd743 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -335,10 +335,12 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, + mock_db, mock_get_account, mock_revoke_token, mock_get_data, @@ -356,6 +358,7 @@ class TestForgotPasswordResetApi: # Arrange mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act diff --git a/api/tests/test_containers_integration_tests/controllers/console/helpers.py b/api/tests/test_containers_integration_tests/controllers/console/helpers.py index 9e2084f3939..a8ecf94da19 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/helpers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/helpers.py @@ -11,7 +11,7 @@ from constants import HEADER_NAME_CSRF_TOKEN from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole +from models.account import AccountStatus, TenantAccountRole, TenantStatus from models.model import App, AppMode from services.account_service import AccountService @@ -37,7 +37,7 @@ def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Ten db_session.add(account) db_session.commit() - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) db_session.add(tenant) db_session.commit() diff --git a/api/tests/test_containers_integration_tests/controllers/mcp/test_mcp.py b/api/tests/test_containers_integration_tests/controllers/mcp/test_mcp.py index 90670a9db59..21b395a04c6 100644 --- a/api/tests/test_containers_integration_tests/controllers/mcp/test_mcp.py +++ b/api/tests/test_containers_integration_tests/controllers/mcp/test_mcp.py @@ -444,7 +444,7 @@ class TestMCPAppApi: ) session = MagicMock() - session.query().where().first.side_effect = [server, app] + session.scalar.side_effect = [server, app] result_server, result_app = api._get_mcp_server_and_app("server-1", session) diff --git a/api/tests/integration_tests/vdb/iris/__init__.py b/api/tests/test_containers_integration_tests/controllers/service_api/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/iris/__init__.py rename to api/tests/test_containers_integration_tests/controllers/service_api/__init__.py diff --git a/api/tests/integration_tests/vdb/lindorm/__init__.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/lindorm/__init__.py rename to api/tests/test_containers_integration_tests/controllers/service_api/dataset/__init__.py diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py similarity index 50% rename from api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py rename to api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 910d781cd02..9b913d6d3d6 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -1,17 +1,16 @@ """ -Unit tests for Service API Dataset controllers. +Integration tests for Service API Dataset controllers. + +Migrated from unit_tests/controllers/service_api/dataset/test_dataset.py. Tests coverage for: - DatasetCreatePayload, DatasetUpdatePayload Pydantic models - Tag-related payloads (create, update, delete, binding) - DatasetListQuery model -- DatasetService and TagService interfaces -- Permission validation patterns +- API endpoint error handling and controller behavior -Focus on: -- Pydantic model validation -- Error type mappings -- Service method interfaces +Services (DatasetService, TagService, DocumentService) remain mocked +since these test controller-level behavior. """ import uuid @@ -19,6 +18,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound import services @@ -36,22 +36,23 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName from models.account import Account from models.dataset import DatasetPermissionEnum from models.enums import TagType -from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService -from services.tag_service import TagService +from models.model import Tag + +# --------------------------------------------------------------------------- +# Pydantic model validation tests +# --------------------------------------------------------------------------- class TestDatasetCreatePayload: """Test suite for DatasetCreatePayload Pydantic model.""" def test_payload_with_required_name(self): - """Test payload with required name field.""" payload = DatasetCreatePayload(name="Test Dataset") assert payload.name == "Test Dataset" assert payload.description == "" assert payload.permission == DatasetPermissionEnum.ONLY_ME def test_payload_with_all_fields(self): - """Test payload with all fields populated.""" payload = DatasetCreatePayload( name="Full Dataset", description="A comprehensive dataset description", @@ -70,28 +71,23 @@ class TestDatasetCreatePayload: assert payload.embedding_model_provider == "openai" def test_payload_name_length_validation_min(self): - """Test name minimum length validation.""" with pytest.raises(ValueError): DatasetCreatePayload(name="") def test_payload_name_length_validation_max(self): - """Test name maximum length validation (40 chars).""" with pytest.raises(ValueError): DatasetCreatePayload(name="A" * 41) def test_payload_description_max_length(self): - """Test description maximum length (400 chars).""" with pytest.raises(ValueError): DatasetCreatePayload(name="Dataset", description="A" * 401) @pytest.mark.parametrize("technique", ["high_quality", "economy"]) def test_payload_valid_indexing_techniques(self, technique): - """Test valid indexing technique values.""" payload = DatasetCreatePayload(name="Dataset", indexing_technique=technique) assert payload.indexing_technique == technique def test_payload_with_external_knowledge_settings(self): - """Test payload with external knowledge configuration.""" payload = DatasetCreatePayload( name="External Dataset", external_knowledge_api_id="api_123", external_knowledge_id="knowledge_456" ) @@ -103,20 +99,17 @@ class TestDatasetUpdatePayload: """Test suite for DatasetUpdatePayload Pydantic model.""" def test_payload_all_optional(self): - """Test payload with all fields optional.""" payload = DatasetUpdatePayload() assert payload.name is None assert payload.description is None assert payload.permission is None def test_payload_with_partial_update(self): - """Test payload with partial update fields.""" payload = DatasetUpdatePayload(name="Updated Name", description="Updated description") assert payload.name == "Updated Name" assert payload.description == "Updated description" def test_payload_with_permission_change(self): - """Test payload with permission update.""" payload = DatasetUpdatePayload( permission=DatasetPermissionEnum.PARTIAL_TEAM, partial_member_list=[{"user_id": "user_123", "role": "editor"}], @@ -125,12 +118,8 @@ class TestDatasetUpdatePayload: assert len(payload.partial_member_list) == 1 def test_payload_name_length_validation(self): - """Test name length constraints.""" - # Minimum is 1 with pytest.raises(ValueError): DatasetUpdatePayload(name="") - - # Maximum is 40 with pytest.raises(ValueError): DatasetUpdatePayload(name="A" * 41) @@ -139,7 +128,6 @@ class TestDatasetListQuery: """Test suite for DatasetListQuery Pydantic model.""" def test_query_with_defaults(self): - """Test query with default values.""" query = DatasetListQuery() assert query.page == 1 assert query.limit == 20 @@ -148,7 +136,6 @@ class TestDatasetListQuery: assert query.tag_ids == [] def test_query_with_all_filters(self): - """Test query with all filter fields.""" query = DatasetListQuery( page=3, limit=50, keyword="machine learning", include_all=True, tag_ids=["tag1", "tag2", "tag3"] ) @@ -159,7 +146,6 @@ class TestDatasetListQuery: assert len(query.tag_ids) == 3 def test_query_with_tag_filter(self): - """Test query with tag IDs filter.""" query = DatasetListQuery(tag_ids=["tag_abc", "tag_def"]) assert query.tag_ids == ["tag_abc", "tag_def"] @@ -168,22 +154,18 @@ class TestTagCreatePayload: """Test suite for TagCreatePayload Pydantic model.""" def test_payload_with_name(self): - """Test payload with required name.""" payload = TagCreatePayload(name="New Tag") assert payload.name == "New Tag" def test_payload_name_length_min(self): - """Test name minimum length (1).""" with pytest.raises(ValueError): TagCreatePayload(name="") def test_payload_name_length_max(self): - """Test name maximum length (50).""" with pytest.raises(ValueError): TagCreatePayload(name="A" * 51) def test_payload_with_unicode_name(self): - """Test payload with unicode characters.""" payload = TagCreatePayload(name="标签 🏷️ Тег") assert payload.name == "标签 🏷️ Тег" @@ -192,13 +174,11 @@ class TestTagUpdatePayload: """Test suite for TagUpdatePayload Pydantic model.""" def test_payload_with_name_and_id(self): - """Test payload with name and tag_id.""" payload = TagUpdatePayload(name="Updated Tag", tag_id="tag_123") assert payload.name == "Updated Tag" assert payload.tag_id == "tag_123" def test_payload_requires_tag_id(self): - """Test that tag_id is required.""" with pytest.raises(ValueError): TagUpdatePayload(name="Updated Tag") @@ -207,12 +187,10 @@ class TestTagDeletePayload: """Test suite for TagDeletePayload Pydantic model.""" def test_payload_with_tag_id(self): - """Test payload with tag_id.""" payload = TagDeletePayload(tag_id="tag_to_delete") assert payload.tag_id == "tag_to_delete" def test_payload_requires_tag_id(self): - """Test that tag_id is required.""" with pytest.raises(ValueError): TagDeletePayload() @@ -221,19 +199,16 @@ class TestTagBindingPayload: """Test suite for TagBindingPayload Pydantic model.""" def test_payload_with_valid_data(self): - """Test payload with valid binding data.""" payload = TagBindingPayload(tag_ids=["tag1", "tag2"], target_id="dataset_123") assert len(payload.tag_ids) == 2 assert payload.target_id == "dataset_123" def test_payload_rejects_empty_tag_ids(self): - """Test that empty tag_ids are rejected.""" with pytest.raises(ValueError) as exc_info: TagBindingPayload(tag_ids=[], target_id="dataset_123") assert "Tag IDs is required" in str(exc_info.value) def test_payload_single_tag_id(self): - """Test payload with single tag ID.""" payload = TagBindingPayload(tag_ids=["single_tag"], target_id="dataset_456") assert payload.tag_ids == ["single_tag"] @@ -242,674 +217,14 @@ class TestTagUnbindingPayload: """Test suite for TagUnbindingPayload Pydantic model.""" def test_payload_with_valid_data(self): - """Test payload with valid unbinding data.""" payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456") assert payload.tag_id == "tag_123" assert payload.target_id == "dataset_456" -class TestDatasetTagsApi: - """Test suite for DatasetTagsApi endpoints.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.current_user") - @patch("controllers.service_api.dataset.dataset.TagService") - def test_get_tags_success(self, mock_tag_service, mock_current_user, app): - """Test successful retrieval of dataset tags.""" - # Arrange - mock_current_user needs to pass isinstance(current_user, Account) - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.current_tenant_id = "tenant_123" - # Replace the mock with our properly specced one - from controllers.service_api.dataset import dataset as dataset_module - - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag = Mock() - mock_tag.id = "tag_1" - mock_tag.name = "Test Tag" - mock_tag.type = TagType.KNOWLEDGE - mock_tag.binding_count = "0" # Required for Pydantic validation - must be string - mock_tag_service.get_tags.return_value = [mock_tag] - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act - with app.test_request_context("/", method="GET"): - api = DatasetTagsApi() - response, status_code = api.get("tenant_123") - - # Assert - assert status_code == 200 - assert len(response) == 1 - assert response[0]["id"] == "tag_1" - assert response[0]["name"] == "Test Tag" - mock_tag_service.get_tags.assert_called_once_with("knowledge", "tenant_123") - finally: - dataset_module.current_user = original_current_user - - @pytest.mark.skip(reason="Production code bug: binding_count should be string, not integer") - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_create_tag_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful creation of a dataset tag.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag = Mock() - mock_tag.id = "new_tag_1" - mock_tag.name = "New Tag" - mock_tag.type = TagType.KNOWLEDGE - mock_tag_service.save_tags.return_value = mock_tag - mock_service_api_ns.payload = {"name": "New Tag"} - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act - with app.test_request_context("/", method="POST", json={"name": "New Tag"}): - api = DatasetTagsApi() - response, status_code = api.post("tenant_123") - - # Assert - assert status_code == 200 - assert response["id"] == "new_tag_1" - assert response["name"] == "New Tag" - assert response["binding_count"] == 0 - finally: - dataset_module.current_user = original_current_user - - def test_create_tag_forbidden(self, app): - """Test tag creation without edit permissions.""" - # Arrange - from werkzeug.exceptions import Forbidden - - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = False - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act & Assert - with app.test_request_context("/", method="POST"): - api = DatasetTagsApi() - with pytest.raises(Forbidden): - api.post("tenant_123") - finally: - dataset_module.current_user = original_current_user - - @pytest.mark.skip(reason="Production code bug: binding_count should be string, not integer") - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_update_tag_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful update of a dataset tag.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag = Mock() - mock_tag.id = "tag_1" - mock_tag.name = "Updated Tag" - mock_tag.type = TagType.KNOWLEDGE - mock_tag.binding_count = "5" - mock_tag_service.update_tags.return_value = mock_tag - mock_tag_service.get_tag_binding_count.return_value = 5 - mock_service_api_ns.payload = {"name": "Updated Tag", "tag_id": "tag_1"} - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act - with app.test_request_context("/", method="PATCH", json={"name": "Updated Tag", "tag_id": "tag_1"}): - api = DatasetTagsApi() - response, status_code = api.patch("tenant_123") - - # Assert - assert status_code == 200 - assert response["id"] == "tag_1" - assert response["name"] == "Updated Tag" - assert response["binding_count"] == 5 - finally: - dataset_module.current_user = original_current_user - - @pytest.mark.skip(reason="Production code bug: binding_count should be string, not integer") - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_delete_tag_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful deletion of a dataset tag.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag_service.delete_tag.return_value = None - mock_service_api_ns.payload = {"tag_id": "tag_1"} - - from controllers.service_api.dataset.dataset import DatasetTagsApi - - try: - # Act - with app.test_request_context("/", method="DELETE", json={"tag_id": "tag_1"}): - api = DatasetTagsApi() - response = api.delete("tenant_123") - - # Assert - assert response == ("", 204) - mock_tag_service.delete_tag.assert_called_once_with("tag_1") - finally: - dataset_module.current_user = original_current_user - - -class TestDatasetTagBindingApi: - """Test suite for DatasetTagBindingApi endpoints.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_bind_tags_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful binding of tags to dataset.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag_service.save_tag_binding.return_value = None - payload = {"tag_ids": ["tag_1", "tag_2"], "target_id": "dataset_123"} - mock_service_api_ns.payload = payload - - from controllers.service_api.dataset.dataset import DatasetTagBindingApi - - try: - # Act - with app.test_request_context("/", method="POST", json=payload): - api = DatasetTagBindingApi() - response = api.post("tenant_123") - - # Assert - assert response == ("", 204) - mock_tag_service.save_tag_binding.assert_called_once_with( - {"tag_ids": ["tag_1", "tag_2"], "target_id": "dataset_123", "type": "knowledge"} - ) - finally: - dataset_module.current_user = original_current_user - - def test_bind_tags_forbidden(self, app): - """Test tag binding without edit permissions.""" - # Arrange - from werkzeug.exceptions import Forbidden - - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = False - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - from controllers.service_api.dataset.dataset import DatasetTagBindingApi - - try: - # Act & Assert - with app.test_request_context("/", method="POST"): - api = DatasetTagBindingApi() - with pytest.raises(Forbidden): - api.post("tenant_123") - finally: - dataset_module.current_user = original_current_user - - -class TestDatasetTagUnbindingApi: - """Test suite for DatasetTagUnbindingApi endpoints.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.TagService") - @patch("controllers.service_api.dataset.dataset.service_api_ns") - def test_unbind_tag_success(self, mock_service_api_ns, mock_tag_service, app): - """Test successful unbinding of tag from dataset.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.has_edit_permission = True - mock_account.is_dataset_editor = False - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag_service.delete_tag_binding.return_value = None - payload = {"tag_id": "tag_1", "target_id": "dataset_123"} - mock_service_api_ns.payload = payload - - from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi - - try: - # Act - with app.test_request_context("/", method="POST", json=payload): - api = DatasetTagUnbindingApi() - response = api.post("tenant_123") - - # Assert - assert response == ("", 204) - mock_tag_service.delete_tag_binding.assert_called_once_with( - {"tag_id": "tag_1", "target_id": "dataset_123", "type": "knowledge"} - ) - finally: - dataset_module.current_user = original_current_user - - -class TestDatasetTagsBindingStatusApi: - """Test suite for DatasetTagsBindingStatusApi endpoints.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.TagService") - def test_get_dataset_tags_binding_status(self, mock_tag_service, app): - """Test retrieval of tags bound to a specific dataset.""" - # Arrange - from controllers.service_api.dataset import dataset as dataset_module - from models.account import Account - - mock_account = Mock(spec=Account) - mock_account.current_tenant_id = "tenant_123" - original_current_user = dataset_module.current_user - dataset_module.current_user = mock_account - - mock_tag = Mock() - mock_tag.id = "tag_1" - mock_tag.name = "Test Tag" - mock_tag_service.get_tags_by_target_id.return_value = [mock_tag] - - from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi - - try: - # Act - with app.test_request_context("/", method="GET"): - api = DatasetTagsBindingStatusApi() - response, status_code = api.get("tenant_123", dataset_id="dataset_123") - - # Assert - assert status_code == 200 - assert response["data"] == [{"id": "tag_1", "name": "Test Tag"}] - assert response["total"] == 1 - mock_tag_service.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123") - finally: - dataset_module.current_user = original_current_user - - -class TestDocumentStatusApi: - """Test suite for DocumentStatusApi batch operations.""" - - @pytest.fixture - def app(self): - """Create Flask test application.""" - from flask import Flask - - app = Flask(__name__) - app.config["TESTING"] = True - return app - - @patch("controllers.service_api.dataset.dataset.DatasetService") - @patch("controllers.service_api.dataset.dataset.DocumentService") - def test_batch_enable_documents(self, mock_doc_service, mock_dataset_service, app): - """Test batch enabling documents.""" - # Arrange - mock_dataset = Mock() - mock_dataset_service.get_dataset.return_value = mock_dataset - mock_doc_service.batch_update_document_status.return_value = None - - from controllers.service_api.dataset.dataset import DocumentStatusApi - - # Act - with app.test_request_context("/", method="PATCH", json={"document_ids": ["doc_1", "doc_2"]}): - api = DocumentStatusApi() - response, status_code = api.patch("tenant_123", "dataset_123", "enable") - - # Assert - assert status_code == 200 - assert response == {"result": "success"} - mock_doc_service.batch_update_document_status.assert_called_once() - - @patch("controllers.service_api.dataset.dataset.DatasetService") - def test_batch_update_dataset_not_found(self, mock_dataset_service, app): - """Test batch update when dataset not found.""" - # Arrange - mock_dataset_service.get_dataset.return_value = None - - from werkzeug.exceptions import NotFound - - from controllers.service_api.dataset.dataset import DocumentStatusApi - - # Act & Assert - with app.test_request_context("/", method="PATCH", json={"document_ids": ["doc_1"]}): - api = DocumentStatusApi() - with pytest.raises(NotFound) as exc_info: - api.patch("tenant_123", "dataset_123", "enable") - assert "Dataset not found" in str(exc_info.value) - - @patch("controllers.service_api.dataset.dataset.DatasetService") - @patch("controllers.service_api.dataset.dataset.DocumentService") - def test_batch_update_permission_error(self, mock_doc_service, mock_dataset_service, app): - """Test batch update with permission error.""" - # Arrange - mock_dataset = Mock() - mock_dataset_service.get_dataset.return_value = mock_dataset - from services.errors.account import NoPermissionError - - mock_dataset_service.check_dataset_permission.side_effect = NoPermissionError("No permission") - - from werkzeug.exceptions import Forbidden - - from controllers.service_api.dataset.dataset import DocumentStatusApi - - # Act & Assert - with app.test_request_context("/", method="PATCH", json={"document_ids": ["doc_1"]}): - api = DocumentStatusApi() - with pytest.raises(Forbidden): - api.patch("tenant_123", "dataset_123", "enable") - - @patch("controllers.service_api.dataset.dataset.DatasetService") - @patch("controllers.service_api.dataset.dataset.DocumentService") - def test_batch_update_invalid_action(self, mock_doc_service, mock_dataset_service, app): - """Test batch update with invalid action error.""" - # Arrange - mock_dataset = Mock() - mock_dataset_service.get_dataset.return_value = mock_dataset - mock_doc_service.batch_update_document_status.side_effect = ValueError("Invalid action") - - from controllers.service_api.dataset.dataset import DocumentStatusApi - from controllers.service_api.dataset.error import InvalidActionError - - # Act & Assert - with app.test_request_context("/", method="PATCH", json={"document_ids": ["doc_1"]}): - api = DocumentStatusApi() - with pytest.raises(InvalidActionError): - api.patch("tenant_123", "dataset_123", "invalid_action") - - """Test DatasetPermissionEnum values.""" - - def test_only_me_permission(self): - """Test ONLY_ME permission value.""" - assert DatasetPermissionEnum.ONLY_ME is not None - - def test_all_team_permission(self): - """Test ALL_TEAM permission value.""" - assert DatasetPermissionEnum.ALL_TEAM is not None - - def test_partial_team_permission(self): - """Test PARTIAL_TEAM permission value.""" - assert DatasetPermissionEnum.PARTIAL_TEAM is not None - - -class TestDatasetErrors: - """Test dataset-related error types.""" - - def test_dataset_in_use_error_can_be_raised(self): - """Test DatasetInUseError can be raised.""" - error = DatasetInUseError() - assert error is not None - - def test_dataset_name_duplicate_error_can_be_raised(self): - """Test DatasetNameDuplicateError can be raised.""" - error = DatasetNameDuplicateError() - assert error is not None - - def test_invalid_action_error_can_be_raised(self): - """Test InvalidActionError can be raised.""" - error = InvalidActionError("Invalid action") - assert error is not None - - -class TestDatasetService: - """Test DatasetService interface methods.""" - - def test_get_datasets_method_exists(self): - """Test DatasetService.get_datasets exists.""" - assert hasattr(DatasetService, "get_datasets") - - def test_get_dataset_method_exists(self): - """Test DatasetService.get_dataset exists.""" - assert hasattr(DatasetService, "get_dataset") - - def test_create_empty_dataset_method_exists(self): - """Test DatasetService.create_empty_dataset exists.""" - assert hasattr(DatasetService, "create_empty_dataset") - - def test_update_dataset_method_exists(self): - """Test DatasetService.update_dataset exists.""" - assert hasattr(DatasetService, "update_dataset") - - def test_delete_dataset_method_exists(self): - """Test DatasetService.delete_dataset exists.""" - assert hasattr(DatasetService, "delete_dataset") - - def test_check_dataset_permission_method_exists(self): - """Test DatasetService.check_dataset_permission exists.""" - assert hasattr(DatasetService, "check_dataset_permission") - - def test_check_dataset_model_setting_method_exists(self): - """Test DatasetService.check_dataset_model_setting exists.""" - assert hasattr(DatasetService, "check_dataset_model_setting") - - def test_check_embedding_model_setting_method_exists(self): - """Test DatasetService.check_embedding_model_setting exists.""" - assert hasattr(DatasetService, "check_embedding_model_setting") - - @patch.object(DatasetService, "get_datasets") - def test_get_datasets_returns_tuple(self, mock_get): - """Test get_datasets returns tuple of datasets and total.""" - mock_datasets = [Mock(), Mock()] - mock_get.return_value = (mock_datasets, 2) - - datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant_123", user=Mock()) - assert len(datasets) == 2 - assert total == 2 - - @patch.object(DatasetService, "get_dataset") - def test_get_dataset_returns_dataset(self, mock_get): - """Test get_dataset returns dataset object.""" - mock_dataset = Mock() - mock_dataset.id = str(uuid.uuid4()) - mock_dataset.name = "Test Dataset" - mock_get.return_value = mock_dataset - - result = DatasetService.get_dataset("dataset_id") - assert result.name == "Test Dataset" - - @patch.object(DatasetService, "get_dataset") - def test_get_dataset_returns_none_when_not_found(self, mock_get): - """Test get_dataset returns None when not found.""" - mock_get.return_value = None - - result = DatasetService.get_dataset("nonexistent_id") - assert result is None - - -class TestDatasetPermissionService: - """Test DatasetPermissionService interface.""" - - def test_check_permission_method_exists(self): - """Test DatasetPermissionService.check_permission exists.""" - assert hasattr(DatasetPermissionService, "check_permission") - - def test_get_dataset_partial_member_list_method_exists(self): - """Test DatasetPermissionService.get_dataset_partial_member_list exists.""" - assert hasattr(DatasetPermissionService, "get_dataset_partial_member_list") - - def test_update_partial_member_list_method_exists(self): - """Test DatasetPermissionService.update_partial_member_list exists.""" - assert hasattr(DatasetPermissionService, "update_partial_member_list") - - def test_clear_partial_member_list_method_exists(self): - """Test DatasetPermissionService.clear_partial_member_list exists.""" - assert hasattr(DatasetPermissionService, "clear_partial_member_list") - - -class TestDocumentService: - """Test DocumentService interface.""" - - def test_batch_update_document_status_method_exists(self): - """Test DocumentService.batch_update_document_status exists.""" - assert hasattr(DocumentService, "batch_update_document_status") - - -class TestTagService: - """Test TagService interface.""" - - def test_get_tags_method_exists(self): - """Test TagService.get_tags exists.""" - assert hasattr(TagService, "get_tags") - - def test_save_tags_method_exists(self): - """Test TagService.save_tags exists.""" - assert hasattr(TagService, "save_tags") - - def test_update_tags_method_exists(self): - """Test TagService.update_tags exists.""" - assert hasattr(TagService, "update_tags") - - def test_delete_tag_method_exists(self): - """Test TagService.delete_tag exists.""" - assert hasattr(TagService, "delete_tag") - - def test_save_tag_binding_method_exists(self): - """Test TagService.save_tag_binding exists.""" - assert hasattr(TagService, "save_tag_binding") - - def test_delete_tag_binding_method_exists(self): - """Test TagService.delete_tag_binding exists.""" - assert hasattr(TagService, "delete_tag_binding") - - def test_get_tags_by_target_id_method_exists(self): - """Test TagService.get_tags_by_target_id exists.""" - assert hasattr(TagService, "get_tags_by_target_id") - - def test_get_tag_binding_count_method_exists(self): - """Test TagService.get_tag_binding_count exists.""" - assert hasattr(TagService, "get_tag_binding_count") - - @patch.object(TagService, "get_tags") - def test_get_tags_returns_list(self, mock_get): - """Test get_tags returns list of tags.""" - mock_tags = [ - Mock(id="tag1", name="Tag One", type="knowledge"), - Mock(id="tag2", name="Tag Two", type="knowledge"), - ] - mock_get.return_value = mock_tags - - result = TagService.get_tags("knowledge", "tenant_123") - assert len(result) == 2 - - @patch.object(TagService, "save_tags") - def test_save_tags_returns_tag(self, mock_save): - """Test save_tags returns created tag.""" - mock_tag = Mock() - mock_tag.id = str(uuid.uuid4()) - mock_tag.name = "New Tag" - mock_tag.type = TagType.KNOWLEDGE - mock_save.return_value = mock_tag - - result = TagService.save_tags({"name": "New Tag", "type": "knowledge"}) - assert result.name == "New Tag" - - -class TestDocumentStatusAction: - """Test document status action values.""" - - def test_enable_action(self): - """Test enable action.""" - action = "enable" - assert action in ["enable", "disable", "archive", "un_archive"] - - def test_disable_action(self): - """Test disable action.""" - action = "disable" - assert action in ["enable", "disable", "archive", "un_archive"] - - def test_archive_action(self): - """Test archive action.""" - action = "archive" - assert action in ["enable", "disable", "archive", "un_archive"] - - def test_un_archive_action(self): - """Test un_archive action.""" - action = "un_archive" - assert action in ["enable", "disable", "archive", "un_archive"] - - -# ============================================================================= -# API Endpoint Tests -# -# ``DatasetListApi`` and ``DatasetApi`` inherit from ``DatasetApiResource`` -# whose ``method_decorators`` include ``validate_dataset_token``. -# -# Decorator strategy: -# - ``@cloud_edition_billing_rate_limit_check`` preserves ``__wrapped__`` -# → call via ``_unwrap(method)(self, …)``. -# - Methods without billing decorators → call directly; only patch ``db``, -# services, ``current_user``, and ``marshal``. -# ============================================================================= +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- def _unwrap(method): @@ -920,6 +235,15 @@ def _unwrap(method): return fn +@pytest.fixture +def app(flask_app_with_containers): + # Uses the full containerised app so that Flask config, extensions, and + # blueprint registrations match production. Most tests mock the service + # layer to isolate controller logic; a few (e.g. test_list_tags_from_db) + # exercise the real DB-backed path to validate end-to-end behaviour. + return flask_app_with_containers + + @pytest.fixture def mock_tenant(): tenant = Mock() @@ -938,12 +262,13 @@ def mock_dataset(): return dataset -class TestDatasetListApiGet: - """Test suite for DatasetListApi.get() endpoint. +# --------------------------------------------------------------------------- +# API endpoint tests — DatasetListApi +# --------------------------------------------------------------------------- - ``get`` has no billing decorators but calls ``current_user``, - ``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``. - """ + +class TestDatasetListApiGet: + """Test suite for DatasetListApi.get() endpoint.""" @patch("controllers.service_api.dataset.dataset.marshal") @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @@ -958,7 +283,6 @@ class TestDatasetListApiGet: app, mock_tenant, ): - """Test successful dataset list retrieval.""" from controllers.service_api.dataset.dataset import DatasetListApi mock_current_user.__class__ = Account @@ -981,10 +305,7 @@ class TestDatasetListApiGet: class TestDatasetListApiPost: - """Test suite for DatasetListApi.post() endpoint. - - ``post`` is wrapped by ``@cloud_edition_billing_rate_limit_check``. - """ + """Test suite for DatasetListApi.post() endpoint.""" @patch("controllers.service_api.dataset.dataset.marshal") @patch("controllers.service_api.dataset.dataset.current_user") @@ -997,7 +318,6 @@ class TestDatasetListApiPost: app, mock_tenant, ): - """Test successful dataset creation.""" from controllers.service_api.dataset.dataset import DatasetListApi mock_current_user.__class__ = Account @@ -1024,7 +344,6 @@ class TestDatasetListApiPost: app, mock_tenant, ): - """Test DatasetNameDuplicateError when name already exists.""" from controllers.service_api.dataset.dataset import DatasetListApi mock_current_user.__class__ = Account @@ -1040,12 +359,13 @@ class TestDatasetListApiPost: _unwrap(api.post)(api, tenant_id=mock_tenant.id) -class TestDatasetApiGet: - """Test suite for DatasetApi.get() endpoint. +# --------------------------------------------------------------------------- +# API endpoint tests — DatasetApi +# --------------------------------------------------------------------------- - ``get`` has no billing decorators but calls ``DatasetService``, - ``create_plugin_provider_manager``, ``marshal``, and ``current_user``. - """ + +class TestDatasetApiGet: + """Test suite for DatasetApi.get() endpoint.""" @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.marshal") @@ -1062,7 +382,6 @@ class TestDatasetApiGet: app, mock_dataset, ): - """Test successful dataset retrieval.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.get_dataset.return_value = mock_dataset @@ -1092,7 +411,6 @@ class TestDatasetApiGet: @patch("controllers.service_api.dataset.dataset.DatasetService") def test_get_dataset_not_found(self, mock_dataset_svc, app, mock_dataset): - """Test 404 when dataset not found.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.get_dataset.return_value = None @@ -1114,7 +432,6 @@ class TestDatasetApiGet: app, mock_dataset, ): - """Test 403 when user has no permission.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.get_dataset.return_value = mock_dataset @@ -1130,10 +447,7 @@ class TestDatasetApiGet: class TestDatasetApiDelete: - """Test suite for DatasetApi.delete() endpoint. - - ``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check``. - """ + """Test suite for DatasetApi.delete() endpoint.""" @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.current_user") @@ -1146,7 +460,6 @@ class TestDatasetApiDelete: app, mock_dataset, ): - """Test successful dataset deletion.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.delete_dataset.return_value = True @@ -1169,7 +482,6 @@ class TestDatasetApiDelete: app, mock_dataset, ): - """Test 404 when dataset not found for deletion.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.delete_dataset.return_value = False @@ -1191,7 +503,6 @@ class TestDatasetApiDelete: app, mock_dataset, ): - """Test DatasetInUseError when dataset is in use.""" from controllers.service_api.dataset.dataset import DatasetApi mock_dataset_svc.delete_dataset.side_effect = services.errors.dataset.DatasetInUseError() @@ -1205,12 +516,13 @@ class TestDatasetApiDelete: _unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) -class TestDocumentStatusApiPatch: - """Test suite for DocumentStatusApi.patch() endpoint. +# --------------------------------------------------------------------------- +# API endpoint tests — DocumentStatusApi +# --------------------------------------------------------------------------- - ``patch`` has no billing decorators but calls ``DatasetService``, - ``DocumentService``, and ``current_user``. - """ + +class TestDocumentStatusApiPatch: + """Test suite for DocumentStatusApi.patch() endpoint.""" @patch("controllers.service_api.dataset.dataset.DocumentService") @patch("controllers.service_api.dataset.dataset.current_user") @@ -1224,7 +536,6 @@ class TestDocumentStatusApiPatch: mock_tenant, mock_dataset, ): - """Test successful batch document status update.""" from controllers.service_api.dataset.dataset import DocumentStatusApi mock_current_user.__class__ = Account @@ -1256,7 +567,6 @@ class TestDocumentStatusApiPatch: mock_tenant, mock_dataset, ): - """Test 404 when dataset not found.""" from controllers.service_api.dataset.dataset import DocumentStatusApi mock_dataset_svc.get_dataset.return_value = None @@ -1274,6 +584,39 @@ class TestDocumentStatusApiPatch: action="enable", ) + @patch("controllers.service_api.dataset.dataset.DocumentService") + @patch("controllers.service_api.dataset.dataset.current_user") + @patch("controllers.service_api.dataset.dataset.DatasetService") + def test_batch_update_status_permission_error( + self, + mock_dataset_svc, + mock_current_user, + mock_doc_svc, + app, + mock_tenant, + mock_dataset, + ): + from controllers.service_api.dataset.dataset import DocumentStatusApi + + mock_current_user.__class__ = Account + mock_dataset_svc.get_dataset.return_value = mock_dataset + mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError( + "No permission" + ) + + with app.test_request_context( + f"/datasets/{mock_dataset.id}/documents/status/enable", + method="PATCH", + json={"document_ids": ["doc-1"]}, + ): + api = DocumentStatusApi() + with pytest.raises(Forbidden): + api.patch( + tenant_id=mock_tenant.id, + dataset_id=mock_dataset.id, + action="enable", + ) + @patch("controllers.service_api.dataset.dataset.DocumentService") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") @@ -1286,7 +629,6 @@ class TestDocumentStatusApiPatch: mock_tenant, mock_dataset, ): - """Test InvalidActionError when document is indexing.""" from controllers.service_api.dataset.dataset import DocumentStatusApi mock_current_user.__class__ = Account @@ -1320,7 +662,6 @@ class TestDocumentStatusApiPatch: mock_tenant, mock_dataset, ): - """Test InvalidActionError when ValueError raised.""" from controllers.service_api.dataset.dataset import DocumentStatusApi mock_current_user.__class__ = Account @@ -1343,6 +684,11 @@ class TestDocumentStatusApiPatch: ) +# --------------------------------------------------------------------------- +# API endpoint tests — Tags +# --------------------------------------------------------------------------- + + class TestDatasetTagsApiGet: """Test suite for DatasetTagsApi.get() endpoint.""" @@ -1354,7 +700,6 @@ class TestDatasetTagsApiGet: mock_tag_svc, app, ): - """Test successful tag list retrieval.""" from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -1368,15 +713,49 @@ class TestDatasetTagsApiGet: assert status == 200 assert len(response) == 1 + mock_tag_svc.get_tags.assert_called_once_with("knowledge", "tenant-1") + + @pytest.mark.skip(reason="Production bug: DataSetTag.binding_count is str|None but DB COUNT() returns int") + @patch("controllers.service_api.dataset.dataset.current_user") + def test_list_tags_from_db( + self, + mock_current_user, + app, + db_session_with_containers: Session, + ): + """Integration test: creates real Tag rows and retrieves them + through the controller without mocking TagService.""" + from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + ) + + account, tenant = create_console_account_and_tenant(db_session_with_containers) + + tag = Tag( + name="Integration Tag", + type=TagType.KNOWLEDGE, + created_by=account.id, + tenant_id=tenant.id, + ) + db_session_with_containers.add(tag) + db_session_with_containers.commit() + + mock_current_user.__class__ = Account + mock_current_user.current_tenant_id = tenant.id + + from controllers.service_api.dataset.dataset import DatasetTagsApi + + with app.test_request_context("/datasets/tags", method="GET"): + api = DatasetTagsApi() + response, status = api.get(_=None) + + assert status == 200 + assert any(t["name"] == "Integration Tag" for t in response) class TestDatasetTagsApiPost: """Test suite for DatasetTagsApi.post() endpoint.""" - # BUG: dataset.py L512 passes ``binding_count=0`` (int) to - # ``DataSetTag.model_validate()``, but ``DataSetTag.binding_count`` - # is typed ``str | None`` (see fields/tag_fields.py L20). - # This causes a Pydantic ValidationError at runtime. @pytest.mark.skip(reason="Production bug: DataSetTag.binding_count is str|None but dataset.py passes int 0") @patch("controllers.service_api.dataset.dataset.TagService") @patch("controllers.service_api.dataset.dataset.current_user") @@ -1386,7 +765,6 @@ class TestDatasetTagsApiPost: mock_tag_svc, app, ): - """Test successful tag creation.""" from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -1409,7 +787,6 @@ class TestDatasetTagsApiPost: @patch("controllers.service_api.dataset.dataset.current_user") def test_create_tag_forbidden(self, mock_current_user, app): - """Test 403 when user lacks edit permission.""" from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -1426,6 +803,146 @@ class TestDatasetTagsApiPost: api.post(_=None) +class TestDatasetTagsApiPatch: + """Test suite for DatasetTagsApi.patch() endpoint.""" + + @pytest.mark.skip(reason="Production bug: DataSetTag.binding_count is str|None but dataset.py passes int 0") + @patch("controllers.service_api.dataset.dataset.TagService") + @patch("controllers.service_api.dataset.dataset.service_api_ns") + @patch("controllers.service_api.dataset.dataset.current_user") + def test_update_tag_success( + self, + mock_current_user, + mock_service_api_ns, + mock_tag_svc, + app, + ): + from controllers.service_api.dataset.dataset import DatasetTagsApi + + mock_current_user.__class__ = Account + mock_current_user.has_edit_permission = True + mock_current_user.is_dataset_editor = True + + mock_tag = SimpleNamespace(id="tag-1", name="Updated Tag", type="knowledge") + mock_tag_svc.update_tags.return_value = mock_tag + mock_tag_svc.get_tag_binding_count.return_value = 5 + mock_service_api_ns.payload = {"name": "Updated Tag", "tag_id": "tag-1"} + + with app.test_request_context( + "/datasets/tags", + method="PATCH", + json={"name": "Updated Tag", "tag_id": "tag-1"}, + ): + api = DatasetTagsApi() + response, status = api.patch(_=None) + + assert status == 200 + assert response["name"] == "Updated Tag" + mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1") + + @patch("controllers.service_api.dataset.dataset.current_user") + def test_update_tag_forbidden(self, mock_current_user, app): + from controllers.service_api.dataset.dataset import DatasetTagsApi + + mock_current_user.__class__ = Account + mock_current_user.has_edit_permission = False + mock_current_user.is_dataset_editor = False + + with app.test_request_context( + "/datasets/tags", + method="PATCH", + json={"name": "Updated Tag", "tag_id": "tag-1"}, + ): + api = DatasetTagsApi() + with pytest.raises(Forbidden): + api.patch(_=None) + + +class TestDatasetTagsApiDelete: + """Test suite for DatasetTagsApi.delete() endpoint.""" + + @patch("controllers.service_api.dataset.dataset.TagService") + @patch("controllers.service_api.dataset.dataset.service_api_ns") + @patch("libs.login.current_user") + def test_delete_tag_success( + self, + mock_current_user, + mock_service_api_ns, + mock_tag_svc, + app, + ): + from controllers.service_api.dataset.dataset import DatasetTagsApi + + user_obj = Mock(spec=Account) + user_obj.has_edit_permission = True + mock_current_user.has_edit_permission = True + # Assign as plain lambda to avoid AsyncMock returning a coroutine + mock_current_user._get_current_object = lambda: user_obj + + mock_tag_svc.delete_tag.return_value = None + mock_service_api_ns.payload = {"tag_id": "tag-1"} + + with app.test_request_context( + "/datasets/tags", + method="DELETE", + json={"tag_id": "tag-1"}, + ): + api = DatasetTagsApi() + result = api.delete(_=None) + + assert result == ("", 204) + mock_tag_svc.delete_tag.assert_called_once_with("tag-1") + + @patch("libs.login.current_user") + def test_delete_tag_forbidden(self, mock_current_user, app): + from controllers.service_api.dataset.dataset import DatasetTagsApi + + user_obj = Mock(spec=Account) + user_obj.has_edit_permission = False + mock_current_user.has_edit_permission = False + # Assign as plain lambda to avoid AsyncMock returning a coroutine + mock_current_user._get_current_object = lambda: user_obj + + with app.test_request_context( + "/datasets/tags", + method="DELETE", + json={"tag_id": "tag-1"}, + ): + api = DatasetTagsApi() + with pytest.raises(Forbidden): + api.delete(_=None) + + +class TestDatasetTagsBindingStatusApi: + """Test suite for DatasetTagsBindingStatusApi endpoints.""" + + @patch("controllers.service_api.dataset.dataset.TagService") + @patch("controllers.service_api.dataset.dataset.current_user") + def test_get_dataset_tags_binding_status( + self, + mock_current_user, + mock_tag_svc, + app, + ): + from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi + + mock_current_user.__class__ = Account + mock_current_user.current_tenant_id = "tenant_123" + mock_tag = Mock() + mock_tag.id = "tag_1" + mock_tag.name = "Test Tag" + mock_tag_svc.get_tags_by_target_id.return_value = [mock_tag] + + with app.test_request_context("/", method="GET"): + api = DatasetTagsBindingStatusApi() + response, status_code = api.get("tenant_123", dataset_id="dataset_123") + + assert status_code == 200 + assert response["data"] == [{"id": "tag_1", "name": "Test Tag"}] + assert response["total"] == 1 + mock_tag_svc.get_tags_by_target_id.assert_called_once_with("knowledge", "tenant_123", "dataset_123") + + class TestDatasetTagBindingApiPost: """Test suite for DatasetTagBindingApi.post() endpoint.""" @@ -1437,7 +954,6 @@ class TestDatasetTagBindingApiPost: mock_tag_svc, app, ): - """Test successful tag binding.""" from controllers.service_api.dataset.dataset import DatasetTagBindingApi mock_current_user.__class__ = Account @@ -1454,10 +970,14 @@ class TestDatasetTagBindingApiPost: result = api.post(_=None) assert result == ("", 204) + from services.tag_service import TagBindingCreatePayload + + mock_tag_svc.save_tag_binding.assert_called_once_with( + TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge") + ) @patch("controllers.service_api.dataset.dataset.current_user") def test_bind_tags_forbidden(self, mock_current_user, app): - """Test 403 when user lacks edit permission.""" from controllers.service_api.dataset.dataset import DatasetTagBindingApi mock_current_user.__class__ = Account @@ -1485,7 +1005,6 @@ class TestDatasetTagUnbindingApiPost: mock_tag_svc, app, ): - """Test successful tag unbinding.""" from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi mock_current_user.__class__ = Account @@ -1502,10 +1021,14 @@ class TestDatasetTagUnbindingApiPost: result = api.post(_=None) assert result == ("", 204) + from services.tag_service import TagBindingDeletePayload + + mock_tag_svc.delete_tag_binding.assert_called_once_with( + TagBindingDeletePayload(tag_id="tag-1", target_id="ds-1", type="knowledge") + ) @patch("controllers.service_api.dataset.dataset.current_user") def test_unbind_tag_forbidden(self, mock_current_user, app): - """Test 403 when user lacks edit permission.""" from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi mock_current_user.__class__ = Account diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py new file mode 100644 index 00000000000..4e884626a77 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py @@ -0,0 +1,110 @@ +""" +Testcontainers integration tests for Service API Site controller. +""" + +from __future__ import annotations + +import pytest +from flask import Flask +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.service_api.app.site import AppSiteApi +from models.account import Tenant, TenantStatus +from models.model import App, AppMode, Site + + +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _unwrap(method): + fn = method + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + return fn + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="service-api-site-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="service-api-site-app", + enable_site=True, + enable_api=True, + status="normal", + ) + db_session.add(app_model) + db_session.commit() + return app_model + + +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, + title="Service API Site", + icon_type="emoji", + icon="robot", + icon_background="#ffffff", + description="Service API test site", + default_language="en-US", + prompt_public=True, + show_workflow_steps=True, + customize_token_strategy="not_allow", + use_icon_as_answer_icon=False, + chat_color_theme="light", + chat_color_theme_inverted=False, + ) + db_session.add(site) + db_session.commit() + return site + + +class TestAppSiteApi: + def test_get_site_success(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + response = _unwrap(api.get)(api, app_model=app_model) + + assert response["title"] == "Service API Site" + assert response["icon"] == "robot" + assert response["description"] == "Service API test site" + + def test_get_site_not_found(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) + + def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + archived_tenant = db_session_with_containers.get(Tenant, tenant.id) + assert archived_tenant is not None + archived_tenant.status = TenantStatus.ARCHIVE + db_session_with_containers.commit() + + app_model = db_session_with_containers.get(App, app_model.id) + assert app_model is not None + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/test_containers_integration_tests/controllers/web/test_site.py similarity index 51% rename from api/tests/unit_tests/controllers/web/test_site.py rename to api/tests/test_containers_integration_tests/controllers/web/test_site.py index 6e9d754c436..9adb26ff3d2 100644 --- a/api/tests/unit_tests/controllers/web/test_site.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_site.py @@ -1,28 +1,48 @@ -"""Unit tests for controllers.web.site endpoints.""" +"""Testcontainers integration tests for controllers.web.site endpoints.""" from __future__ import annotations from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from controllers.web.site import AppSiteApi, AppSiteInfo +from models import Tenant, TenantStatus +from models.model import App, AppMode, CustomizeTokenStrategy, Site -def _tenant(*, status: str = "normal") -> SimpleNamespace: - return SimpleNamespace( - id="tenant-1", - status=status, - plan="basic", - custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False}, +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="test-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="test-app", + enable_site=enable_site, + enable_api=True, ) + db_session.add(app_model) + db_session.commit() + return app_model -def _site() -> SimpleNamespace: - return SimpleNamespace( +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, title="Site", icon_type="emoji", icon="robot", @@ -31,77 +51,64 @@ def _site() -> SimpleNamespace: default_language="en", chat_color_theme="light", chat_color_theme_inverted=False, - copyright=None, - privacy_policy=None, - custom_disclaimer=None, + customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, + code=f"code-{app_id[-6:]}", prompt_public=False, show_workflow_steps=True, use_icon_as_answer_icon=False, ) + db_session.add(site) + db_session.commit() + return site -# --------------------------------------------------------------------------- -# AppSiteApi -# --------------------------------------------------------------------------- class TestAppSiteApi: @patch("controllers.web.site.FeatureService.get_features") - @patch("controllers.web.site.db") - def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None: + def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_features.return_value = SimpleNamespace(can_replace_logo=False) - site_obj = _site() - mock_db.session.scalar.return_value = site_obj - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): result = AppSiteApi().get(app_model, end_user) - # marshal_with serializes AppSiteInfo to a dict - assert result["app_id"] == "app-1" + assert result["app_id"] == app_model.id assert result["plan"] == "basic" assert result["enable_site"] is True - @patch("controllers.web.site.db") - def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + def test_missing_site_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_db.session.scalar.return_value = None - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) end_user = SimpleNamespace(id="eu-1") with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) - @patch("controllers.web.site.db") - def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + @patch("controllers.web.site.FeatureService.get_features") + def test_archived_tenant_raises_forbidden( + self, mock_features, app: Flask, db_session_with_containers: Session + ) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - from models.account import TenantStatus - - mock_db.session.scalar.return_value = _site() - tenant = SimpleNamespace( - id="tenant-1", - status=TenantStatus.ARCHIVE, - plan="basic", - custom_config_dict={}, - ) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) -# --------------------------------------------------------------------------- -# AppSiteInfo -# --------------------------------------------------------------------------- class TestAppSiteInfo: def test_basic_fields(self) -> None: - tenant = _tenant() - site_obj = _site() + tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={}) + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False) assert info.app_id == "app-1" @@ -118,7 +125,7 @@ class TestAppSiteInfo: plan="pro", custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True}, ) - site_obj = _site() + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True) assert info.can_replace_logo is True diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 04ad143103b..f14b2c0ae56 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -37,10 +37,8 @@ class TestForgotPasswordSendEmailApi: @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") - @patch("controllers.web.forgot_password.sessionmaker") def test_should_normalize_email_before_sending( self, - mock_session_cls, mock_extract_ip, mock_rate_limit, mock_get_account, @@ -50,19 +48,16 @@ class TestForgotPasswordSendEmailApi: mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_mail.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password", - method="POST", - json={"email": "User@Example.com", "language": "zh-Hans"}, - ): - response = ForgotPasswordSendEmailApi().post() + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -153,14 +148,14 @@ class TestForgotPasswordResetApi: @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") def test_should_fetch_account_with_fallback( self, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_get_account, mock_update_account, app, @@ -168,29 +163,27 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = mock_account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "token-123", - "new_password": "ValidPass123!", - "password_confirm": "ValidPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -199,7 +192,7 @@ class TestForgotPasswordResetApi: mock_get_account, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_token_bytes, mock_hash_password, app, @@ -207,20 +200,18 @@ class TestForgotPasswordResetApi: mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() mock_get_account.return_value = account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("reset-token") diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 2b4c1b59abf..c342e8994b3 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,13 +22,6 @@ import uuid from time import time import pytest -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events import GraphRunPausedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -40,6 +33,13 @@ from core.app.layers.pause_state_persist_layer import ( ) from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel @@ -88,11 +88,11 @@ class TestPauseStatePersistenceLayerTestContainers: def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service): """Set up test data for each test method using TestContainers.""" # Create test tenant and account - from models.account import Tenant, TenantAccountJoin, TenantAccountRole + from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -101,7 +101,7 @@ class TestPauseStatePersistenceLayerTestContainers: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -557,11 +557,9 @@ class TestPauseStatePersistenceLayerTestContainers: self.session.refresh(self.test_workflow_run) assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING - pause_states = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) - .all() - ) + pause_states = self.session.scalars( + select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) + ).all() assert len(pause_states) == 0 def test_layer_requires_initialization(self, db_session_with_containers): diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 13caad799eb..6524d6ce61f 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,12 +4,11 @@ from __future__ import annotations from uuid import uuid4 -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -18,7 +17,15 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction +from models.account import ( + Account, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountRole, + TenantStatus, +) from models.human_input import ( EmailExternalRecipientPayload, EmailMemberRecipientPayload, @@ -29,7 +36,7 @@ from models.human_input import ( def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]: - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) session.add(tenant) session.flush() @@ -39,7 +46,7 @@ def _create_tenant_with_members(session: Session, member_emails: list[str]) -> t email=email, name=f"Member {index}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) session.add(account) session.flush() diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 0a9b476afc5..5aed230cd4a 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,6 +4,17 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowType from graphon.graph import Graph from graphon.graph_engine import GraphEngine @@ -16,20 +27,9 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool -from sqlalchemy import delete, select -from sqlalchemy.orm import Session - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account -from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import App, AppMode, IconType from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun @@ -101,8 +101,8 @@ def _build_graph( start_data = StartNodeData(title="start", variables=[]) start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, + node_id="start", + config=start_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -116,8 +116,8 @@ def _build_graph( ], ) human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, + node_id="human", + config=human_data, graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, @@ -130,8 +130,8 @@ def _build_graph( desc=None, ) end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, + node_id="end", + config=end_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -175,7 +175,7 @@ class TestHumanInputResumeNodeExecutionIntegration: def setup_test_data(self, db_session_with_containers: Session): tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -184,7 +184,7 @@ class TestHumanInputResumeNodeExecutionIntegration: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index cc72dc1cf39..35e41035df6 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,13 +4,13 @@ from unittest.mock import patch from uuid import uuid4 import pytest -from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase): file_related_id = related_id return File( - id=str(uuid4()), # Generate new UUID for File.id + file_id=str(uuid4()), # Generate new UUID for File.id tenant_id=tenant_id, - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=transfer_method, related_id=file_related_id, remote_url=remote_url, diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index b745aed1417..2fd289dfbca 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -6,7 +6,6 @@ from decimal import Decimal from uuid import uuid4 from graphon.nodes.human_input.entities import FormDefinition, UserAction - from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom diff --git a/api/tests/test_containers_integration_tests/models/test_account.py b/api/tests/test_containers_integration_tests/models/test_account.py index 078dc0e8de1..6fd6716cbb9 100644 --- a/api/tests/test_containers_integration_tests/models/test_account.py +++ b/api/tests/test_containers_integration_tests/models/test_account.py @@ -1,79 +1,202 @@ -# import secrets +""" +Integration tests for Account and Tenant model methods that interact with the database. -# import pytest -# from sqlalchemy import select -# from sqlalchemy.orm import Session -# from sqlalchemy.orm.exc import DetachedInstanceError +Migrated from unit_tests/models/test_account_models.py, replacing +@patch("models.account.db") mock patches with real PostgreSQL operations. -# from libs.datetime_utils import naive_utc_now -# from models.account import Account, Tenant, TenantAccountJoin +Covers: +- Account.current_tenant setter (sets _current_tenant and role from TenantAccountJoin) +- Account.set_tenant_id (resolves tenant + role from real join row) +- Account.get_by_openid (AccountIntegrate lookup then Account fetch) +- Tenant.get_accounts (returns accounts linked via TenantAccountJoin) +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy import delete +from sqlalchemy.orm import Session + +from models.account import Account, AccountIntegrate, Tenant, TenantAccountJoin, TenantAccountRole -# @pytest.fixture -# def session(db_session_with_containers): -# with Session(db_session_with_containers.get_bind()) as session: -# yield session +def _cleanup_tracked_rows(db_session: Session, tracked: list) -> None: + """Delete rows tracked during the test so committed state does not leak into the DB. + + Rolls back any pending (uncommitted) session state first, then issues DELETE + statements by primary key for each tracked entity (in reverse creation order) + and commits. This cleans up rows created via either flush() or commit(). + """ + db_session.rollback() + for entity in reversed(tracked): + db_session.execute(delete(type(entity)).where(type(entity).id == entity.id)) + db_session.commit() -# @pytest.fixture -# def account(session): -# account = Account( -# name="test account", -# email=f"test_{secrets.token_hex(8)}@example.com", -# ) -# session.add(account) -# session.commit() -# return account +def _build_tenant() -> Tenant: + return Tenant(name=f"Tenant {uuid4()}") -# @pytest.fixture -# def tenant(session): -# tenant = Tenant(name="test tenant") -# session.add(tenant) -# session.commit() -# return tenant +def _build_account(email_prefix: str = "account") -> Account: + return Account( + name=f"Account {uuid4()}", + email=f"{email_prefix}_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) -# @pytest.fixture -# def tenant_account_join(session, account, tenant): -# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id) -# session.add(tenant_join) -# session.commit() -# yield tenant_join -# session.delete(tenant_join) -# session.commit() +class _DBTrackingTestBase: + """Base class providing a tracker list and shared row factories for account/tenant tests.""" + + _tracked: list + + @pytest.fixture(autouse=True) + def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]: + self._tracked = [] + yield + _cleanup_tracked_rows(db_session_with_containers, self._tracked) + + def _create_tenant(self, db_session: Session) -> Tenant: + tenant = _build_tenant() + db_session.add(tenant) + db_session.flush() + self._tracked.append(tenant) + return tenant + + def _create_account(self, db_session: Session, email_prefix: str = "account") -> Account: + account = _build_account(email_prefix) + db_session.add(account) + db_session.flush() + self._tracked.append(account) + return account + + def _create_join( + self, db_session: Session, tenant_id: str, account_id: str, role: TenantAccountRole, current: bool = True + ) -> TenantAccountJoin: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id, role=role, current=current) + db_session.add(join) + db_session.flush() + self._tracked.append(join) + return join -# class TestAccountTenant: -# def test_set_current_tenant_should_reload_tenant( -# self, -# db_session_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session: -# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one() -# account.current_tenant = scoped_tenant -# scoped_tenant.created_at = naive_utc_now() -# # session.commit() +class TestAccountCurrentTenantSetter(_DBTrackingTestBase): + """Integration tests for Account.current_tenant property setter.""" -# # Ensure the tenant used in assignment is detached. -# with pytest.raises(DetachedInstanceError): -# _ = scoped_tenant.name + def test_current_tenant_property_returns_cached_tenant(self, db_session_with_containers: Session) -> None: + """current_tenant getter returns the in-memory _current_tenant without DB access.""" + account = self._create_account(db_session_with_containers) + tenant = self._create_tenant(db_session_with_containers) + account._current_tenant = tenant -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id + assert account.current_tenant is tenant -# def test_set_tenant_id_should_load_tenant_as_not_expire( -# self, -# flask_app_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with flask_app_with_containers.test_request_context(): -# account.set_tenant_id(tenant.id) + def test_current_tenant_setter_sets_tenant_and_role_when_join_exists( + self, db_session_with_containers: Session + ) -> None: + """Setting current_tenant loads the join row and assigns role when relationship exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.OWNER) + db_session_with_containers.commit() -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id + account.current_tenant = tenant + + assert account._current_tenant is not None + assert account._current_tenant.id == tenant.id + assert account.role == TenantAccountRole.OWNER + + def test_current_tenant_setter_sets_none_when_no_join_exists(self, db_session_with_containers: Session) -> None: + """Setting current_tenant results in _current_tenant=None when no join row exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + db_session_with_containers.commit() + + account.current_tenant = tenant + + assert account._current_tenant is None + + +class TestAccountSetTenantId(_DBTrackingTestBase): + """Integration tests for Account.set_tenant_id method.""" + + def test_set_tenant_id_sets_tenant_and_role_when_relationship_exists( + self, db_session_with_containers: Session + ) -> None: + """set_tenant_id loads the tenant and assigns role when a join row exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.ADMIN) + db_session_with_containers.commit() + + account.set_tenant_id(tenant.id) + + assert account._current_tenant is not None + assert account._current_tenant.id == tenant.id + assert account.role == TenantAccountRole.ADMIN + + def test_set_tenant_id_does_not_set_tenant_when_no_relationship_exists( + self, db_session_with_containers: Session + ) -> None: + """set_tenant_id does nothing when no join row matches the tenant.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + db_session_with_containers.commit() + + account.set_tenant_id(tenant.id) + + assert account._current_tenant is None + + +class TestAccountGetByOpenId(_DBTrackingTestBase): + """Integration tests for Account.get_by_openid class method.""" + + def test_get_by_openid_returns_account_when_integrate_exists(self, db_session_with_containers: Session) -> None: + """get_by_openid returns the Account when a matching AccountIntegrate row exists.""" + account = self._create_account(db_session_with_containers, email_prefix="openid") + provider = "google" + open_id = f"google_{uuid4()}" + + integrate = AccountIntegrate( + account_id=account.id, + provider=provider, + open_id=open_id, + encrypted_token="token", + ) + db_session_with_containers.add(integrate) + db_session_with_containers.flush() + self._tracked.append(integrate) + + result = Account.get_by_openid(provider, open_id) + + assert result is not None + assert result.id == account.id + + def test_get_by_openid_returns_none_when_no_integrate_exists(self, db_session_with_containers: Session) -> None: + """get_by_openid returns None when no AccountIntegrate row matches.""" + result = Account.get_by_openid("github", f"github_{uuid4()}") + + assert result is None + + +class TestTenantGetAccounts(_DBTrackingTestBase): + """Integration tests for Tenant.get_accounts method.""" + + def test_get_accounts_returns_linked_accounts(self, db_session_with_containers: Session) -> None: + """get_accounts returns all accounts linked to the tenant via TenantAccountJoin.""" + tenant = self._create_tenant(db_session_with_containers) + account1 = self._create_account(db_session_with_containers, email_prefix="tenant_member") + account2 = self._create_account(db_session_with_containers, email_prefix="tenant_member") + self._create_join(db_session_with_containers, tenant.id, account1.id, TenantAccountRole.OWNER, current=False) + self._create_join(db_session_with_containers, tenant.id, account2.id, TenantAccountRole.NORMAL, current=False) + + accounts = tenant.get_accounts() + + assert len(accounts) == 2 + account_ids = {a.id for a in accounts} + assert account1.id in account_ids + assert account2.id in account_ids diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py new file mode 100644 index 00000000000..f10f519e250 --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py @@ -0,0 +1,149 @@ +""" +Integration tests for Conversation.inputs and Message.inputs tenant resolution. + +Migrated from unit_tests/models/test_model.py, replacing db.session.scalar monkeypatching +with a real App in PostgreSQL so the _resolve_app_tenant_id lookup executes against the DB. +""" + +from collections.abc import Generator +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from core.workflow.file_reference import build_file_reference +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod +from models.model import App, AppMode, Conversation, Message + + +def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) -> dict: + mapping: dict = { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "reference": build_file_reference(record_id=record_id), + "type": "document", + "filename": "example.txt", + "extension": ".txt", + "mime_type": "text/plain", + "size": 1, + } + if tenant_id is not None: + mapping["tenant_id"] = tenant_id + return mapping + + +class TestConversationMessageInputsTenantResolution: + """Integration tests for Conversation/Message.inputs tenant resolution via real DB lookup.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_app(self, db_session: Session) -> App: + tenant_id = str(uuid4()) + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.CHAT, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=str(uuid4()), + updated_by=str(uuid4()), + ) + db_session.add(app) + db_session.flush() + return app + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_resolves_tenant_via_db_for_local_file( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs resolves tenant_id from real App row when file mapping has no tenant_id.""" + app = self._create_app(db_session_with_containers) + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = {"file": _build_local_file_mapping("upload-1")} + + restored_inputs = owner.inputs + + # The tenant_id should come from the real App row in the DB + assert restored_inputs["file"] == {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"} + assert len(build_calls) == 1 + assert build_calls[0][1] == app.tenant_id + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_uses_serialized_tenant_id_skipping_db_lookup( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs uses tenant_id from the file mapping payload without hitting the DB.""" + app = self._create_app(db_session_with_containers) + payload_tenant_id = "tenant-from-payload" + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id=payload_tenant_id)} + + restored_inputs = owner.inputs + + assert restored_inputs["file"] == {"tenant_id": payload_tenant_id, "upload_file_id": "upload-1"} + assert len(build_calls) == 1 + assert build_calls[0][1] == payload_tenant_id + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_resolves_tenant_for_file_list( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs resolves tenant_id for a list of file mappings.""" + app = self._create_app(db_session_with_containers) + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = { + "files": [ + _build_local_file_mapping("upload-1"), + _build_local_file_mapping("upload-2"), + ] + } + + restored_inputs = owner.inputs + + assert len(build_calls) == 2 + assert all(call[1] == app.tenant_id for call in build_calls) + assert restored_inputs["files"] == [ + {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"}, + {"tenant_id": app.tenant_id, "upload_file_id": "upload-2"}, + ] diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py new file mode 100644 index 00000000000..6352f815df7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py @@ -0,0 +1,314 @@ +""" +Integration tests for Conversation.status_count and Site.generate_code model properties. + +Migrated from unit_tests/models/test_app_models.py TestConversationStatusCount and +test_site_generate_code, replacing db.session.scalars mocks with real PostgreSQL queries. +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from graphon.enums import WorkflowExecutionStatus +from models.enums import ConversationFromSource, InvokeFrom +from models.model import App, AppMode, Conversation, Message, Site +from models.workflow import Workflow, WorkflowRun, WorkflowRunTriggeredFrom, WorkflowType + + +class TestConversationStatusCount: + """Integration tests for Conversation.status_count property.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App: + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.ADVANCED_CHAT, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=created_by, + updated_by=created_by, + ) + db_session.add(app) + db_session.flush() + return app + + def _create_conversation(self, db_session: Session, app: App) -> Conversation: + conversation = Conversation( + app_id=app.id, + mode=app.mode, + name=f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.WEB_APP, + from_source=ConversationFromSource.API, + dialogue_count=0, + is_deleted=False, + ) + conversation.inputs = {} + db_session.add(conversation) + db_session.flush() + return conversation + + def _create_workflow(self, db_session: Session, app: App, created_by: str) -> Workflow: + workflow = Workflow( + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.CHAT, + version="draft", + graph="{}", + created_by=created_by, + ) + workflow._features = "{}" + db_session.add(workflow) + db_session.flush() + return workflow + + def _create_workflow_run( + self, db_session: Session, app: App, workflow: Workflow, status: WorkflowExecutionStatus, created_by: str + ) -> WorkflowRun: + run = WorkflowRun( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type=WorkflowType.CHAT, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + version="draft", + status=status, + created_by_role="account", + created_by=created_by, + ) + db_session.add(run) + db_session.flush() + return run + + def _create_message( + self, db_session: Session, app: App, conversation: Conversation, workflow_run_id: str | None = None + ) -> Message: + message = Message( + app_id=app.id, + conversation_id=conversation.id, + _inputs={}, + query="Test query", + message={"role": "user", "content": "Test query"}, + answer="Test answer", + model_provider="openai", + model_id="gpt-4", + message_tokens=10, + message_unit_price=0, + answer_tokens=10, + answer_unit_price=0, + total_price=0, + currency="USD", + from_source=ConversationFromSource.API, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id=workflow_run_id, + ) + db_session.add(message) + db_session.flush() + return message + + def test_status_count_returns_none_when_no_messages(self, db_session_with_containers: Session) -> None: + """status_count returns None when conversation has no messages with workflow_run_id.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + + result = conversation.status_count + + assert result is None + + def test_status_count_returns_none_when_messages_have_no_workflow_run_id( + self, db_session_with_containers: Session + ) -> None: + """status_count returns None when messages exist but none have workflow_run_id.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=None) + + result = conversation.status_count + + assert result is None + + def test_status_count_counts_succeeded_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts succeeded workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 1 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_counts_failed_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts failed workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.FAILED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 0 + assert result["failed"] == 1 + assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_counts_paused_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts paused workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.PAUSED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + assert result["paused"] == 1 + + def test_status_count_multiple_statuses(self, db_session_with_containers: Session) -> None: + """status_count counts multiple workflow runs with different statuses.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + + for status in [ + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + WorkflowExecutionStatus.PAUSED, + ]: + run = self._create_workflow_run(db_session_with_containers, app, workflow, status, created_by) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 1 + assert result["failed"] == 1 + assert result["partial_success"] == 1 + assert result["paused"] == 1 + + def test_status_count_filters_workflow_runs_by_app_id(self, db_session_with_containers: Session) -> None: + """status_count excludes workflow runs belonging to a different app.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + other_app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, other_app, created_by) + + # Workflow run belongs to other_app, not app + other_run = self._create_workflow_run( + db_session_with_containers, other_app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by + ) + # Message references that run but is in a conversation under app + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=other_run.id) + + result = conversation.status_count + + # The run should be excluded because app_id filter doesn't match + assert result is not None + assert result["success"] == 0 + + +class TestSiteGenerateCode: + """Integration tests for Site.generate_code static method.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def test_generate_code_returns_string_of_correct_length(self, db_session_with_containers: Session) -> None: + """Site.generate_code returns a code string of the requested length.""" + code = Site.generate_code(8) + + assert isinstance(code, str) + assert len(code) == 8 + + def test_generate_code_avoids_duplicates(self, db_session_with_containers: Session) -> None: + """Site.generate_code returns a code not already in use.""" + tenant_id = str(uuid4()) + app = App( + tenant_id=tenant_id, + name="Test App", + mode=AppMode.CHAT, + enable_site=True, + enable_api=False, + is_demo=False, + is_public=False, + is_universal=False, + created_by=str(uuid4()), + updated_by=str(uuid4()), + ) + db_session_with_containers.add(app) + db_session_with_containers.flush() + + site = Site( + app_id=app.id, + title="Test Site", + default_language="en-US", + customize_token_strategy="not_allow", + ) + # Set an explicit code so generate_code must avoid it + site.code = "AAAAAAAA" + db_session_with_containers.add(site) + db_session_with_containers.flush() + + code = Site.generate_code(8) + + assert isinstance(code, str) + assert len(code) == 8 + assert code != site.code diff --git a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py index 9cf96c1ca70..b325c97f7d1 100644 --- a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py +++ b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py @@ -5,11 +5,12 @@ from typing import Any, NamedTuple import pytest import sqlalchemy as sa from sqlalchemy import exc as sa_exc -from sqlalchemy import insert +from sqlalchemy import insert, select from sqlalchemy.engine import Connection, Engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.sql.sqltypes import VARCHAR +from graphon.model_runtime.entities.model_entities import ModelType from models.types import EnumText _USER_TABLE = "enum_text_users" @@ -58,6 +59,13 @@ class _ColumnTest(_Base): long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False) +class _LegacyModelTypeRecord(_Base): + __tablename__ = "enum_text_legacy_model_type_test" + + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + model_type: Mapped[ModelType] = mapped_column(EnumText(enum_class=ModelType), nullable=False) + + def _first[T](it: Iterable[T]) -> T: ls = list(it) if not ls: @@ -129,12 +137,12 @@ class TestEnumText: session.commit() with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == admin_user_id).first() + user = session.scalar(select(_User).where(_User.id == admin_user_id).limit(1)) assert user.user_type == _UserType.admin assert user.user_type_nullable is None with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == normal_user_id).first() + user = session.scalar(select(_User).where(_User.id == normal_user_id).limit(1)) assert user.user_type == _UserType.normal assert user.user_type_nullable == _UserType.normal @@ -198,6 +206,26 @@ class TestEnumText: with pytest.raises(ValueError) as exc: with Session(engine_with_containers) as session: - _user = session.query(_User).where(_User.id == 1).first() + _user = session.scalar(select(_User).where(_User.id == 1).limit(1)) assert str(exc.value) == "'invalid' is not a valid _UserType" + + def test_select_legacy_model_type_values(self, engine_with_containers: Engine): + insertion_sql = """ + INSERT INTO enum_text_legacy_model_type_test (id, model_type) VALUES + (1, 'text-generation'), + (2, 'embeddings'), + (3, 'reranking'); + """ + with Session(engine_with_containers) as session: + session.execute(sa.text(insertion_sql)) + session.commit() + + with Session(engine_with_containers) as session: + records = session.scalars(select(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id)).all() + + assert [record.model_type for record in records] == [ + ModelType.LLM, + ModelType.TEXT_EMBEDDING, + ModelType.RERANK, + ] diff --git a/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py b/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py new file mode 100644 index 00000000000..14c2263110c --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py @@ -0,0 +1,170 @@ +""" +Integration tests for WorkflowNodeExecutionModel.created_by_account and .created_by_end_user. + +Migrated from unit_tests/models/test_workflow_trigger_log.py, replacing +monkeypatch.setattr(db.session, "scalar", ...) with real Account/EndUser rows +persisted in PostgreSQL so the db.session.get() call executes against the DB. +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account +from models.enums import CreatorUserRole +from models.model import App, AppMode, EndUser +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom + + +class TestWorkflowNodeExecutionModelCreatedBy: + """Integration tests for WorkflowNodeExecutionModel creator lookup properties.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_account(self, db_session: Session) -> Account: + account = Account( + name="Test Account", + email=f"test_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + return account + + def _create_end_user(self, db_session: Session, tenant_id: str, app_id: str) -> EndUser: + end_user = EndUser( + tenant_id=tenant_id, + app_id=app_id, + type="service_api", + external_user_id=f"ext-{uuid4()}", + name="End User", + session_id=f"session-{uuid4()}", + ) + end_user.is_anonymous = False + db_session.add(end_user) + db_session.flush() + return end_user + + def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App: + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.WORKFLOW, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=created_by, + updated_by=created_by, + ) + db_session.add(app) + db_session.flush() + return app + + def _make_execution( + self, tenant_id: str, app_id: str, created_by_role: str, created_by: str + ) -> WorkflowNodeExecutionModel: + return WorkflowNodeExecutionModel( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=str(uuid4()), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=None, + index=1, + predecessor_node_id=None, + node_execution_id=None, + node_id="n1", + node_type="start", + title="Start", + inputs=None, + process_data=None, + outputs=None, + status="succeeded", + error=None, + elapsed_time=0.0, + execution_metadata=None, + created_by_role=created_by_role, + created_by=created_by, + ) + + def test_created_by_account_returns_account_when_role_is_account(self, db_session_with_containers: Session) -> None: + """created_by_account returns the Account row when role is ACCOUNT.""" + account = self._create_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, str(uuid4()), account.id) + + execution = self._make_execution( + tenant_id=app.tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + ) + + result = execution.created_by_account + + assert result is not None + assert result.id == account.id + + def test_created_by_account_returns_none_when_role_is_end_user(self, db_session_with_containers: Session) -> None: + """created_by_account returns None when role is END_USER, even if an Account exists.""" + account = self._create_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, str(uuid4()), account.id) + + execution = self._make_execution( + tenant_id=app.tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.END_USER.value, + created_by=account.id, + ) + + result = execution.created_by_account + + assert result is None + + def test_created_by_end_user_returns_end_user_when_role_is_end_user( + self, db_session_with_containers: Session + ) -> None: + """created_by_end_user returns the EndUser row when role is END_USER.""" + account = self._create_account(db_session_with_containers) + tenant_id = str(uuid4()) + app = self._create_app(db_session_with_containers, tenant_id, account.id) + end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id) + + execution = self._make_execution( + tenant_id=tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.END_USER.value, + created_by=end_user.id, + ) + + result = execution.created_by_end_user + + assert result is not None + assert result.id == end_user.id + + def test_created_by_end_user_returns_none_when_role_is_account(self, db_session_with_containers: Session) -> None: + """created_by_end_user returns None when role is ACCOUNT, even if an EndUser exists.""" + account = self._create_account(db_session_with_containers) + tenant_id = str(uuid4()) + app = self._create_app(db_session_with_containers, tenant_id, account.id) + end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id) + + execution = self._make_execution( + tenant_id=tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=end_user.id, + ) + + result = execution.created_by_end_user + + assert result is None diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index a68b3a08c7a..641399c7f94 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index d28cfda1598..aebe87839c5 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -8,15 +8,15 @@ from unittest.mock import Mock from uuid import uuid4 import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from sqlalchemy import Engine, delete, select -from sqlalchemy.orm import Session, sessionmaker - -from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( @@ -220,7 +220,6 @@ class TestDeleteRunsWithRelated: created_by=test_scope.user_id, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -280,7 +279,6 @@ class TestCountRunsWithRelated: created_by=test_scope.user_id, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -544,7 +542,6 @@ class TestPrivateWorkflowPauseEntity: status=WorkflowExecutionStatus.RUNNING, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -574,7 +571,6 @@ class TestPrivateWorkflowPauseEntity: ) state_key = f"workflow-state-{uuid4()}.json" pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=state_key, @@ -606,7 +602,6 @@ class TestPrivateWorkflowPauseEntity: ) state_key = f"workflow-state-{uuid4()}.json" pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=state_key, @@ -672,7 +667,6 @@ class TestBuildHumanInputRequiredReason: status=WorkflowExecutionStatus.RUNNING, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index 7f44eb6ca37..54b7afc018a 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -12,11 +12,11 @@ from decimal import Decimal from uuid import uuid4 import pytest -from graphon.nodes.human_input.entities import FormDefinition, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ConversationFromSource, InvokeFrom @@ -271,7 +271,7 @@ def _create_recipient( def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: - from core.workflow.human_input_compat import DeliveryMethodType + from core.workflow.human_input_adapter import DeliveryMethodType from models.human_input import ConsoleDeliveryPayload delivery = HumanInputDelivery( diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py new file mode 100644 index 00000000000..fa78f1c28b1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -0,0 +1,395 @@ +"""Testcontainers integration tests for SQLAlchemyWorkflowNodeExecutionRepository.""" + +from __future__ import annotations + +import json +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker + +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.model_runtime.utils.encoders import jsonable_encoder +from models.account import Account, Tenant +from models.enums import CreatorUserRole +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom + + +def _create_account_with_tenant(session: Session) -> Account: + tenant = Tenant(name="Test Workspace") + session.add(tenant) + session.flush() + + account = Account(name="test", email=f"test-{uuid4()}@example.com") + session.add(account) + session.flush() + + account._current_tenant = tenant + return account + + +def _make_repo(session: Session, account: Account, app_id: str) -> SQLAlchemyWorkflowNodeExecutionRepository: + engine = session.get_bind() + assert isinstance(engine, Engine) + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=sessionmaker(bind=engine, expire_on_commit=False), + user=account, + app_id=app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def _create_node_execution_model( + session: Session, + *, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_run_id: str, + index: int = 1, + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING, +) -> WorkflowNodeExecutionModel: + model = WorkflowNodeExecutionModel( + id=str(uuid4()), + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=workflow_run_id, + index=index, + predecessor_node_id=None, + node_execution_id=str(uuid4()), + node_id=f"node-{index}", + node_type=BuiltinNodeTypes.START, + title=f"Test Node {index}", + inputs='{"input_key": "input_value"}', + process_data='{"process_key": "process_value"}', + outputs='{"output_key": "output_value"}', + status=status, + error=None, + elapsed_time=1.5, + execution_metadata="{}", + created_at=datetime.now(), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + finished_at=None, + ) + session.add(model) + session.flush() + return model + + +class TestSave: + def test_save_new_record(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + predecessor_node_id=None, + node_id="node-1", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs={"input_key": "input_value"}, + process_data={"process_key": "process_value"}, + outputs={"result": "success"}, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=1.5, + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}, + created_at=datetime.now(), + finished_at=None, + ) + + repo.save(execution) + + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session: + saved = verify_session.get(WorkflowNodeExecutionModel, execution.id) + assert saved is not None + assert saved.tenant_id == account.current_tenant_id + assert saved.app_id == app_id + assert saved.node_id == "node-1" + assert saved.status == WorkflowNodeExecutionStatus.RUNNING + + def test_save_updates_existing_record(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + repo = _make_repo(db_session_with_containers, account, str(uuid4())) + + execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + predecessor_node_id=None, + node_id="node-1", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs=None, + process_data=None, + outputs=None, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=0.0, + metadata=None, + created_at=datetime.now(), + finished_at=None, + ) + + repo.save(execution) + + execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + execution.elapsed_time = 2.5 + repo.save(execution) + + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session: + saved = verify_session.get(WorkflowNodeExecutionModel, execution.id) + assert saved is not None + assert saved.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert saved.elapsed_time == 2.5 + + +class TestGetByWorkflowExecution: + def test_returns_executions_ordered(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + tenant_id = account.current_tenant_id + app_id = str(uuid4()) + workflow_id = str(uuid4()) + workflow_run_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=2, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + db_session_with_containers.commit() + + order_config = OrderConfig(order_by=["index"], order_direction="desc") + result = repo.get_by_workflow_execution( + workflow_execution_id=workflow_run_id, + order_config=order_config, + ) + + assert len(result) == 2 + assert result[0].index == 2 + assert result[1].index == 1 + assert all(isinstance(r, WorkflowNodeExecution) for r in result) + + def test_excludes_paused_executions(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + tenant_id = account.current_tenant_id + app_id = str(uuid4()) + workflow_id = str(uuid4()) + workflow_run_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=1, + status=WorkflowNodeExecutionStatus.RUNNING, + ) + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=2, + status=WorkflowNodeExecutionStatus.PAUSED, + ) + db_session_with_containers.commit() + + result = repo.get_by_workflow_execution(workflow_execution_id=workflow_run_id) + + assert len(result) == 1 + assert result[0].index == 1 + + +class TestToDbModel: + def test_converts_domain_to_db_model(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + domain_model = WorkflowNodeExecution( + id="test-id", + workflow_id="test-workflow-id", + node_execution_id="test-node-execution-id", + workflow_execution_id="test-workflow-run-id", + index=1, + predecessor_node_id="test-predecessor-id", + node_id="test-node-id", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs={"input_key": "input_value"}, + process_data={"process_key": "process_value"}, + outputs={"output_key": "output_value"}, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=1.5, + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"), + }, + created_at=datetime.now(), + finished_at=None, + ) + + db_model = repo._to_db_model(domain_model) + + assert isinstance(db_model, WorkflowNodeExecutionModel) + assert db_model.id == domain_model.id + assert db_model.tenant_id == account.current_tenant_id + assert db_model.app_id == app_id + assert db_model.workflow_id == domain_model.workflow_id + assert db_model.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + assert db_model.workflow_run_id == domain_model.workflow_execution_id + assert db_model.index == domain_model.index + assert db_model.predecessor_node_id == domain_model.predecessor_node_id + assert db_model.node_execution_id == domain_model.node_execution_id + assert db_model.node_id == domain_model.node_id + assert db_model.node_type == domain_model.node_type + assert db_model.title == domain_model.title + assert db_model.inputs_dict == domain_model.inputs + assert db_model.process_data_dict == domain_model.process_data + assert db_model.outputs_dict == domain_model.outputs + assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata) + assert db_model.status == domain_model.status + assert db_model.error == domain_model.error + assert db_model.elapsed_time == domain_model.elapsed_time + assert db_model.created_at == domain_model.created_at + assert db_model.created_by_role == CreatorUserRole.ACCOUNT + assert db_model.created_by == account.id + assert db_model.finished_at == domain_model.finished_at + + +class TestToDomainModel: + def test_converts_db_to_domain_model(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + inputs_dict = {"input_key": "input_value"} + process_data_dict = {"process_key": "process_value"} + outputs_dict = {"output_key": "output_value"} + metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100} + now = datetime.now() + + db_model = WorkflowNodeExecutionModel() + db_model.id = "test-id" + db_model.tenant_id = account.current_tenant_id + db_model.app_id = app_id + db_model.workflow_id = "test-workflow-id" + db_model.triggered_from = "workflow-run" + db_model.workflow_run_id = "test-workflow-run-id" + db_model.index = 1 + db_model.predecessor_node_id = "test-predecessor-id" + db_model.node_execution_id = "test-node-execution-id" + db_model.node_id = "test-node-id" + db_model.node_type = BuiltinNodeTypes.START + db_model.title = "Test Node" + db_model.inputs = json.dumps(inputs_dict) + db_model.process_data = json.dumps(process_data_dict) + db_model.outputs = json.dumps(outputs_dict) + db_model.status = WorkflowNodeExecutionStatus.RUNNING + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.execution_metadata = json.dumps(metadata_dict) + db_model.created_at = now + db_model.created_by_role = "account" + db_model.created_by = account.id + db_model.finished_at = None + + domain_model = repo._to_domain_model(db_model) + + assert isinstance(domain_model, WorkflowNodeExecution) + assert domain_model.id == "test-id" + assert domain_model.workflow_id == "test-workflow-id" + assert domain_model.workflow_execution_id == "test-workflow-run-id" + assert domain_model.index == 1 + assert domain_model.predecessor_node_id == "test-predecessor-id" + assert domain_model.node_execution_id == "test-node-execution-id" + assert domain_model.node_id == "test-node-id" + assert domain_model.node_type == BuiltinNodeTypes.START + assert domain_model.title == "Test Node" + assert domain_model.inputs == inputs_dict + assert domain_model.process_data == process_data_dict + assert domain_model.outputs == outputs_dict + assert domain_model.status == WorkflowNodeExecutionStatus.RUNNING + assert domain_model.error is None + assert domain_model.elapsed_time == 1.5 + assert domain_model.metadata == {WorkflowNodeExecutionMetadataKey(k): v for k, v in metadata_dict.items()} + assert domain_model.created_at == now + assert domain_model.finished_at is None + + def test_domain_model_without_offload_data(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + repo = _make_repo(db_session_with_containers, account, str(uuid4())) + + process_data = {"normal": "data"} + db_model = WorkflowNodeExecutionModel() + db_model.id = str(uuid4()) + db_model.tenant_id = account.current_tenant_id + db_model.app_id = str(uuid4()) + db_model.workflow_id = str(uuid4()) + db_model.triggered_from = "workflow-run" + db_model.workflow_run_id = None + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_execution_id = str(uuid4()) + db_model.node_id = "test-node-id" + db_model.node_type = "llm" + db_model.title = "Test Node" + db_model.inputs = None + db_model.process_data = json.dumps(process_data) + db_model.outputs = None + db_model.status = "succeeded" + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.execution_metadata = "{}" + db_model.created_at = datetime.now() + db_model.created_by_role = "account" + db_model.created_by = account.id + db_model.finished_at = None + + domain_model = repo._to_domain_model(db_model) + + assert domain_model.process_data == process_data + assert domain_model.process_data_truncated is False + assert domain_model.get_truncated_process_data() is None diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index c5e9201ee37..d6f0657380b 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -7,12 +7,12 @@ from datetime import timedelta from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/tests/integration_tests/vdb/matrixone/__init__.py b/api/tests/test_containers_integration_tests/services/rag_pipeline/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/matrixone/__init__.py rename to api/tests/test_containers_integration_tests/services/rag_pipeline/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py new file mode 100644 index 00000000000..8fc1809a467 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py @@ -0,0 +1,255 @@ +""" +Integration tests for RagPipelineService methods that interact with the database. + +Migrated from unit_tests/services/rag_pipeline/test_rag_pipeline_service.py, replacing +db.session.scalar/commit/delete mocker patches with real PostgreSQL operations. + +Covers: +- get_pipeline: Dataset and Pipeline lookups +- update_customized_pipeline_template: find + unique-name check + commit +- delete_customized_pipeline_template: find + delete + commit +""" + +from collections.abc import Generator +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate +from models.enums import DataSourceType +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class TestRagPipelineServiceGetPipeline: + """Integration tests for RagPipelineService.get_pipeline.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _make_service(self, flask_app_with_containers) -> RagPipelineService: + with ( + patch( + "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository", + return_value=None, + ), + patch( + "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=None, + ), + ): + session_factory = sessionmaker(bind=flask_app_with_containers.extensions["sqlalchemy"].engine) + return RagPipelineService(session_maker=session_factory) + + def _create_pipeline(self, db_session: Session, tenant_id: str, created_by: str) -> Pipeline: + pipeline = Pipeline( + tenant_id=tenant_id, + name=f"Pipeline {uuid4()}", + description="", + created_by=created_by, + ) + db_session.add(pipeline) + db_session.flush() + return pipeline + + def _create_dataset( + self, db_session: Session, tenant_id: str, created_by: str, pipeline_id: str | None = None + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=f"Dataset {uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by, + pipeline_id=pipeline_id, + ) + db_session.add(dataset) + db_session.flush() + return dataset + + def test_get_pipeline_raises_when_dataset_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline raises ValueError when dataset does not exist.""" + service = self._make_service(flask_app_with_containers) + + with pytest.raises(ValueError, match="Dataset not found"): + service.get_pipeline(tenant_id=str(uuid4()), dataset_id=str(uuid4())) + + def test_get_pipeline_raises_when_pipeline_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline raises ValueError when dataset exists but has no linked pipeline.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=None) + db_session_with_containers.flush() + + service = self._make_service(flask_app_with_containers) + + with pytest.raises(ValueError, match="(Dataset not found|Pipeline not found)"): + service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id) + + def test_get_pipeline_returns_pipeline_when_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline returns the Pipeline when both Dataset and Pipeline exist.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + pipeline = self._create_pipeline(db_session_with_containers, tenant_id, created_by) + dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=pipeline.id) + db_session_with_containers.flush() + + service = self._make_service(flask_app_with_containers) + + result = service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id) + + assert result.id == pipeline.id + + +class TestUpdateCustomizedPipelineTemplate: + """Integration tests for RagPipelineService.update_customized_pipeline_template.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_template( + self, db_session: Session, tenant_id: str, created_by: str, name: str = "Template" + ) -> PipelineCustomizedTemplate: + template = PipelineCustomizedTemplate( + tenant_id=tenant_id, + name=name, + description="Original description", + chunk_structure="fixed_size", + icon={"type": "emoji", "value": "📄"}, + position=1, + yaml_content="{}", + install_count=0, + language="en-US", + created_by=created_by, + ) + db_session.add(template) + db_session.flush() + return template + + def test_update_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + """update_customized_pipeline_template updates name and description.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template = self._create_template(db_session_with_containers, tenant_id, created_by) + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="Updated Name", + description="Updated description", + icon_info=IconInfo(icon="🔥"), + ) + result = RagPipelineService.update_customized_pipeline_template(template.id, info) + + assert result.name == "Updated Name" + assert result.description == "Updated description" + + def test_update_template_raises_when_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """update_customized_pipeline_template raises ValueError when template doesn't exist.""" + fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="New Name", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.update_customized_pipeline_template(str(uuid4()), info) + + def test_update_template_raises_on_duplicate_name( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """update_customized_pipeline_template raises ValueError when new name already exists.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template1 = self._create_template(db_session_with_containers, tenant_id, created_by, name="Original") + self._create_template(db_session_with_containers, tenant_id, created_by, name="Duplicate") + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="Duplicate", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Template name is already exists"): + RagPipelineService.update_customized_pipeline_template(template1.id, info) + + +class TestDeleteCustomizedPipelineTemplate: + """Integration tests for RagPipelineService.delete_customized_pipeline_template.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_template(self, db_session: Session, tenant_id: str, created_by: str) -> PipelineCustomizedTemplate: + template = PipelineCustomizedTemplate( + tenant_id=tenant_id, + name=f"Template {uuid4()}", + description="Description", + chunk_structure="fixed_size", + icon={"type": "emoji", "value": "📄"}, + position=1, + yaml_content="{}", + install_count=0, + language="en-US", + created_by=created_by, + ) + db_session.add(template) + db_session.flush() + return template + + def test_delete_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + """delete_customized_pipeline_template removes the template from the DB.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template = self._create_template(db_session_with_containers, tenant_id, created_by) + template_id = template.id + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + RagPipelineService.delete_customized_pipeline_template(template_id) + + # Verify the record is deleted within the same context + from sqlalchemy import select + + from extensions.ext_database import db as ext_db + + remaining = ext_db.session.scalar( + select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id) + ) + assert remaining is None + + def test_delete_template_raises_when_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """delete_customized_pipeline_template raises ValueError when template doesn't exist.""" + fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.delete_customized_pipeline_template(str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index cc9596d15fb..9a53ff087c5 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -9,7 +9,7 @@ from werkzeug.exceptions import Unauthorized from configs import dify_config from controllers.console.error import AccountNotFound, NotAllowedCreateWorkspace -from models import AccountStatus, TenantAccountJoin +from models import AccountStatus, TenantAccountJoin, TenantStatus from services.account_service import AccountService, RegisterService, TenantService, TokenPair from services.errors.account import ( AccountAlreadyInTenantError, @@ -2851,7 +2851,7 @@ class TestRegisterService: interface_language="en-US", password=existing_pending_member_password, ) - existing_account.status = "pending" + existing_account.status = AccountStatus.PENDING db_session_with_containers.commit() @@ -2941,7 +2941,7 @@ class TestRegisterService: interface_language="en-US", password=already_in_tenant_password, ) - existing_account.status = "active" + existing_account.status = AccountStatus.ACTIVE db_session_with_containers.commit() @@ -3331,7 +3331,7 @@ class TestRegisterService: TenantService.create_tenant_member(tenant, account, role="normal") # Change tenant status to non-normal - tenant.status = "archive" + tenant.status = TenantStatus.ARCHIVE db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 4f3c0e42000..00a2f9a59f4 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -842,7 +842,6 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from graphon.file import FileTransferMethod, FileType - from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 33955d5d84b..77ce28b999f 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -1,56 +1,104 @@ +from __future__ import annotations + +import base64 import json +from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest import yaml from faker import Faker -from models.model import App, AppModelConfig +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) +from extensions.ext_redis import redis_client +from graphon.enums import BuiltinNodeTypes +from models import Account, AppMode +from models.model import AppModelConfig, IconType +from services import app_dsl_service from services.account_service import AccountService, TenantService -from services.app_dsl_service import AppDslService, ImportMode, ImportStatus +from services.app_dsl_service import ( + CHECK_DEPENDENCIES_REDIS_KEY_PREFIX, + CURRENT_DSL_VERSION, + DSL_MAX_SIZE, + IMPORT_INFO_REDIS_EXPIRY, + IMPORT_INFO_REDIS_KEY_PREFIX, + AppDslService, + CheckDependenciesPendingData, + ImportMode, + ImportStatus, + PendingData, + _check_version_compatibility, +) from services.app_service import AppService from tests.test_containers_integration_tests.helpers import generate_valid_password +_DEFAULT_TENANT_ID = "00000000-0000-0000-0000-000000000001" +_DEFAULT_ACCOUNT_ID = "00000000-0000-0000-0000-000000000002" + + +def _account_mock(*, tenant_id: str = _DEFAULT_TENANT_ID, account_id: str = _DEFAULT_ACCOUNT_ID) -> MagicMock: + account = MagicMock(spec=Account) + account.current_tenant_id = tenant_id + account.id = account_id + return account + + +def _yaml_dump(data: dict) -> str: + return yaml.safe_dump(data, allow_unicode=True) + + +def _workflow_yaml(*, version: str = CURRENT_DSL_VERSION) -> str: + return _yaml_dump( + { + "version": version, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + } + ) + + +def _pending_yaml_content(version: str = "99.0.0") -> bytes: + return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode() + class TestAppDslService: """Integration tests for AppDslService using testcontainers.""" + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + @pytest.fixture def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( patch("services.app_dsl_service.WorkflowService") as mock_workflow_service, patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service, - patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service, - patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy, - patch("services.app_dsl_service.redis_client") as mock_redis_client, patch("services.app_dsl_service.app_was_created") as mock_app_was_created, - patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, ): - # Setup default mock returns mock_workflow_service.return_value.get_draft_workflow.return_value = None mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock() mock_dependencies_service.generate_latest_dependencies.return_value = [] mock_dependencies_service.get_leaked_dependencies.return_value = [] mock_dependencies_service.generate_dependencies.return_value = [] - mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None - mock_ssrf_proxy.get.return_value.content = b"test content" - mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None - mock_redis_client.setex.return_value = None - mock_redis_client.get.return_value = None - mock_redis_client.delete.return_value = None mock_app_was_created.send.return_value = None - mock_app_model_config_was_updated.send.return_value = None - # Mock ModelManager for app service mock_model_instance = mock_model_manager.return_value mock_model_instance.get_default_model_instance.return_value = None - mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + mock_model_instance.get_default_provider_model_name.return_value = ( + "openai", + "gpt-3.5-turbo", + ) - # Mock FeatureService and EnterpriseService mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None @@ -58,34 +106,16 @@ class TestAppDslService: yield { "workflow_service": mock_workflow_service, "dependencies_service": mock_dependencies_service, - "draft_variable_service": mock_draft_variable_service, - "ssrf_proxy": mock_ssrf_proxy, - "redis_client": mock_redis_client, "app_was_created": mock_app_was_created, - "app_model_config_was_updated": mock_app_model_config_was_updated, "model_manager": mock_model_manager, "feature_service": mock_feature_service, "enterprise_service": mock_enterprise_service, } def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): - """ - Helper method to create a test app and account for testing. - - Args: - db_session_with_containers: Database session from testcontainers infrastructure - mock_external_service_dependencies: Mock dependencies - - Returns: - tuple: (app, account) - Created app and account instances - """ fake = Faker() - - # Setup mocks for account creation with patch("services.account_service.FeatureService") as mock_account_feature_service: mock_account_feature_service.get_system_features.return_value.is_allow_register = True - - # Create account and tenant first account = AccountService.create_account( email=fake.email(), name=fake.name(), @@ -94,8 +124,6 @@ class TestAppDslService: ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant - - # Setup app creation arguments app_args = { "name": fake.company(), "description": fake.text(max_nb_chars=100), @@ -106,17 +134,11 @@ class TestAppDslService: "api_rph": 100, "api_rpm": 10, } - - # Create app app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) - return app, account - def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"): - """ - Helper method to create simple YAML content for testing. - """ + def _create_simple_yaml_content(self, app_name: str = "Test App", app_mode: str = "chat") -> str: yaml_data = { "version": "0.3.0", "kind": "app", @@ -145,88 +167,739 @@ class TestAppDslService: } return yaml.dump(yaml_data, allow_unicode=True) - def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test app import with missing YAML content. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + # ── Version Compatibility ───────────────────────────────────────── - # Import app without YAML content - dsl_service = AppDslService(db_session_with_containers) - result = dsl_service.import_app( - account=account, - import_mode=ImportMode.YAML_CONTENT, - name="Missing Content App", - ) + def test_check_version_compatibility_invalid_version_returns_failed(self): + assert _check_version_compatibility("not-a-version") == ImportStatus.FAILED - # Verify import failed - assert result.status == ImportStatus.FAILED - assert result.app_id is None - assert "yaml_content is required" in result.error - assert result.imported_dsl_version == "" + def test_check_version_compatibility_newer_version_returns_pending(self): + assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING - # Verify no app was created in database - apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() - assert apps_count == 1 # Only the original test app + def test_check_version_compatibility_major_older_returns_pending(self, monkeypatch): + monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0") + assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING - def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test app import with missing YAML URL. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + def test_check_version_compatibility_minor_older_returns_completed_with_warnings( + self, + ): + assert _check_version_compatibility("0.5.0") == ImportStatus.COMPLETED_WITH_WARNINGS - # Import app without YAML URL - dsl_service = AppDslService(db_session_with_containers) - result = dsl_service.import_app( - account=account, - import_mode=ImportMode.YAML_URL, - name="Missing URL App", - ) + def test_check_version_compatibility_equal_returns_completed(self): + assert _check_version_compatibility(CURRENT_DSL_VERSION) == ImportStatus.COMPLETED - # Verify import failed - assert result.status == ImportStatus.FAILED - assert result.app_id is None - assert "yaml_url is required" in result.error - assert result.imported_dsl_version == "" + # ── Import: Validation ──────────────────────────────────────────── - # Verify no app was created in database - apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() - assert apps_count == 1 # Only the original test app - - def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test app import with invalid import mode. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Create YAML content - yaml_content = self._create_simple_yaml_content(fake.company(), "chat") - - # Import app with invalid mode should raise ValueError - dsl_service = AppDslService(db_session_with_containers) - with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"): - dsl_service.import_app( - account=account, + def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Invalid import_mode"): + service.import_app( + account=_account_mock(), import_mode="invalid-mode", - yaml_content=yaml_content, - name="Invalid Mode App", + yaml_content="version: '0.1.0'", ) - # Verify no app was created in database - apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() - assert apps_count == 1 # Only the original test app + def test_import_app_missing_yaml_content(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=None, + ) + assert result.status == ImportStatus.FAILED + assert "yaml_content is required" in result.error - def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful DSL export for chat app. - """ - fake = Faker() + def test_import_app_missing_yaml_url(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url=None, + ) + assert result.status == ImportStatus.FAILED + assert "yaml_url is required" in result.error + + def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content="[]", + ) + assert result.status == ImportStatus.FAILED + assert "content must be a mapping" in result.error + + def test_import_app_version_not_str_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}}) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=yaml_content, + ) + assert result.status == ImportStatus.FAILED + assert "Invalid version type" in result.error + + def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump({"version": "0.6.0", "kind": "app"}), + ) + assert result.status == ImportStatus.FAILED + assert "Missing app data" in result.error + + def test_import_app_yaml_error_returns_failed(self, db_session_with_containers, monkeypatch): + def bad_safe_load(_content: str): + raise yaml.YAMLError("bad") + + monkeypatch.setattr(app_dsl_service.yaml, "safe_load", bad_safe_load) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content="x: y", + ) + assert result.status == ImportStatus.FAILED + assert result.error.startswith("Invalid YAML format:") + + def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers, monkeypatch): + monkeypatch.setattr( + AppDslService, + "_create_or_update_app", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("oops")), + ) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + ) + assert result.status == ImportStatus.FAILED + assert result.error == "oops" + + # ── Import: YAML URL ────────────────────────────────────────────── + + def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers, monkeypatch): + monkeypatch.setattr( + app_dsl_service.ssrf_proxy, + "get", + lambda _url, **_kw: (_ for _ in ()).throw(RuntimeError("boom")), + ) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/a.yml", + ) + assert result.status == ImportStatus.FAILED + assert "Error fetching YAML from URL: boom" in result.error + + def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers, monkeypatch): + response = MagicMock() + response.content = b"" + response.raise_for_status.return_value = None + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/a.yml", + ) + assert result.status == ImportStatus.FAILED + assert "Empty content" in result.error + + def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers, monkeypatch): + response = MagicMock() + response.content = b"x" * (DSL_MAX_SIZE + 1) + response.raise_for_status.return_value = None + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/a.yml", + ) + assert result.status == ImportStatus.FAILED + assert "File size exceeds" in result.error + + def test_import_app_yaml_url_user_attachments_keeps_original_url(self, db_session_with_containers, monkeypatch): + yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml" + yaml_bytes = _pending_yaml_content() + + requested_urls: list[str] = [] + + def fake_get(url: str, **kwargs): + requested_urls.append(url) + response = MagicMock() + response.content = yaml_bytes + response.raise_for_status.return_value = None + return response + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url=yaml_url, + ) + + assert result.status == ImportStatus.PENDING + assert result.imported_dsl_version == "99.0.0" + assert requested_urls == [yaml_url] + + def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers, monkeypatch): + yaml_url = "https://github.com/acme/repo/blob/main/app.yml" + raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml" + yaml_bytes = _pending_yaml_content() + + requested_urls: list[str] = [] + + def fake_get(url: str, **kwargs): + requested_urls.append(url) + assert url == raw_url + response = MagicMock() + response.content = yaml_bytes + response.raise_for_status.return_value = None + return response + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url=yaml_url, + ) + + assert result.status == ImportStatus.PENDING + assert requested_urls == [raw_url] + + # ── Import: App ID checks ──────────────────────────────────────── + + def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id=str(uuid4()), + ) + assert result.status == ImportStatus.FAILED + assert result.error == "App not found" + + def test_import_app_overwrite_only_allows_workflow_and_advanced_chat( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + assert app.mode == "chat" + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id=app.id, + ) + assert result.status == ImportStatus.FAILED + assert "Only workflow or advanced chat apps" in result.error + + # ── Import: Flow ────────────────────────────────────────────────── + + def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(version="99.0.0"), + name="n", + description="d", + icon_type="emoji", + icon="i", + icon_background="#000000", + ) + assert result.status == ImportStatus.PENDING + assert result.imported_dsl_version == "99.0.0" + + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{result.id}" + stored = redis_client.get(redis_key) + assert stored is not None + + def test_import_app_completed_uses_declared_dependencies( + self, db_session_with_containers, mock_external_service_dependencies + ): + _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + dependencies_payload = [ + { + "type": "package", + "value": { + "plugin_unique_identifier": "langgenius/google", + "version": "1.0.0", + }, + } + ] + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump( + { + "version": CURRENT_DSL_VERSION, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + "dependencies": dependencies_payload, + } + ), + ) + + assert result.status == ImportStatus.COMPLETED + assert result.app_id is not None + + @pytest.mark.parametrize("has_workflow", [True, False]) + def test_import_app_legacy_versions_extract_dependencies( + self, db_session_with_containers, monkeypatch, has_workflow: bool + ): + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_workflow_graph", + lambda *_args, **_kwargs: ["from-workflow"], + ) + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_model_config", + lambda *_args, **_kwargs: ["from-model-config"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_latest_dependencies", + lambda deps: [SimpleNamespace(model_dump=lambda: {"dep": deps[0]})], + ) + + created_app = SimpleNamespace( + id=str(uuid4()), + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + ) + monkeypatch.setattr( + AppDslService, + "_create_or_update_app", + lambda *_args, **_kwargs: created_app, + ) + + draft_var_service = MagicMock() + monkeypatch.setattr( + app_dsl_service, + "WorkflowDraftVariableService", + lambda *args, **kwargs: draft_var_service, + ) + + data: dict = { + "version": "0.1.5", + "kind": "app", + "app": {"name": "Legacy", "mode": AppMode.WORKFLOW.value}, + } + if has_workflow: + data["workflow"] = {"graph": {"nodes": []}, "features": {}} + else: + data["model_config"] = {"model": {"provider": "openai"}} + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump(data), + ) + assert result.status == ImportStatus.COMPLETED_WITH_WARNINGS + draft_var_service.delete_app_workflow_variables.assert_called_once_with(app_id=created_app.id) + + # ── Confirm Import ──────────────────────────────────────────────── + + def test_confirm_import_expired_returns_failed(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=str(uuid4()), account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "expired" in result.error + + def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers, monkeypatch): + import_id = str(uuid4()) + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + + pending = PendingData( + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + name="name", + description="desc", + icon_type="emoji", + icon="🤖", + icon_background="#fff", + app_id=None, + ) + redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, pending.model_dump_json()) + + created_app = SimpleNamespace( + id=str(uuid4()), + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + ) + monkeypatch.setattr( + AppDslService, + "_create_or_update_app", + lambda *_args, **_kwargs: created_app, + ) + + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=import_id, account=_account_mock()) + assert result.status == ImportStatus.COMPLETED + assert result.app_id == created_app.id + assert redis_client.get(redis_key) is None + + def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers): + import_id = str(uuid4()) + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "123") + + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=import_id, account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "validation error" in result.error + + def test_confirm_import_exception_returns_failed(self, db_session_with_containers): + import_id = str(uuid4()) + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "not-valid-json") + + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=import_id, account=_account_mock()) + assert result.status == ImportStatus.FAILED + + # ── Check Dependencies ──────────────────────────────────────────── + + def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + app_model = SimpleNamespace(id=str(uuid4()), tenant_id=_DEFAULT_TENANT_ID) + result = service.check_dependencies(app_model=app_model) + assert result.leaked_dependencies == [] + + def test_check_dependencies_calls_analysis_service(self, db_session_with_containers, monkeypatch): + app_id = str(uuid4()) + pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id) + redis_client.setex( + f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_id}", + IMPORT_INFO_REDIS_EXPIRY, + pending.model_dump_json(), + ) + + dep = app_dsl_service.PluginDependency.model_validate( + { + "type": "package", + "value": { + "plugin_unique_identifier": "acme/foo", + "version": "1.0.0", + }, + } + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [dep], + ) + + service = AppDslService(db_session_with_containers) + result = service.check_dependencies(app_model=SimpleNamespace(id=app_id, tenant_id=_DEFAULT_TENANT_ID)) + assert len(result.leaked_dependencies) == 1 + + def test_check_dependencies_with_real_app(self, db_session_with_containers, mock_external_service_dependencies): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' + redis_client.setex( + f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}", + IMPORT_INFO_REDIS_EXPIRY, + mock_dependencies_json, + ) + + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.check_dependencies(app_model=app) + assert result.leaked_dependencies == [] + + # ── Create/Update App ───────────────────────────────────────────── + + def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="loss app mode"): + service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock()) + + def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers, monkeypatch): + fixed_now = object() + monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now) + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_environment_variable_from_mapping", + lambda _m: SimpleNamespace(kind="env"), + ) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_conversation_variable_from_mapping", + lambda _m: SimpleNamespace(kind="conv"), + ) + + app = SimpleNamespace( + id=str(uuid4()), + tenant_id=_DEFAULT_TENANT_ID, + mode=AppMode.WORKFLOW.value, + name="old", + description="old-desc", + icon_type=IconType.EMOJI, + icon="old-icon", + icon_background="#111111", + updated_by=None, + updated_at=None, + app_model_config=None, + ) + service = AppDslService(db_session_with_containers) + updated = service._create_or_update_app( + app=app, + data={ + "app": { + "mode": AppMode.WORKFLOW.value, + "name": "yaml-name", + "icon_type": IconType.IMAGE, + "icon": "X", + }, + "workflow": {"graph": {"nodes": []}, "features": {}}, + }, + account=_account_mock(), + name="override-name", + description=None, + icon_background="#222222", + ) + assert updated is app + assert app.name == "override-name" + assert app.icon_type == IconType.IMAGE + assert app.icon == "X" + assert app.icon_background == "#222222" + assert app.updated_at is fixed_now + + def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers): + account = _account_mock() + account.current_tenant_id = None + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Current tenant is not set"): + service._create_or_update_app( + app=None, + data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}}, + account=account, + ) + + def test_create_or_update_app_creates_workflow_app_and_saves_dependencies( + self, db_session_with_containers, mock_external_service_dependencies + ): + _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + mock_wf_svc = mock_external_service_dependencies["workflow_service"] + mock_wf_svc.return_value.get_draft_workflow.return_value = MagicMock(unique_hash="uh") + + service = AppDslService(db_session_with_containers) + deps = [ + app_dsl_service.PluginDependency.model_validate( + { + "type": "package", + "value": { + "plugin_unique_identifier": "acme/foo", + "version": "1.0.0", + }, + } + ) + ] + data = { + "app": {"mode": AppMode.WORKFLOW.value, "name": "n"}, + "workflow": { + "graph": {"nodes": []}, + "features": {}, + }, + } + + app = service._create_or_update_app(app=None, data=data, account=account, dependencies=deps) + + assert app.tenant_id == account.current_tenant_id + mock_external_service_dependencies["app_was_created"].send.assert_called_once() + mock_wf_svc.return_value.sync_draft_workflow.assert_called_once() + + stored = redis_client.get(f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}") + assert stored is not None + + def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Missing workflow data"): + service._create_or_update_app( + app=SimpleNamespace( + id=str(uuid4()), + tenant_id=_DEFAULT_TENANT_ID, + mode=AppMode.WORKFLOW.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.WORKFLOW.value}}, + account=_account_mock(), + ) + + def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Missing model_config"): + service._create_or_update_app( + app=SimpleNamespace( + id=str(uuid4()), + tenant_id=_DEFAULT_TENANT_ID, + mode=AppMode.CHAT.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.CHAT.value}}, + account=_account_mock(), + ) + + def test_create_or_update_app_chat_creates_model_config_and_sends_event( + self, db_session_with_containers, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + app.app_model_config_id = None + db_session_with_containers.commit() + + service = AppDslService(db_session_with_containers) + service._create_or_update_app( + app=app, + data={ + "app": {"mode": AppMode.CHAT.value}, + "model_config": {"model": {"provider": "openai"}}, + }, + account=account, + ) + + db_session_with_containers.expire_all() + assert app.app_model_config_id is not None + + def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Invalid app mode"): + service._create_or_update_app( + app=SimpleNamespace( + id=str(uuid4()), + tenant_id=_DEFAULT_TENANT_ID, + mode=AppMode.RAG_PIPELINE.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.RAG_PIPELINE.value}}, + account=_account_mock(), + ) + + # ── Export ───────────────────────────────────────────────────────── + + def test_export_dsl_delegates_by_mode(self, monkeypatch): + workflow_calls: list[bool] = [] + model_calls: list[bool] = [] + monkeypatch.setattr( + AppDslService, + "_append_workflow_export_data", + lambda **_kwargs: workflow_calls.append(True), + ) + monkeypatch.setattr( + AppDslService, + "_append_model_config_export_data", + lambda *_args, **_kwargs: model_calls.append(True), + ) + + workflow_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=None, + ) + AppDslService.export_dsl(workflow_app) + assert workflow_calls == [True] + + chat_app = SimpleNamespace( + mode=AppMode.CHAT.value, + tenant_id=_DEFAULT_TENANT_ID, + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}), + ) + AppDslService.export_dsl(chat_app) + assert model_calls == [True] + + def test_export_dsl_preserves_icon_and_icon_type(self, monkeypatch): + monkeypatch.setattr( + AppDslService, + "_append_workflow_export_data", + lambda **_kwargs: None, + ) + + emoji_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + name="Emoji App", + icon="🎨", + icon_type=IconType.EMOJI, + icon_background="#FF5733", + description="App with emoji icon", + use_icon_as_answer_icon=True, + app_model_config=None, + ) + yaml_output = AppDslService.export_dsl(emoji_app) + data = yaml.safe_load(yaml_output) + assert data["app"]["icon"] == "🎨" + assert data["app"]["icon_type"] == "emoji" + assert data["app"]["icon_background"] == "#FF5733" + + image_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + name="Image App", + icon="https://example.com/icon.png", + icon_type=IconType.IMAGE, + icon_background="#FFEAD5", + description="App with image icon", + use_icon_as_answer_icon=False, + app_model_config=None, + ) + yaml_output = AppDslService.export_dsl(image_app) + data = yaml.safe_load(yaml_output) + assert data["app"]["icon"] == "https://example.com/icon.png" + assert data["app"]["icon_type"] == "image" + assert data["app"]["icon_background"] == "#FFEAD5" + + def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - # Create model config for the app model_config = AppModelConfig( app_id=app.id, provider="openai", @@ -247,53 +920,38 @@ class TestAppDslService: created_by=account.id, updated_by=account.id, ) - model_config.id = fake.uuid4() - - # Set the app_model_config_id to link the config + model_config.id = str(uuid4()) app.app_model_config_id = model_config.id db_session_with_containers.add(model_config) db_session_with_containers.commit() - # Export DSL exported_dsl = AppDslService.export_dsl(app, include_secret=False) - - # Parse exported YAML exported_data = yaml.safe_load(exported_dsl) - # Verify exported data structure assert exported_data["kind"] == "app" assert exported_data["app"]["name"] == app.name assert exported_data["app"]["mode"] == app.mode - assert exported_data["app"]["icon"] == app.icon - assert exported_data["app"]["icon_background"] == app.icon_background - assert exported_data["app"]["description"] == app.description - - # Verify model config was exported assert "model_config" in exported_data - # The exported model_config structure may be different from the database structure - # Check that the model config exists and has the expected content - assert exported_data["model_config"] is not None - - # Verify dependencies were exported assert "dependencies" in exported_data - assert isinstance(exported_data["dependencies"], list) def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful DSL export for workflow app. - """ - fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Update app to workflow mode app.mode = "workflow" db_session_with_containers.commit() - # Mock workflow service to return a workflow mock_workflow = MagicMock() mock_workflow.to_dict.return_value = { - "graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []}, + "graph": { + "nodes": [ + { + "id": "start", + "type": "start", + "data": {"type": "start"}, + } + ], + "edges": [], + }, "features": {}, "environment_variables": [], "conversation_variables": [], @@ -302,54 +960,40 @@ class TestAppDslService: "workflow_service" ].return_value.get_draft_workflow.return_value = mock_workflow - # Export DSL exported_dsl = AppDslService.export_dsl(app, include_secret=False) - - # Parse exported YAML exported_data = yaml.safe_load(exported_dsl) - # Verify exported data structure assert exported_data["kind"] == "app" - assert exported_data["app"]["name"] == app.name assert exported_data["app"]["mode"] == "workflow" - - # Verify workflow was exported assert "workflow" in exported_data - assert "graph" in exported_data["workflow"] - assert "nodes" in exported_data["workflow"]["graph"] - - # Verify dependencies were exported assert "dependencies" in exported_data - assert isinstance(exported_data["dependencies"], list) - - # Verify workflow service was called - mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( - app, None - ) def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful DSL export with specific workflow ID. - """ - fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Update app to workflow mode app.mode = "workflow" db_session_with_containers.commit() - # Mock workflow service to return a workflow when specific workflow_id is provided mock_workflow = MagicMock() mock_workflow.to_dict.return_value = { - "graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []}, + "graph": { + "nodes": [ + { + "id": "start", + "type": "start", + "data": {"type": "start"}, + } + ], + "edges": [], + }, "features": {}, "environment_variables": [], "conversation_variables": [], } - # Mock the get_draft_workflow method to return different workflows based on workflow_id - def mock_get_draft_workflow(app_model, workflow_id=None): - if workflow_id == "specific-workflow-id": + workflow_id = str(uuid4()) + + def mock_get_draft_workflow(app_model, wf_id=None): + if wf_id == workflow_id: return mock_workflow return None @@ -357,78 +1001,351 @@ class TestAppDslService: "workflow_service" ].return_value.get_draft_workflow.side_effect = mock_get_draft_workflow - # Export DSL with specific workflow ID - exported_dsl = AppDslService.export_dsl(app, include_secret=False, workflow_id="specific-workflow-id") - - # Parse exported YAML + exported_dsl = AppDslService.export_dsl(app, include_secret=False, workflow_id=workflow_id) exported_data = yaml.safe_load(exported_dsl) - # Verify exported data structure assert exported_data["kind"] == "app" - assert exported_data["app"]["name"] == app.name - assert exported_data["app"]["mode"] == "workflow" - - # Verify workflow was exported assert "workflow" in exported_data - assert "graph" in exported_data["workflow"] - assert "nodes" in exported_data["workflow"]["graph"] - - # Verify dependencies were exported - assert "dependencies" in exported_data - assert isinstance(exported_data["dependencies"], list) - - # Verify workflow service was called with specific workflow ID - mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( - app, "specific-workflow-id" - ) def test_export_dsl_with_invalid_workflow_id_raises_error( self, db_session_with_containers, mock_external_service_dependencies ): - """ - Test that export_dsl raises error when invalid workflow ID is provided. - """ - fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Update app to workflow mode app.mode = "workflow" db_session_with_containers.commit() - # Mock workflow service to return None when invalid workflow ID is provided mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.return_value = None - # Export DSL with invalid workflow ID should raise ValueError - with pytest.raises(ValueError, match="Missing draft workflow configuration, please check."): - AppDslService.export_dsl(app, include_secret=False, workflow_id="invalid-workflow-id") + with pytest.raises( + ValueError, + match="Missing draft workflow configuration, please check.", + ): + AppDslService.export_dsl(app, include_secret=False, workflow_id=str(uuid4())) - # Verify workflow service was called with the invalid workflow ID - mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( - app, "invalid-workflow-id" + # ── Workflow Export Data ─────────────────────────────────────────── + + def test_append_workflow_export_data_filters_and_overrides(self, monkeypatch): + workflow_dict = { + "graph": { + "nodes": [ + { + "data": { + "type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, + "dataset_ids": ["d1", "d2"], + } + }, + { + "data": { + "type": BuiltinNodeTypes.TOOL, + "credential_id": "secret", + } + }, + { + "data": { + "type": BuiltinNodeTypes.AGENT, + "agent_parameters": {"tools": {"value": [{"credential_id": "secret"}]}}, + } + }, + { + "data": { + "type": TRIGGER_SCHEDULE_NODE_TYPE, + "config": {"x": 1}, + } + }, + { + "data": { + "type": TRIGGER_WEBHOOK_NODE_TYPE, + "webhook_url": "x", + "webhook_debug_url": "y", + } + }, + { + "data": { + "type": TRIGGER_PLUGIN_NODE_TYPE, + "subscription_id": "s", + } + }, + ] + } + } + + workflow = SimpleNamespace(to_dict=lambda *, include_secret: workflow_dict) + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + monkeypatch.setattr( + AppDslService, + "encrypt_dataset_id", + lambda *, dataset_id, tenant_id: f"enc:{tenant_id}:{dataset_id}", + ) + monkeypatch.setattr( + app_dsl_service.TriggerScheduleNode, + "get_default_config", + lambda: {"config": {"default": True}}, + ) + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_workflow", + lambda *_args, **_kwargs: ["dep-1"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace( + model_dump=lambda: { + "tenant": tenant_id, + "dep": dependencies[0], + } + ) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + export_data: dict = {} + AppDslService._append_workflow_export_data( + export_data=export_data, + app_model=SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID), + include_secret=False, + workflow_id=None, ) - def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful dependency checking. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + nodes = export_data["workflow"]["graph"]["nodes"] + assert nodes[0]["data"]["dataset_ids"] == [ + f"enc:{_DEFAULT_TENANT_ID}:d1", + f"enc:{_DEFAULT_TENANT_ID}:d2", + ] + assert "credential_id" not in nodes[1]["data"] + assert "credential_id" not in nodes[2]["data"]["agent_parameters"]["tools"]["value"][0] + assert nodes[3]["data"]["config"] == {"default": True} + assert nodes[4]["data"]["webhook_url"] == "" + assert nodes[4]["data"]["webhook_debug_url"] == "" + assert nodes[5]["data"]["subscription_id"] == "" + assert export_data["dependencies"] == [{"tenant": _DEFAULT_TENANT_ID, "dep": "dep-1"}] - # Mock Redis to return dependencies - mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' - mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json + def test_append_workflow_export_data_missing_workflow_raises(self, monkeypatch): + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) - # Check dependencies - dsl_service = AppDslService(db_session_with_containers) - result = dsl_service.check_dependencies(app_model=app) + with pytest.raises(ValueError, match="Missing draft workflow configuration"): + AppDslService._append_workflow_export_data( + export_data={}, + app_model=SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID), + include_secret=False, + workflow_id=None, + ) - # Verify result - assert result.leaked_dependencies == [] + # ── Model Config Export Data ────────────────────────────────────── - # Verify Redis was queried - mock_external_service_dependencies["redis_client"].get.assert_called_once_with( - f"app_check_dependencies:{app.id}" + def test_append_model_config_export_data_filters_credential_id(self, monkeypatch): + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_model_config", + lambda *_args, **_kwargs: ["dep-1"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace( + model_dump=lambda: { + "tenant": tenant_id, + "dep": dependencies[0], + } + ) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + app_model_config = SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": [{"credential_id": "secret"}]}}) + app_model = SimpleNamespace(tenant_id=_DEFAULT_TENANT_ID, app_model_config=app_model_config) + export_data: dict = {} + + AppDslService._append_model_config_export_data(export_data, app_model) + assert export_data["model_config"]["agent_mode"]["tools"] == [{}] + assert export_data["dependencies"] == [{"tenant": _DEFAULT_TENANT_ID, "dep": "dep-1"}] + + def test_append_model_config_export_data_requires_app_config(self): + with pytest.raises(ValueError, match="Missing app configuration"): + AppDslService._append_model_config_export_data({}, SimpleNamespace(app_model_config=None)) + + # ── Dependency Extraction ───────────────────────────────────────── + + def test_extract_dependencies_from_workflow_graph_covers_all_node_types(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", ) - # Verify dependencies service was called - mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once() + monkeypatch.setattr( + app_dsl_service.ToolNodeData, + "model_validate", + lambda _d: SimpleNamespace(provider_id="p1"), + ) + monkeypatch.setattr( + app_dsl_service.LLMNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m1")), + ) + monkeypatch.setattr( + app_dsl_service.QuestionClassifierNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m2")), + ) + monkeypatch.setattr( + app_dsl_service.ParameterExtractorNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m3")), + ) + + def kr_validate(_d): + return SimpleNamespace( + retrieval_mode="multiple", + multiple_retrieval_config=SimpleNamespace( + reranking_mode="weighted_score", + weights=SimpleNamespace(vector_setting=SimpleNamespace(embedding_provider_name="m4")), + reranking_model=None, + ), + single_retrieval_config=None, + ) + + monkeypatch.setattr( + app_dsl_service.KnowledgeRetrievalNodeData, + "model_validate", + kr_validate, + ) + + graph = { + "nodes": [ + {"data": {"type": BuiltinNodeTypes.TOOL}}, + {"data": {"type": BuiltinNodeTypes.LLM}}, + {"data": {"type": BuiltinNodeTypes.QUESTION_CLASSIFIER}}, + {"data": {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR}}, + {"data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}, + {"data": {"type": "unknown"}}, + ] + } + + deps = AppDslService._extract_dependencies_from_workflow_graph(graph) + assert deps == [ + "tool:p1", + "model:m1", + "model:m2", + "model:m3", + "model:m4", + ] + + def test_extract_dependencies_from_workflow_graph_handles_exceptions(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.ToolNodeData, + "model_validate", + lambda _d: (_ for _ in ()).throw(ValueError("bad")), + ) + deps = AppDslService._extract_dependencies_from_workflow_graph( + {"nodes": [{"data": {"type": BuiltinNodeTypes.TOOL}}]} + ) + assert deps == [] + + def test_extract_dependencies_from_model_config_parses_providers(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + + deps = AppDslService._extract_dependencies_from_model_config( + { + "model": {"provider": "p1"}, + "dataset_configs": { + "datasets": {"datasets": [{"reranking_model": {"reranking_provider_name": {"provider": "p2"}}}]} + }, + "agent_mode": {"tools": [{"provider_id": "t1"}]}, + } + ) + assert deps == ["model:p1", "model:p2", "tool:t1"] + + def test_extract_dependencies_from_model_config_handles_exceptions(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda _p: (_ for _ in ()).throw(ValueError("bad")), + ) + deps = AppDslService._extract_dependencies_from_model_config({"model": {"provider": "p1"}}) + assert deps == [] + + # ── Leaked Dependencies ─────────────────────────────────────────── + + def test_get_leaked_dependencies_empty_returns_empty(self): + assert AppDslService.get_leaked_dependencies(_DEFAULT_TENANT_ID, []) == [] + + def test_get_leaked_dependencies_delegates(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [SimpleNamespace(tenant_id=tenant_id, deps=dependencies)], + ) + res = AppDslService.get_leaked_dependencies(_DEFAULT_TENANT_ID, [SimpleNamespace(id="x")]) + assert len(res) == 1 + + # ── Encryption/Decryption ───────────────────────────────────────── + + def test_encrypt_decrypt_dataset_id_respects_config(self, monkeypatch): + tenant_id = _DEFAULT_TENANT_ID + dataset_uuid = "00000000-0000-0000-0000-000000000000" + + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + False, + ) + assert AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) == dataset_uuid + + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + True, + ) + encrypted = AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) + assert encrypted != dataset_uuid + assert base64.b64decode(encrypted.encode()) + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=tenant_id) == dataset_uuid + + def test_decrypt_dataset_id_returns_plain_uuid_unchanged(self): + value = "00000000-0000-0000-0000-000000000000" + assert AppDslService.decrypt_dataset_id(encrypted_data=value, tenant_id=_DEFAULT_TENANT_ID) == value + + def test_decrypt_dataset_id_returns_none_on_invalid_data(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + True, + ) + assert AppDslService.decrypt_dataset_id(encrypted_data="not-base64", tenant_id=_DEFAULT_TENANT_ID) is None + + def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + True, + ) + encrypted = AppDslService.encrypt_dataset_id(dataset_id="not-a-uuid", tenant_id=_DEFAULT_TENANT_ID) + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=_DEFAULT_TENANT_ID) is None + + # ── Utility ─────────────────────────────────────────────────────── + + def test_is_valid_uuid_handles_bad_inputs(self): + assert AppDslService._is_valid_uuid("00000000-0000-0000-0000-000000000000") is True + assert AppDslService._is_valid_uuid("nope") is False diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index fa57dd4a6ff..b695ae9fd9e 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -658,15 +658,17 @@ class TestAppService: # Update app icon new_icon = "🌟" new_icon_background = "#FFD93D" + new_icon_type = "image" mock_current_user = create_autospec(Account, instance=True) mock_current_user.id = account.id mock_current_user.current_tenant_id = account.current_tenant_id with patch("services.app_service.current_user", mock_current_user): - updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) + updated_app = app_service.update_app_icon(app, new_icon, new_icon_background, new_icon_type) assert updated_app.icon == new_icon assert updated_app.icon_background == new_icon_background + assert str(updated_app.icon_type).lower() == new_icon_type assert updated_app.updated_by == account.id # Verify other fields remain unchanged diff --git a/api/tests/test_containers_integration_tests/services/test_audio_service_db.py b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py new file mode 100644 index 00000000000..2593b53fe84 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py @@ -0,0 +1,211 @@ +""" +Integration tests for AudioService.transcript_tts message-ID path. + +Migrated from unit_tests/services/test_audio_service.py, replacing +db.session.get mock patches with real Message rows persisted in PostgreSQL. + +Covers: +- transcript_tts with valid message_id that resolves to a real Message +- transcript_tts returns None for invalid (non-UUID) message_id +- transcript_tts returns None when message_id is a valid UUID but no row exists +- transcript_tts returns None when message exists but has an empty answer +""" + +from collections.abc import Generator +from decimal import Decimal +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from models.account import TenantAccountJoin +from models.enums import ConversationFromSource, MessageStatus +from models.model import App, AppMode, Conversation, Message +from services.audio_service import AudioService +from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation(db_session: Session, app: App, account_id: str) -> Conversation: + """Create a Conversation row via flush() so the rollback-based teardown can remove it.""" + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.WEB_APP.value, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=None, + from_account_id=account_id, + dialogue_count=0, + is_deleted=False, + ) + db_session.add(conversation) + db_session.flush() + return conversation + + +def _create_message( + db_session: Session, + app: App, + conversation: Conversation, + account_id: str, + *, + answer: str = "Message answer text", + status: MessageStatus | str = MessageStatus.NORMAL, +) -> Message: + """Create a Message row via flush() so the rollback-based teardown can remove it.""" + message = Message( + app_id=app.id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation.id, + inputs={}, + query="Test query", + message={"messages": [{"role": "user", "content": "Test query"}]}, + message_tokens=0, + message_unit_price=Decimal(0), + message_price_unit=Decimal("0.001"), + answer=answer, + answer_tokens=0, + answer_unit_price=Decimal(0), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=0, + total_price=Decimal(0), + currency="USD", + status=status, + invoke_from=InvokeFrom.WEB_APP.value, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=None, + from_account_id=account_id, + ) + db_session.add(message) + db_session.flush() + return message + + +class TestAudioServiceTranscriptTTSMessageLookup: + """Integration tests for AudioService.transcript_tts message-ID lookup via real DB.""" + + @pytest.fixture(autouse=True) + def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Track rows created by shared helpers that commit, then clean up after the test. + + The shared console helpers (create_console_account_and_tenant, create_console_app) + commit their inserts so the rows survive a simple rollback. This fixture records + the app/account/tenant created per test and explicitly deletes them after the test + so the DB does not accumulate state across tests. Conversation/Message rows are + created via flush() only, so the trailing rollback removes them. + """ + self._committed_rows: list = [] + yield + db_session_with_containers.rollback() + for entity in reversed(self._committed_rows): + db_session_with_containers.execute(delete(type(entity)).where(type(entity).id == entity.id)) + db_session_with_containers.commit() + + def _setup_app_and_account(self, db_session: Session) -> tuple[App, str, str]: + """Create committed app/account/tenant using shared helpers and track them for cleanup.""" + account, tenant = create_console_account_and_tenant(db_session) + app = create_console_app(db_session, tenant_id=tenant.id, account_id=account.id, mode=AppMode.CHAT) + + # Track rows in the order they must be deleted (FK-safe: app and join before account/tenant) + self._committed_rows.append(app) + join = db_session.scalar( + select(TenantAccountJoin).where( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant.id, + ) + ) + if join is not None: + self._committed_rows.append(join) + self._committed_rows.extend([account, tenant]) + return app, account.id, tenant.id + + def test_transcript_tts_with_message_id_success(self, db_session_with_containers: Session) -> None: + """transcript_tts invokes TTS with the message answer when message_id resolves to a real row.""" + app, account_id, _ = self._setup_app_and_account(db_session_with_containers) + conversation = _create_conversation(db_session_with_containers, app, account_id) + message = _create_message( + db_session_with_containers, + app, + conversation, + account_id, + answer="Hello from message", + ) + + mock_model_instance = MagicMock() + mock_model_instance.invoke_tts.return_value = b"audio from message" + mock_model_manager = MagicMock() + mock_model_manager.get_default_model_instance.return_value = mock_model_instance + + with patch("services.audio_service.ModelManager.for_tenant", return_value=mock_model_manager): + result = AudioService.transcript_tts( + app_model=app, + message_id=message.id, + voice="en-US-Neural", + ) + + assert result == b"audio from message" + mock_model_instance.invoke_tts.assert_called_once_with( + content_text="Hello from message", + voice="en-US-Neural", + ) + + def test_transcript_tts_returns_none_for_invalid_message_id(self, db_session_with_containers: Session) -> None: + """transcript_tts returns None immediately when message_id is not a valid UUID.""" + app, _, _ = self._setup_app_and_account(db_session_with_containers) + + result = AudioService.transcript_tts( + app_model=app, + message_id="invalid-uuid", + ) + + assert result is None + + def test_transcript_tts_returns_none_for_nonexistent_message(self, db_session_with_containers: Session) -> None: + """transcript_tts returns None when message_id is a valid UUID but no Message row exists.""" + app, _, _ = self._setup_app_and_account(db_session_with_containers) + + result = AudioService.transcript_tts( + app_model=app, + message_id=str(uuid4()), + ) + + assert result is None + + def test_transcript_tts_returns_none_for_empty_message_answer(self, db_session_with_containers: Session) -> None: + """transcript_tts returns None when the resolved message has an empty answer.""" + app, account_id, _ = self._setup_app_and_account(db_session_with_containers) + conversation = _create_conversation(db_session_with_containers, app, account_id) + message = _create_message( + db_session_with_containers, + app, + conversation, + account_id, + answer="", + status=MessageStatus.NORMAL, + ) + + result = AudioService.transcript_tts( + app_model=app, + message_id=message.id, + ) + + assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_billing_service.py b/api/tests/test_containers_integration_tests/services/test_billing_service.py index 76708b36b1c..8092c7ad757 100644 --- a/api/tests/test_containers_integration_tests/services/test_billing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_billing_service.py @@ -1,9 +1,13 @@ import json +from collections.abc import Generator from unittest.mock import patch +from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from services.billing_service import BillingService @@ -363,3 +367,62 @@ class TestBillingServiceGetPlanBulkWithCache: assert ttl_1_new <= 600 assert ttl_2 > 0 assert ttl_2 <= 600 + + +class TestBillingServiceIsTenantOwnerOrAdmin: + """ + Integration tests for BillingService.is_tenant_owner_or_admin. + + Verifies that non-privileged roles (EDITOR, DATASET_OPERATOR) raise ValueError + when checked against real TenantAccountJoin rows in PostgreSQL. + """ + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_account_with_tenant_role(self, db_session: Session, role: TenantAccountRole) -> tuple[Account, Tenant]: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session.add(tenant) + db_session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"billing_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session.add(join) + db_session.flush() + + # Wire up in-memory reference so current_tenant_id resolves + account._current_tenant = tenant + return account, tenant + + def test_is_tenant_owner_or_admin_editor_role_raises_error(self, db_session_with_containers: Session) -> None: + """is_tenant_owner_or_admin raises ValueError for EDITOR role.""" + account, _ = self._create_account_with_tenant_role(db_session_with_containers, TenantAccountRole.EDITOR) + + with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"): + BillingService.is_tenant_owner_or_admin(account) + + def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self, db_session_with_containers: Session) -> None: + """is_tenant_owner_or_admin raises ValueError for DATASET_OPERATOR role.""" + account, _ = self._create_account_with_tenant_role( + db_session_with_containers, TenantAccountRole.DATASET_OPERATOR + ) + + with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"): + BillingService.is_tenant_owner_or_admin(account) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 6180d98b1ef..98c38f2b5fa 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -637,6 +637,40 @@ class TestConversationServiceSummarization: assert conversation.name == new_name assert conversation.updated_at == mock_time + @patch("services.conversation_service.LLMGenerator.generate_conversation_name") + def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers): + """ + Test rename delegates to auto_generate_name when auto_generate is True. + + When auto_generate is True, the service should call auto_generate_name + which uses an LLM to create a descriptive conversation title. + """ + # Arrange + app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + conversation = ConversationServiceIntegrationTestDataFactory.create_conversation( + db_session_with_containers, app_model, user + ) + ConversationServiceIntegrationTestDataFactory.create_message( + db_session_with_containers, app_model, conversation, user + ) + generated_name = "Auto Generated Name" + mock_llm_generator.return_value = generated_name + + # Act + result = ConversationService.rename( + app_model=app_model, + conversation_id=conversation.id, + user=user, + name=None, + auto_generate=True, + ) + + # Assert + assert result == conversation + assert conversation.name == generated_name + class TestConversationServiceMessageAnnotation: """ @@ -1066,3 +1100,32 @@ class TestConversationServiceExport: not_deleted = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation.id)) assert not_deleted is not None mock_delete_task.delay.assert_not_called() + + @patch("services.conversation_service.delete_conversation_related_data") + def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers): + """ + Test that delete propagates exceptions and does not trigger the cleanup task. + + When a DB error occurs during deletion, the service must rollback the + transaction and re-raise the exception without scheduling async cleanup. + """ + # Arrange + app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + conversation = ConversationServiceIntegrationTestDataFactory.create_conversation( + db_session_with_containers, app_model, user + ) + conversation_id = conversation.id + + # Act — force an error during the delete to exercise the rollback path + with patch("services.conversation_service.db.session.delete", side_effect=Exception("DB error")): + with pytest.raises(Exception, match="DB error"): + ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user) + + # Assert — async cleanup must NOT have been scheduled + mock_delete_task.delay.assert_not_called() + + # Conversation is still present because the deletion was never committed + still_there = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation_id)) + assert still_there is not None diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py new file mode 100644 index 00000000000..0b7bd9ca645 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py @@ -0,0 +1,524 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import sessionmaker + +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from graphon.variables import FloatVariable, IntegerVariable, StringVariable +from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource +from models.model import App, Conversation, EndUser +from models.workflow import ConversationVariable +from services.conversation_service import ConversationService +from services.errors.conversation import ( + ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, + LastConversationNotExistsError, +) + + +class ConversationServiceVariableIntegrationFactory: + @staticmethod + def create_app_and_account(db_session_with_containers): + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"conversation-variable-{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session_with_containers.add(tenant_join) + db_session_with_containers.flush() + + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + return app, account + + @staticmethod + def create_end_user(db_session_with_containers, app: App): + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + type=InvokeFrom.SERVICE_API.value, + external_user_id=f"external-{uuid4()}", + name=f"End User {uuid4()}", + is_anonymous=False, + session_id=f"session-{uuid4()}", + ) + db_session_with_containers.add(end_user) + db_session_with_containers.commit() + return end_user + + @staticmethod + def create_conversation( + db_session_with_containers, + app: App, + user: Account | EndUser, + *, + name: str | None = None, + invoke_from: InvokeFrom = InvokeFrom.WEB_APP, + created_at: datetime | None = None, + updated_at: datetime | None = None, + ) -> Conversation: + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=name or f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=invoke_from.value, + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, + from_end_user_id=user.id if isinstance(user, EndUser) else None, + from_account_id=user.id if isinstance(user, Account) else None, + dialogue_count=0, + is_deleted=False, + ) + conversation.inputs = {} + if created_at is not None: + conversation.created_at = created_at + if updated_at is not None: + conversation.updated_at = updated_at + + db_session_with_containers.add(conversation) + db_session_with_containers.commit() + return conversation + + @staticmethod + def create_variable( + db_session_with_containers, + *, + app: App, + conversation: Conversation, + variable: StringVariable | FloatVariable | IntegerVariable, + created_at: datetime | None = None, + ) -> ConversationVariable: + row = ConversationVariable.from_variable(app_id=app.id, conversation_id=conversation.id, variable=variable) + if created_at is not None: + row.created_at = created_at + row.updated_at = created_at + + db_session_with_containers.add(row) + db_session_with_containers.commit() + return row + + +@pytest.fixture +def real_conversation_service_session_factory(flask_app_with_containers): + del flask_app_with_containers + real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + + with ( + patch("services.conversation_service.session_factory.create_session", side_effect=lambda: real_session_maker()), + patch("services.conversation_service.session_factory.get_session_maker", return_value=real_session_maker), + ): + yield + + +class TestConversationServiceVariables: + def test_get_conversational_variable_success( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + older_time = datetime(2024, 1, 1, 12, 0, 0) + newer_time = older_time + timedelta(minutes=5) + + first_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + created_at=older_time, + ) + second_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="priority", value="high"), + created_at=newer_time, + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=None, + ) + + assert [item["id"] for item in result.data] == [first_variable.id, second_variable.id] + assert [item["name"] for item in result.data] == ["topic", "priority"] + assert result.limit == 10 + assert result.has_more is False + + def test_get_conversational_variable_with_last_id( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + base_time = datetime(2024, 1, 1, 9, 0, 0) + + first_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + created_at=base_time, + ) + second_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="priority", value="high"), + created_at=base_time + timedelta(minutes=1), + ) + third_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="owner", value="alice"), + created_at=base_time + timedelta(minutes=2), + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=first_variable.id, + ) + + assert [item["id"] for item in result.data] == [second_variable.id, third_variable.id] + assert result.has_more is False + + def test_get_conversational_variable_last_id_not_found_raises_error( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=str(uuid4()), + ) + + def test_get_conversational_variable_sets_has_more( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + for index in range(3): + factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name=f"var_{index}", value=f"value_{index}"), + created_at=datetime(2024, 1, 1, 10, 0, index), + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=2, + last_id=None, + ) + + assert len(result.data) == 2 + assert result.has_more is True + + def test_update_conversation_variable_success( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + ) + updated_at = datetime(2024, 1, 1, 15, 0, 0) + + with patch("services.conversation_service.naive_utc_now", return_value=updated_at): + result = ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value="support", + ) + + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id)) + + assert persisted is not None + assert persisted.to_variable().value == "support" + assert result["id"] == existing.id + assert result["value"] == "support" + assert result["updated_at"] == updated_at + + def test_update_conversation_variable_not_found_raises_error( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=str(uuid4()), + user=account, + new_value="support", + ) + + def test_update_conversation_variable_type_mismatch_raises_error( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=FloatVariable(id=str(uuid4()), name="score", value=1.5), + ) + + with pytest.raises(ConversationVariableTypeMismatchError, match="expects float"): + ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value="wrong-type", + ) + + def test_update_conversation_variable_integer_number_compatibility( + self, db_session_with_containers, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=IntegerVariable(id=str(uuid4()), name="attempts", value=1), + ) + + result = ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value=42, + ) + + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id)) + + assert persisted is not None + assert persisted.to_variable().value == 42 + assert result["value"] == 42 + + +class TestConversationServicePaginationWithContainers: + def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + + with pytest.raises(LastConversationNotExistsError): + ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=str(uuid4()), + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + base_time = datetime(2024, 1, 1, 8, 0, 0) + newest = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Newest", + updated_at=base_time + timedelta(minutes=2), + ) + middle = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Middle", + updated_at=base_time + timedelta(minutes=1), + ) + oldest = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Oldest", + updated_at=base_time, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=middle.id, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + ) + + assert newest.id != middle.id + assert [conversation.id for conversation in result.data] == [oldest.id] + + def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha") + beta = factory.create_conversation(db_session_with_containers, app, account, name="Beta") + gamma = factory.create_conversation(db_session_with_containers, app, account, name="Gamma") + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=beta.id, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + sort_by="name", + ) + + assert alpha.id != beta.id + assert [conversation.id for conversation in result.data] == [gamma.id] + + def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + end_user = factory.create_end_user(db_session_with_containers, app) + account_conversation = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Console Conversation", + invoke_from=InvokeFrom.WEB_APP, + ) + end_user_conversation = factory.create_conversation( + db_session_with_containers, + app, + end_user, + name="API Conversation", + invoke_from=InvokeFrom.SERVICE_API, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=end_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert account_conversation.id != end_user_conversation.id + assert [conversation.id for conversation in result.data] == [end_user_conversation.id] + + def test_pagination_filters_to_account_console_source(self, db_session_with_containers): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + end_user = factory.create_end_user(db_session_with_containers, app) + account_conversation = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Console Conversation", + invoke_from=InvokeFrom.WEB_APP, + ) + factory.create_conversation( + db_session_with_containers, + app, + end_user, + name="API Conversation", + invoke_from=InvokeFrom.SERVICE_API, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + assert [conversation.id for conversation in result.data] == [account_conversation.id] diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index fb0adbbcc26..02ab3f83146 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,10 +3,10 @@ from uuid import uuid4 import pytest -from graphon.variables import StringVariable from sqlalchemy.orm import sessionmaker from extensions.ext_database import db +from graphon.variables import StringVariable from models.workflow import ConversationVariable from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index f9bfa570cbb..0de3c64c4f7 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py new file mode 100644 index 00000000000..2bec703f0cc --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py @@ -0,0 +1,650 @@ +"""Testcontainers integration tests for SQL-backed DocumentService paths.""" + +import datetime +import json +from unittest.mock import create_autospec, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from core.rag.index_processor.constant.index_type import IndexStructureType +from extensions.storage.storage_type import StorageType +from models import Account +from models.dataset import Dataset, Document +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus +from models.model import UploadFile +from services.dataset_service import DocumentService +from services.errors.account import NoPermissionError + +FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0) + + +class DocumentServiceIntegrationFactory: + @staticmethod + def create_dataset( + db_session_with_containers, + *, + tenant_id: str | None = None, + created_by: str | None = None, + name: str | None = None, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id or str(uuid4()), + name=name or f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by or str(uuid4()), + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers, + *, + dataset: Dataset, + name: str = "doc.txt", + position: int = 1, + tenant_id: str | None = None, + indexing_status: str = IndexingStatus.COMPLETED, + enabled: bool = True, + archived: bool = False, + is_paused: bool = False, + need_summary: bool = False, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + batch: str | None = None, + data_source_type: str = DataSourceType.UPLOAD_FILE, + data_source_info: dict | None = None, + created_by: str | None = None, + ) -> Document: + document = Document( + tenant_id=tenant_id or dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type=data_source_type, + data_source_info=json.dumps(data_source_info or {}), + batch=batch or f"batch-{uuid4()}", + name=name, + created_from=DocumentCreatedFrom.WEB, + created_by=created_by or dataset.created_by, + doc_form=doc_form, + ) + document.indexing_status = indexing_status + document.enabled = enabled + document.archived = archived + document.is_paused = is_paused + document.need_summary = need_summary + if indexing_status == IndexingStatus.COMPLETED: + document.completed_at = FIXED_UPLOAD_CREATED_AT + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + @staticmethod + def create_upload_file( + db_session_with_containers, + *, + tenant_id: str, + created_by: str, + file_id: str | None = None, + name: str = "source.txt", + ) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.LOCAL, + key=f"uploads/{uuid4()}", + name=name, + size=128, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=created_by, + created_at=FIXED_UPLOAD_CREATED_AT, + used=False, + ) + if file_id: + upload_file.id = file_id + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + +@pytest.fixture +def current_user_mock(): + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.id = str(uuid4()) + current_user.current_tenant_id = str(uuid4()) + current_user.current_role = None + yield current_user + + +def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_document(dataset.id, None) is None + + +def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset) + + result = DocumentService.get_document(dataset.id, document.id) + + assert result is not None + assert result.id == document.id + + +def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + result = DocumentService.get_documents_by_ids(dataset.id, []) + + assert result == [] + + +def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt") + doc_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + name="b.txt", + position=2, + ) + + result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id]) + + assert {document.id for document in result} == {doc_a.id, doc_b.id} + + +def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.update_documents_need_summary(dataset.id, []) == 0 + + +def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + paragraph_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + need_summary=True, + ) + qa_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + need_summary=True, + doc_form=IndexStructureType.QA_INDEX, + ) + + updated_count = DocumentService.update_documents_need_summary( + dataset.id, + [paragraph_doc.id, qa_doc.id], + need_summary=False, + ) + + db_session_with_containers.expire_all() + refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id) + refreshed_qa = db_session_with_containers.get(Document, qa_doc.id) + assert updated_count == 1 + assert refreshed_paragraph is not None + assert refreshed_qa is not None + assert refreshed_paragraph.need_summary is False + assert refreshed_qa.need_summary is True + + +def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url: + result = DocumentService.get_document_download_url(document) + + assert result == "signed-url" + get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True) + + +def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_type=DataSourceType.WEBSITE_CRAWL, + data_source_info={"url": "https://example.com"}, + ) + + with pytest.raises(NotFound, match="invalid source"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={}, + ) + + with pytest.raises(NotFound, match="missing file"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": 99}, + ) + + result = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + assert result == "99" + + +def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": "missing-file"}, + ) + + with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}): + with pytest.raises(NotFound, match="Uploaded file not found"): + DocumentService._get_upload_file_for_upload_file_document(document) + + +def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + result = DocumentService._get_upload_file_for_upload_file_document(document) + + assert result.id == upload_file.id + + +def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with pytest.raises(NotFound, match="Document not found"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[str(uuid4())], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + tenant_id=str(uuid4()), + data_source_info={"upload_file_id": upload_file.id}, + ) + + with pytest.raises(Forbidden, match="No permission"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": str(uuid4())}, + ) + + with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + mapping = DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document_a.id, document_b.id], + tenant_id=dataset.tenant_id, + ) + + assert mapping[document_a.id].id == upload_file_a.id + assert mapping[document_b.id].id == upload_file_b.id + + +def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset( + current_user_mock, flask_app_with_containers +): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=str(uuid4()), + document_ids=[str(uuid4())], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + + with patch( + "services.dataset_service.DatasetService.check_dataset_permission", + side_effect=NoPermissionError("denied"), + ): + with pytest.raises(Forbidden, match="denied"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[document_b.id, document_a.id], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id] + assert download_name.endswith(".zip") + + +def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + enabled_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + enabled=False, + ) + + result = DocumentService.get_document_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [enabled_document.id] + + +def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + available_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.COMPLETED, + enabled=True, + archived=False, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.ERROR, + ) + + result = DocumentService.get_working_documents_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [available_document.id] + + +def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + error_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.ERROR, + ) + paused_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.PAUSED, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=3, + indexing_status=IndexingStatus.COMPLETED, + ) + + result = DocumentService.get_error_documents_by_dataset_id(dataset.id) + + assert {document.id for document in result} == {error_document.id, paused_document.id} + + +def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + batch = f"batch-{uuid4()}" + matching_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + batch=batch, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + tenant_id=str(uuid4()), + batch=batch, + ) + + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.current_tenant_id = dataset.tenant_id + result = DocumentService.get_batch_documents(dataset.id, batch) + + assert [document.id for document in result] == [matching_document.id] + + +def test_get_document_file_detail_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + + result = DocumentService.get_document_file_detail(upload_file.id) + + assert result is not None + assert result.id == upload_file.id + + +def test_delete_document_emits_signal_and_commits(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.document_was_deleted.send") as signal_send: + DocumentService.delete_document(document) + + assert db_session_with_containers.get(Document, document.id) is None + signal_send.assert_called_once_with( + document.id, + dataset_id=document.dataset_id, + doc_form=document.doc_form, + file_id=upload_file.id, + ) + + +def test_delete_documents_ignores_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, []) + + delay.assert_not_called() + + +def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX + db_session_with_containers.commit() + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, [document_a.id, document_b.id]) + + assert db_session_with_containers.get(Document, document_a.id) is None + assert db_session_with_containers.get(Document, document_b.id) is None + delay.assert_called_once() + args = delay.call_args.args + assert args[0] == [document_a.id, document_b.id] + assert args[1] == dataset.id + assert set(args[3]) == {upload_file_a.id, upload_file_b.id} + + +def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3) + + assert DocumentService.get_documents_position(dataset.id) == 4 + + +def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_documents_position(dataset.id) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py new file mode 100644 index 00000000000..1b4179c9c70 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py @@ -0,0 +1,613 @@ +"""Testcontainers integration tests for DatasetService permission and lifecycle SQL paths.""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import ( + AppDatasetJoin, + Dataset, + DatasetAutoDisableLog, + DatasetCollectionBinding, + DatasetPermission, + DatasetPermissionEnum, +) +from models.enums import DataSourceType +from services.dataset_service import DatasetCollectionBindingService, DatasetPermissionService, DatasetService +from services.errors.account import NoPermissionError + + +class DatasetPermissionIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account, tenant + + @staticmethod + def create_account_in_tenant( + db_session_with_containers: Session, + tenant: Tenant, + role: TenantAccountRole = TenantAccountRole.EDITOR, + ) -> Account: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account + + @staticmethod + def create_dataset( + db_session_with_containers: Session, + *, + tenant_id: str, + created_by: str, + name: str | None = None, + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, + enable_api: bool = True, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=name or f"dataset-{uuid4()}", + description="desc", + data_source_type=DataSourceType.UPLOAD_FILE, + indexing_technique=indexing_technique, + created_by=created_by, + provider="vendor", + permission=permission, + retrieval_model={"top_k": 2}, + ) + dataset.enable_api = enable_api + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_dataset_permission( + db_session_with_containers: Session, + *, + dataset_id: str, + tenant_id: str, + account_id: str, + ) -> DatasetPermission: + permission = DatasetPermission( + dataset_id=dataset_id, + tenant_id=tenant_id, + account_id=account_id, + has_permission=True, + ) + db_session_with_containers.add(permission) + db_session_with_containers.commit() + return permission + + @staticmethod + def create_app_dataset_join( + db_session_with_containers: Session, + *, + dataset_id: str, + ) -> AppDatasetJoin: + join = AppDatasetJoin( + app_id=str(uuid4()), + dataset_id=dataset_id, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return join + + @staticmethod + def create_collection_binding( + db_session_with_containers: Session, + *, + provider_name: str, + model_name: str, + collection_type: str = "dataset", + ) -> DatasetCollectionBinding: + binding = DatasetCollectionBinding( + provider_name=provider_name, + model_name=model_name, + collection_name=f"collection_{uuid4().hex}", + type=collection_type, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + return binding + + @staticmethod + def create_auto_disable_log( + db_session_with_containers: Session, + *, + tenant_id: str, + dataset_id: str, + document_id: str, + ) -> DatasetAutoDisableLog: + log = DatasetAutoDisableLog( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + ) + db_session_with_containers.add(log) + db_session_with_containers.commit() + return log + + +class TestDatasetServicePermissionsAndLifecycle: + def test_delete_dataset_returns_false_when_dataset_is_missing(self, db_session_with_containers: Session): + owner, _tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + + result = DatasetService.delete_dataset(str(uuid4()), user=owner) + + assert result is False + + def test_delete_dataset_checks_permission_and_deletes_dataset(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal: + result = DatasetService.delete_dataset(dataset.id, user=owner) + + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + send_deleted_signal.assert_called_once_with(dataset) + + def test_dataset_use_check_returns_true_when_join_exists(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_app_dataset_join( + db_session_with_containers, + dataset_id=dataset.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is True + + def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is False + + def test_check_dataset_permission_rejects_cross_tenant_access(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + outsider, _other_tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, outsider) + + def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_rejects_partial_team_user_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_allows_partial_team_creator(self, db_session_with_containers: Session): + creator, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=creator.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + DatasetService.check_dataset_permission(dataset, creator) + + def test_check_dataset_permission_allows_partial_team_member_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_operator_permission_rejects_only_me_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_rejects_partial_team_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_allows_partial_team_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=operator.id, + ) + + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetService.update_dataset_api_status(str(uuid4()), True) + + def test_update_dataset_api_status_requires_current_user_id(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(id=None)): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.update_dataset_api_status(dataset.id, True) + + def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + now = datetime(2026, 4, 14, 18, 0, 0) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.naive_utc_now", return_value=now), + ): + DatasetService.update_dataset_api_status(dataset.id, True) + + db_session_with_containers.refresh(dataset) + assert dataset.enable_api is True + assert dataset.updated_by == owner.id + assert dataset.updated_at == now + + def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(str(uuid4())) + + assert result == {"document_ids": [], "count": 0} + + def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(dataset.id) + + assert result["count"] == 2 + assert len(result["document_ids"]) == 2 + + +class TestDatasetCollectionBindingServiceIntegration: + def test_get_dataset_collection_binding_returns_existing_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") + + assert result.id == binding.id + + def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_session_with_containers: Session): + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "missing-model") + + persisted = db_session_with_containers.get(DatasetCollectionBinding, result.id) + assert persisted is not None + assert persisted.provider_name == "provider" + assert persisted.model_name == "missing-model" + assert persisted.type == "dataset" + assert persisted.collection_name + + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers): + with flask_app_with_containers.app_context(): + with pytest.raises(ValueError, match="Dataset collection binding not found"): + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4())) + + def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id) + + assert result.id == binding.id + + +class TestDatasetPermissionServiceIntegration: + def test_get_dataset_partial_member_list_returns_scalar_results(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_a.id, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_b.id, + ) + + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + assert set(result) == {member_a.id, member_b.id} + + def test_update_partial_member_list_replaces_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + stale_member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=stale_member.id, + ) + + DatasetPermissionService.update_partial_member_list( + tenant.id, + dataset.id, + [{"user_id": member_a.id}, {"user_id": member_b.id}], + ) + + permissions = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert {permission.account_id for permission in permissions} == {member_a.id, member_b.id} + + def test_check_permission_requires_dataset_editor(self): + user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="does not have permission"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ALL_TEAM, []) + + def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ONLY_ME, []) + + def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with pytest.raises(ValueError, match="Partial member list is required"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, []) + + def test_check_permission_rejects_dataset_operator_member_list_changes(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + with pytest.raises(ValueError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-2"}], + ) + + def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-1"}], + ) + + def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetPermissionService.clear_partial_member_list(dataset.id) + + remaining = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert remaining == [] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index a814466e14f..ac0483a45d7 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -1,13 +1,21 @@ +import json from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.dataset import Dataset, ExternalKnowledgeBindings +from graphon.model_runtime.entities.model_entities import ModelType +from models.account import ( + Account, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountRole, + TenantStatus, +) +from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings from models.enums import DataSourceType from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -25,12 +33,12 @@ class DatasetUpdateTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() - tenant = Tenant(name=f"tenant-{account.id}", status="normal") + tenant = Tenant(name=f"tenant-{account.id}", status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -103,6 +111,34 @@ class DatasetUpdateTestDataFactory: db_session_with_containers.commit() return binding + @staticmethod + def create_external_knowledge_api( + db_session_with_containers: Session, + tenant_id: str, + created_by: str, + api_id: str | None = None, + name: str = "test-api", + ) -> ExternalKnowledgeApis: + """Create a real external knowledge API template for tenant-scoped update validation.""" + external_api = ExternalKnowledgeApis( + tenant_id=tenant_id, + created_by=created_by, + updated_by=created_by, + name=name, + description="test description", + settings=json.dumps( + { + "endpoint": "https://example.com", + "api_key": "test-api-key", + } + ), + ) + if api_id is not None: + external_api.id = api_id + db_session_with_containers.add(external_api) + db_session_with_containers.commit() + return external_api + class TestDatasetServiceUpdateDataset: """ @@ -138,6 +174,11 @@ class TestDatasetServiceUpdateDataset: ) binding_id = binding.id db_session_with_containers.expunge(binding) + external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api( + db_session_with_containers, + tenant_id=tenant.id, + created_by=user.id, + ) update_data = { "name": "new_name", @@ -145,7 +186,7 @@ class TestDatasetServiceUpdateDataset: "external_retrieval_model": "new_model", "permission": "only_me", "external_knowledge_id": "new_knowledge_id", - "external_knowledge_api_id": str(uuid4()), + "external_knowledge_api_id": external_api.id, } result = DatasetService.update_dataset(dataset.id, update_data, user) @@ -218,11 +259,16 @@ class TestDatasetServiceUpdateDataset: created_by=user.id, provider="external", ) + external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api( + db_session_with_containers, + tenant_id=tenant.id, + created_by=user.id, + ) update_data = { "name": "new_name", "external_knowledge_id": "knowledge_id", - "external_knowledge_api_id": str(uuid4()), + "external_knowledge_api_id": external_api.id, } with pytest.raises(ValueError) as context: diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index c8f04e92159..fe426ae5161 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select +from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index b3e7dd2a59f..315936d7212 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -274,6 +274,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = True mock_config.ALLOW_REGISTER = False mock_config.ALLOW_CREATE_WORKSPACE = False mock_config.MAIL_TYPE = "smtp" @@ -298,6 +299,7 @@ class TestFeatureService: # Verify authentication settings assert result.enable_email_code_login is True assert result.enable_email_password_login is False + assert result.enable_collaboration_mode is True assert result.is_allow_register is False assert result.is_allow_create_workspace is False @@ -401,6 +403,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = False mock_config.ALLOW_REGISTER = True mock_config.ALLOW_CREATE_WORKSPACE = True mock_config.MAIL_TYPE = "smtp" @@ -422,6 +425,7 @@ class TestFeatureService: assert result.enable_email_code_login is True assert result.enable_email_password_login is True assert result.enable_social_oauth_login is False + assert result.enable_collaboration_mode is False assert result.is_allow_register is True assert result.is_allow_create_workspace is True assert result.is_email_setup is True diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py similarity index 51% rename from api/tests/unit_tests/services/test_hit_testing_service.py rename to api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index 80e9729f5bc..f332ba05ec2 100644 --- a/api/tests/unit_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -1,239 +1,193 @@ +from __future__ import annotations + import json from typing import Any, cast from unittest.mock import ANY, MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.models.document import Document -from models.dataset import Dataset +from models.dataset import Dataset, DatasetQuery from services.hit_testing_service import HitTestingService -class TestHitTestingService: - """Test suite for HitTestingService""" +def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset: + tenant_id = str(uuid4()) + created_by = str(uuid4()) + ds = Dataset( + tenant_id=kwargs.get("tenant_id", tenant_id), + name=kwargs.get("name", "test-dataset"), + created_by=kwargs.get("created_by", created_by), + provider=provider, + ) + db_session.add(ds) + db_session.commit() + db_session.refresh(ds) + return ds - # ===== Utility Method Tests ===== + +class TestHitTestingService: + # ── Utility methods (pure logic, no DB) ──────────────────────────── def test_escape_query_for_search_should_escape_double_quotes(self): - """Test that escape_query_for_search escapes double quotes correctly""" - # Arrange query = 'test "query" with quotes' - expected = 'test \\"query\\" with quotes' - - # Act result = HitTestingService.escape_query_for_search(query) - - # Assert - assert result == expected + assert result == 'test \\"query\\" with quotes' def test_hit_testing_args_check_should_pass_with_valid_query(self): - """Test that hit_testing_args_check passes with a valid query""" - # Arrange - args = {"query": "valid query"} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"query": "valid query"}) def test_hit_testing_args_check_should_pass_with_valid_attachments(self): - """Test that hit_testing_args_check passes with valid attachment_ids""" - # Arrange - args = {"attachment_ids": ["id1", "id2"]} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]}) def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): - """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" - # Arrange - args = {} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query or attachment_ids is required" in str(exc_info.value) + with pytest.raises(ValueError, match="Query or attachment_ids is required"): + HitTestingService.hit_testing_args_check({}) def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): - """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" - # Arrange - args = {"query": "a" * 251} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query cannot exceed 250 characters" in str(exc_info.value) + with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): + HitTestingService.hit_testing_args_check({"query": "a" * 251}) def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): - """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" - # Arrange - args = {"attachment_ids": "not a list"} + with pytest.raises(ValueError, match="Attachment_ids must be a list"): + HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"}) - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Attachment_ids must be a list" in str(exc_info.value) - - # ===== Response Formatting Tests ===== + # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") def test_compact_retrieve_response_should_format_correctly(self, mock_format): - """Test that compact_retrieve_response formats the response correctly""" - # Arrange query = "test query" mock_doc = MagicMock(spec=Document) - documents = [mock_doc] mock_record = MagicMock() mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - # Act - result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc])) - # Assert assert cast(dict[str, Any], result["query"])["content"] == query assert len(result["records"]) == 1 assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" - mock_format.assert_called_once_with(documents) + mock_format.assert_called_once_with([mock_doc]) - def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): - """Test that compact_external_retrieve_response returns records when dataset provider is external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "external" - query = "test query" + def test_compact_external_retrieve_response_should_return_records_for_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") documents = [ {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, ] - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert len(result["records"]) == 2 assert cast(dict[str, Any], result["records"][0])["content"] == "c1" assert cast(dict[str, Any], result["records"][1])["title"] == "t2" - def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): - """Test that compact_external_retrieve_response returns empty records for non-external provider""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" - documents = [{"content": "c1"}] + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="vendor") - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], + HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]), + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== External Retrieve Tests ===== + # ── External retrieve (real DB) ──────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): - """Test that external_retrieve successfully retrieves from external provider and commits query""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - dataset.provider = "external" - query = 'test "query"' + def test_external_retrieve_should_succeed_for_external_provider( + self, mock_ext_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") + account_id = str(uuid4()) account = MagicMock() - account.id = "account_id" - + account.id = account_id mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.external_retrieve( dataset=dataset, - query=query, + query='test "query"', account=account, external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == 'test "query"' assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" - - # Verify call to RetrievalService.external_retrieve with escaped query mock_ext_retrieve.assert_called_once_with( - dataset_id="dataset_id", + dataset_id=dataset.id, query='test \\"query\\"', external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ) - # Verify DatasetQuery record was added and committed - mock_add.assert_called_once() - mock_commit.assert_called_once() + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 - def test_external_retrieve_should_return_empty_for_non_external_provider(self): - """Test that external_retrieve returns empty results immediately if provider is not external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" + def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - # Act - result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account)) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== Retrieve Tests ===== + # ── Retrieve (real DB) ───────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve uses default model when retrieval_model is not provided""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" + def test_retrieve_should_use_default_model_when_none_provided( + self, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) dataset.retrieval_model = None - query = "test query" account = MagicMock() - account.id = "account_id" - + account.id = str(uuid4()) mock_retrieve.return_value = [] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={} ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" mock_retrieve.assert_called_once() - # Verify top_k from default_retrieval_model (4) assert mock_retrieve.call_args.kwargs["top_k"] == 4 - mock_commit.assert_called_once() + + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): - """Test that retrieve correctly calls metadata filtering when conditions are present""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_metadata_filtering( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "semantic_search", @@ -242,29 +196,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response - mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string") mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_get_meta.assert_called_once() mock_retrieve.assert_called_once() assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): - """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_return_empty_if_metadata_filtering_fails( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() retrieval_model = { @@ -274,37 +226,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response: condition returned but no IDs mock_get_meta.return_value = ({}, "condition_string") - # Act result = cast( dict[str, Any], HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, ), ) - # Assert assert result["records"] == [] mock_retrieve.assert_not_called() @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) attachment_ids = ["att1", "att2"] retrieval_model = { @@ -315,21 +257,19 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, attachment_ids=attachment_ids, ) - # Assert mock_retrieve.assert_called_once_with( retrieval_method=ANY, - dataset_id="dataset_id", - query=query, + dataset_id=dataset.id, + query="test query", attachment_ids=attachment_ids, top_k=4, score_threshold=0.0, @@ -338,26 +278,27 @@ class TestHitTestingService: weights=None, document_ids_filter=None, ) - # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) - # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) - called_query = mock_add.call_args[0][0] - query_content = json.loads(called_query.content) + + # Verify DatasetQuery was persisted with correct content structure + db_session_with_containers.expire_all() + latest = db_session_with_containers.scalar( + select(DatasetQuery) + .where(DatasetQuery.dataset_id == dataset.id) + .order_by(DatasetQuery.created_at.desc()) + .limit(1) + ) + assert latest is not None + query_content = json.loads(latest.content) assert len(query_content) == 3 # 1 text + 2 images assert query_content[0]["content_type"] == "text_query" assert query_content[1]["content_type"] == "image_query" assert query_content[1]["content"] == "att1" @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve passes reranking and threshold parameters correctly""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "hybrid_search", @@ -371,12 +312,14 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_retrieve.assert_called_once() kwargs = mock_retrieve.call_args.kwargs assert kwargs["score_threshold"] == 0.5 diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index c46b8fba0bd..80f9083e815 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,15 +3,15 @@ import uuid from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, ) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 0f252515f72..ed75363f3ba 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -5,17 +5,17 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.runtime import VariablePool from sqlalchemy.engine import Engine from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) +from graphon.runtime import VariablePool from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 2340dd2a03d..cd63d3ad6c2 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -8,11 +8,11 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.file import FileType from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client +from graphon.file import FileType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ( ConversationFromSource, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index ba926bf6758..8955a3b5f23 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -405,11 +405,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.provider_entities import ProviderEntity - from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( ProviderEntity( @@ -644,9 +643,8 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response - from graphon.model_runtime.entities.common_entities import I18nObject - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + from graphon.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", diff --git a/api/tests/test_containers_integration_tests/services/test_ops_service.py b/api/tests/test_containers_integration_tests/services/test_ops_service.py new file mode 100644 index 00000000000..e2e1a228b28 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_ops_service.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import uuid +from unittest.mock import patch + +import pytest +from faker import Faker +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import TraceAppConfig +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.ops_service import OpsService +from tests.test_containers_integration_tests.helpers import generate_valid_password + + +class TestOpsService: + @pytest.fixture + def mock_external_service_dependencies(self): + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + @pytest.fixture + def mock_ops_trace_manager(self): + with patch("services.ops_service.OpsTraceManager") as mock: + yield mock + + def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies): + fake = Faker() + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + }, + account, + ) + return app, account + + _SENTINEL = object() + + def _insert_trace_config( + self, + db_session: Session, + app_id: str, + provider: str, + tracing_config: dict | None | object = _SENTINEL, + ) -> TraceAppConfig: + trace_config = TraceAppConfig( + app_id=app_id, + tracing_provider=provider, + tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"}, + ) + db_session.add(trace_config) + db_session.commit() + return trace_config + + # ── get_tracing_app_config ───────────────────────────────────────── + + def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, "arize") + result = OpsService.get_tracing_app_config(fake_app_id, "arize") + assert result is None + + def test_get_tracing_app_config_none_config( + self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None) + + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config(app.id, "arize") + + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {} + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == default_url + + @pytest.mark.parametrize( + "provider", + ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"], + ) + def test_get_tracing_app_config_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"} + mock_otm.get_trace_config_project_url.return_value = "success_url" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == "success_url" + + def test_get_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.return_value = "key" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + def test_get_tracing_app_config_langfuse_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + # ── create_tracing_app_config ────────────────────────────────────── + + def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + def test_create_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_ops_trace_manager + ): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + result = OpsService.create_tracing_app_config( + str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"} + ) + assert result == {"error": "Invalid Credentials"} + + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config + ): + # Existing config causes the service to return None before reaching the DB insert + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(provider)) + + result = OpsService.create_tracing_app_config(app.id, provider, config) + + assert result is None + + def test_create_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_key.return_value = "key" + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config( + app.id, + TracingProviderEnum.LANGFUSE, + {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}, + ) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is None + + def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_create_tracing_app_config_with_empty_other_keys( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + # "project" is in other_keys for Arize; providing "" triggers default substitution + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("no url") + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""}) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.return_value = "http://project_url" + mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result == {"result": "success"} + + # ── update_tracing_app_config ────────────────────────────────────── + + def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + + def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE)) + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = False + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + def test_update_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {"updated": "config"} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is not None + assert result["app_id"] == app.id + + # ── delete_tracing_app_config ────────────────────────────────────── + + def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session): + result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_delete_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize") + + result = OpsService.delete_tracing_app_config(app.id, "arize") + + assert result is True + remaining = db_session_with_containers.scalar( + select(TraceAppConfig) + .where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize") + .limit(1) + ) + assert remaining is None diff --git a/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py new file mode 100644 index 00000000000..ccc4188dbf0 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models.model import AccountTrialAppRecord, TrialApp +from services import recommended_app_service as service_module +from services.recommended_app_service import RecommendedAppService + +# ── Helpers ──────────────────────────────────────────────────────────── + + +def _apps_response( + recommended_apps: list[dict] | None = None, + categories: list[str] | None = None, +) -> dict: + if recommended_apps is None: + recommended_apps = [ + {"id": "app-1", "name": "Test App 1", "description": "d1", "category": "productivity"}, + {"id": "app-2", "name": "Test App 2", "description": "d2", "category": "communication"}, + ] + if categories is None: + categories = ["productivity", "communication", "utilities"] + return {"recommended_apps": recommended_apps, "categories": categories} + + +def _app_detail( + app_id: str = "app-123", + name: str = "Test App", + description: str = "Test description", + **kwargs: Any, +) -> dict: + detail: dict[str, Any] = { + "id": app_id, + "name": name, + "description": description, + "category": kwargs.get("category", "productivity"), + "icon": kwargs.get("icon", "🚀"), + "model_config": kwargs.get("model_config", {}), + } + detail.update(kwargs) + return detail + + +def _recommendation_detail(result: dict[str, Any] | None) -> dict[str, Any] | None: + return cast("dict[str, Any] | None", result) + + +def _mock_factory_for_apps( + monkeypatch: pytest.MonkeyPatch, + *, + mode: str, + result: dict[str, Any], + fallback_result: dict[str, Any] | None = None, +) -> tuple[MagicMock, MagicMock]: + retrieval_instance = MagicMock() + retrieval_instance.get_recommended_apps_and_categories.return_value = result + retrieval_factory = MagicMock(return_value=retrieval_instance) + monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", mode, raising=False) + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_recommend_app_factory", + MagicMock(return_value=retrieval_factory), + ) + builtin_instance = MagicMock() + if fallback_result is not None: + builtin_instance.fetch_recommended_apps_from_builtin.return_value = fallback_result + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_buildin_recommend_app_retrieval", + MagicMock(return_value=builtin_instance), + ) + return retrieval_instance, builtin_instance + + +# ── Pure logic tests: get_recommended_apps_and_categories ────────────── + + +class TestRecommendedAppServiceGetApps: + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_success_with_apps(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + expected = _apps_response() + + mock_instance = MagicMock() + mock_instance.get_recommended_apps_and_categories.return_value = expected + mock_factory = MagicMock(return_value=mock_instance) + mock_factory_class.get_recommend_app_factory.return_value = mock_factory + + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + assert result == expected + assert len(result["recommended_apps"]) == 2 + assert len(result["categories"]) == 3 + mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote") + mock_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + empty_response = {"recommended_apps": [], "categories": []} + builtin_response = _apps_response( + recommended_apps=[{"id": "builtin-1", "name": "Builtin App", "category": "default"}] + ) + + mock_remote_instance = MagicMock() + mock_remote_instance.get_recommended_apps_and_categories.return_value = empty_response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_remote_instance) + + mock_builtin_instance = MagicMock() + mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response + mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance + + result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN") + + assert result == builtin_response + assert result["recommended_apps"][0]["id"] == "builtin-1" + mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db" + none_response = {"recommended_apps": None, "categories": ["test"]} + builtin_response = _apps_response() + + mock_db_instance = MagicMock() + mock_db_instance.get_recommended_apps_and_categories.return_value = none_response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_db_instance) + + mock_builtin_instance = MagicMock() + mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response + mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance + + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + assert result == builtin_response + mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once() + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_different_languages(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin" + + for language in ["en-US", "zh-CN", "ja-JP", "fr-FR"]: + lang_response = _apps_response( + recommended_apps=[{"id": f"app-{language}", "name": f"App {language}", "category": "test"}] + ) + mock_instance = MagicMock() + mock_instance.get_recommended_apps_and_categories.return_value = lang_response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = RecommendedAppService.get_recommended_apps_and_categories(language) + + assert result["recommended_apps"][0]["id"] == f"app-{language}" + mock_instance.get_recommended_apps_and_categories.assert_called_with(language) + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_uses_correct_factory_mode(self, mock_config, mock_factory_class): + for mode in ["remote", "builtin", "db"]: + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode + response = _apps_response() + mock_instance = MagicMock() + mock_instance.get_recommended_apps_and_categories.return_value = response + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + RecommendedAppService.get_recommended_apps_and_categories("en-US") + + mock_factory_class.get_recommend_app_factory.assert_called_with(mode) + + +# ── Pure logic tests: get_recommend_app_detail ───────────────────────── + + +class TestRecommendedAppServiceGetDetail: + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_success(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + expected = _app_detail(app_id="app-123", name="Productivity App", description="A great app") + + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = expected + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-123")) + + assert result == expected + assert result["id"] == "app-123" + mock_instance.get_recommend_app_detail.assert_called_once_with("app-123") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_different_modes(self, mock_config, mock_factory_class): + for mode in ["remote", "builtin", "db"]: + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = mode + detail = _app_detail(app_id="test-app", name=f"App from {mode}") + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = detail + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("test-app")) + + assert result["name"] == f"App from {mode}" + mock_factory_class.get_recommend_app_factory.assert_called_with(mode) + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_returns_none_when_not_found(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = None + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("nonexistent")) + + assert result is None + mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent") + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_returns_empty_dict(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "builtin" + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = {} + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("app-empty")) + + assert result == {} + + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config", autospec=True) + def test_complex_model_config(self, mock_config, mock_factory_class): + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + complex_config = { + "provider": "openai", + "model": "gpt-4", + "parameters": {"temperature": 0.7, "max_tokens": 2000, "top_p": 1.0}, + } + expected = _app_detail( + app_id="complex-app", + name="Complex App", + model_config=complex_config, + workflows=["workflow-1", "workflow-2"], + tools=["tool-1", "tool-2", "tool-3"], + ) + mock_instance = MagicMock() + mock_instance.get_recommend_app_detail.return_value = expected + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) + + result = _recommendation_detail(RecommendedAppService.get_recommend_app_detail("complex-app")) + + assert result["model_config"] == complex_config + assert len(result["workflows"]) == 2 + assert len(result["tools"]) == 3 + + +# ── Integration tests: trial app features (real DB) ──────────────────── + + +class TestRecommendedAppServiceTrialFeatures: + def test_get_apps_should_not_query_trial_table_when_disabled( + self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch + ): + expected = {"recommended_apps": [{"app_id": "app-1"}], "categories": ["all"]} + retrieval_instance, builtin_instance = _mock_factory_for_apps(monkeypatch, mode="remote", result=expected) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=False)), + ) + + result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + + assert result == expected + retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") + builtin_instance.fetch_recommended_apps_from_builtin.assert_not_called() + + def test_get_apps_should_enrich_can_trial_when_enabled( + self, db_session_with_containers: Session, monkeypatch: pytest.MonkeyPatch + ): + app_id_1 = str(uuid.uuid4()) + app_id_2 = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + # app_id_1 has a TrialApp record; app_id_2 does not + db_session_with_containers.add(TrialApp(app_id=app_id_1, tenant_id=tenant_id)) + db_session_with_containers.commit() + + remote_result = {"recommended_apps": [], "categories": []} + fallback_result = { + "recommended_apps": [{"app_id": app_id_1}, {"app_id": app_id_2}], + "categories": ["all"], + } + _, builtin_instance = _mock_factory_for_apps( + monkeypatch, mode="remote", result=remote_result, fallback_result=fallback_result + ) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), + ) + + result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP") + + builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") + assert result["recommended_apps"][0]["can_trial"] is True + assert result["recommended_apps"][1]["can_trial"] is False + + @pytest.mark.parametrize("has_trial_app", [True, False]) + def test_get_detail_should_set_can_trial_when_enabled( + self, + db_session_with_containers: Session, + monkeypatch: pytest.MonkeyPatch, + has_trial_app: bool, + ): + app_id = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + if has_trial_app: + db_session_with_containers.add(TrialApp(app_id=app_id, tenant_id=tenant_id)) + db_session_with_containers.commit() + + detail = {"id": app_id, "name": "Test App"} + retrieval_instance = MagicMock() + retrieval_instance.get_recommend_app_detail.return_value = detail + retrieval_factory = MagicMock(return_value=retrieval_instance) + monkeypatch.setattr(service_module.dify_config, "HOSTED_FETCH_APP_TEMPLATES_MODE", "remote", raising=False) + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_recommend_app_factory", + MagicMock(return_value=retrieval_factory), + ) + monkeypatch.setattr( + service_module.FeatureService, + "get_system_features", + MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), + ) + + result = cast(dict[str, Any], RecommendedAppService.get_recommend_app_detail(app_id)) + + assert result["id"] == app_id + assert result["can_trial"] is has_trial_app + + def test_add_trial_app_record_increments_count_for_existing(self, db_session_with_containers: Session): + app_id = str(uuid.uuid4()) + account_id = str(uuid.uuid4()) + + db_session_with_containers.add(AccountTrialAppRecord(app_id=app_id, account_id=account_id, count=3)) + db_session_with_containers.commit() + + RecommendedAppService.add_trial_app_record(app_id, account_id) + + db_session_with_containers.expire_all() + record = db_session_with_containers.scalar( + select(AccountTrialAppRecord) + .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) + .limit(1) + ) + assert record is not None + assert record.count == 4 + + def test_add_trial_app_record_creates_new_record(self, db_session_with_containers: Session): + app_id = str(uuid.uuid4()) + account_id = str(uuid.uuid4()) + + RecommendedAppService.add_trial_app_record(app_id, account_id) + + db_session_with_containers.expire_all() + record = db_session_with_containers.scalar( + select(AccountTrialAppRecord) + .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) + .limit(1) + ) + assert record is not None + assert record.app_id == app_id + assert record.account_id == account_id + assert record.count == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_schedule_service.py b/api/tests/test_containers_integration_tests/services/test_schedule_service.py new file mode 100644 index 00000000000..87f33062582 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_schedule_service.py @@ -0,0 +1,387 @@ +"""Testcontainers integration tests for schedule service SQL-backed behavior.""" + +from datetime import datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate +from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError +from events.event_handlers.sync_workflow_schedule_when_app_published import sync_schedule_from_workflow +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.trigger import WorkflowSchedulePlan +from services.errors.account import AccountNotFoundError +from services.trigger.schedule_service import ScheduleService + + +class ScheduleServiceIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_schedule_plan( + db_session_with_containers: Session, + *, + tenant_id: str, + app_id: str | None = None, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", + next_run_at: datetime | None = None, + ) -> WorkflowSchedulePlan: + schedule = WorkflowSchedulePlan( + tenant_id=tenant_id, + app_id=app_id or str(uuid4()), + node_id=node_id, + cron_expression=cron_expression, + timezone=timezone, + next_run_at=next_run_at, + ) + db_session_with_containers.add(schedule) + db_session_with_containers.commit() + return schedule + + +def _cron_workflow( + *, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", +): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": node_id, + "data": { + "type": "trigger-schedule", + "mode": "cron", + "cron_expression": cron_expression, + "timezone": timezone, + }, + } + ] + } + ) + + +def _no_schedule_workflow(): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": "node-1", + "data": {"type": "llm"}, + } + ] + } + ) + + +class TestScheduleServiceIntegration: + def test_create_schedule_persists_schedule(self, db_session_with_containers: Session): + account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + expected_next_run = datetime(2026, 1, 1, 10, 30, 0) + config = ScheduleConfig( + node_id="start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + schedule = ScheduleService.create_schedule( + session=db_session_with_containers, + tenant_id=tenant.id, + app_id=str(uuid4()), + config=config, + ) + + persisted = db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) + assert persisted is not None + assert persisted.tenant_id == tenant.id + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_update_schedule_updates_fields_and_recomputes_next_run(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + cron_expression="30 10 * * *", + timezone="UTC", + ) + expected_next_run = datetime(2026, 1, 2, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate( + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + db_session_with_containers.refresh(updated) + assert updated.cron_expression == "0 12 * * *" + assert updated.timezone == "America/New_York" + assert updated.next_run_at == expected_next_run + + def test_update_schedule_updates_only_node_id_without_recomputing_time(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + initial_next_run = datetime(2026, 1, 1, 10, 0, 0) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + next_run_at=initial_next_run, + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + calls: list[tuple] = [] + + def _track(*args, **kwargs): + calls.append((args, kwargs)) + return datetime(2026, 1, 9, 10, 0, 0) + + monkeypatch.setattr("services.trigger.schedule_service.calculate_next_run_at", _track) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + db_session_with_containers.refresh(updated) + assert updated.node_id == "node-new" + assert updated.next_run_at == initial_next_run + assert calls == [] + + def test_update_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + def test_delete_schedule_removes_row(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + db_session_with_containers.commit() + + assert db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) is None + + def test_delete_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + def test_get_tenant_owner_returns_owner_account(self, db_session_with_containers: Session): + owner, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.OWNER, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == owner.id + + def test_get_tenant_owner_falls_back_to_admin(self, db_session_with_containers: Session): + admin, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.ADMIN, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == admin.id + + def test_get_tenant_owner_raises_when_account_record_missing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + missing_account_id = str(uuid4()) + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=missing_account_id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=missing_account_id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_get_tenant_owner_raises_when_no_owner_or_admin_found(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=tenant.id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_update_next_run_at_updates_persisted_value(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + expected_next_run = datetime(2026, 1, 3, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + + db_session_with_containers.refresh(schedule) + assert result == expected_next_run + assert schedule.next_run_at == expected_next_run + + def test_update_next_run_at_raises_when_schedule_not_found(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + +class TestSyncScheduleFromWorkflowIntegration: + def test_sync_schedule_create_new(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + expected_next_run = datetime(2026, 1, 4, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow(), + ) + + assert result is not None + persisted = db_session_with_containers.execute( + select(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app_id) + ).scalar_one() + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_update_existing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + node_id="old-start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + existing_id = existing.id + expected_next_run = datetime(2026, 1, 5, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow( + node_id="start", + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + assert result is not None + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(WorkflowSchedulePlan, existing_id) + assert persisted is not None + assert persisted.node_id == "start" + assert persisted.cron_expression == "0 12 * * *" + assert persisted.timezone == "America/New_York" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_remove_when_no_config(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + ) + existing_id = existing.id + + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_no_schedule_workflow(), + ) + + assert result is None + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowSchedulePlan, existing_id) is None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index f504f355890..5a6bf0466ef 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -12,7 +12,13 @@ from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset from models.enums import DataSourceType, TagType from models.model import App, Tag, TagBinding -from services.tag_service import TagService +from services.tag_service import ( + SaveTagPayload, + TagBindingCreatePayload, + TagBindingDeletePayload, + TagService, + UpdateTagPayload, +) class TestTagService: @@ -685,7 +691,7 @@ class TestTagService: db_session_with_containers, mock_external_service_dependencies ) - tag_args = {"name": "test_tag_name", "type": "knowledge"} + tag_args = SaveTagPayload(name="test_tag_name", type="knowledge") # Act: Execute the method under test result = TagService.save_tags(tag_args) @@ -725,7 +731,7 @@ class TestTagService: ) # Create first tag - tag_args = {"name": "duplicate_tag", "type": "app"} + tag_args = SaveTagPayload(name="duplicate_tag", type="app") TagService.save_tags(tag_args) # Act & Assert: Verify proper error handling @@ -749,11 +755,11 @@ class TestTagService: ) # Create a tag to update - tag_args = {"name": "original_name", "type": "knowledge"} + tag_args = SaveTagPayload(name="original_name", type="knowledge") tag = TagService.save_tags(tag_args) # Update args - update_args = {"name": "updated_name", "type": "knowledge"} + update_args = UpdateTagPayload(name="updated_name", type="knowledge") # Act: Execute the method under test result = TagService.update_tags(update_args, tag.id) @@ -793,7 +799,7 @@ class TestTagService: non_existent_tag_id = str(uuid.uuid4()) - update_args = {"name": "updated_name", "type": "knowledge"} + update_args = UpdateTagPayload(name="updated_name", type="knowledge") # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: @@ -817,14 +823,14 @@ class TestTagService: ) # Create two tags - tag1_args = {"name": "first_tag", "type": "app"} + tag1_args = SaveTagPayload(name="first_tag", type="app") tag1 = TagService.save_tags(tag1_args) - tag2_args = {"name": "second_tag", "type": "app"} + tag2_args = SaveTagPayload(name="second_tag", type="app") tag2 = TagService.save_tags(tag2_args) # Try to update second tag with first tag's name - update_args = {"name": "first_tag", "type": "app"} + update_args = UpdateTagPayload(name="first_tag", type="app") # Act & Assert: Verify proper error handling with pytest.raises(ValueError) as exc_info: @@ -988,8 +994,10 @@ class TestTagService: dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Execute the method under test - binding_args = {"type": "knowledge", "target_id": dataset.id, "tag_ids": [tag.id for tag in tags]} - TagService.save_tag_binding(binding_args) + binding_payload = TagBindingCreatePayload( + type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags] + ) + TagService.save_tag_binding(binding_payload) # Assert: Verify the expected outcomes @@ -1030,11 +1038,11 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Create first binding - binding_args = {"type": "app", "target_id": app.id, "tag_ids": [tag.id]} - TagService.save_tag_binding(binding_args) + binding_payload = TagBindingCreatePayload(type="app", target_id=app.id, tag_ids=[tag.id]) + TagService.save_tag_binding(binding_payload) # Act: Try to create duplicate binding - TagService.save_tag_binding(binding_args) + TagService.save_tag_binding(binding_payload) # Assert: Verify the expected outcomes @@ -1071,11 +1079,10 @@ class TestTagService: non_existent_target_id = str(uuid.uuid4()) # Act & Assert: Verify proper error handling - binding_args = {"type": "invalid_type", "target_id": non_existent_target_id, "tag_ids": [tag.id]} + from pydantic import ValidationError - with pytest.raises(NotFound) as exc_info: - TagService.save_tag_binding(binding_args) - assert "Invalid binding type" in str(exc_info.value) + with pytest.raises(ValidationError): + TagBindingCreatePayload(type="invalid_type", target_id=non_existent_target_id, tag_ids=[tag.id]) def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1113,8 +1120,8 @@ class TestTagService: assert binding_before is not None # Act: Execute the method under test - delete_args = {"type": "knowledge", "target_id": dataset.id, "tag_id": tag.id} - TagService.delete_tag_binding(delete_args) + delete_payload = TagBindingDeletePayload(type="knowledge", target_id=dataset.id, tag_id=tag.id) + TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes # Verify tag binding was deleted @@ -1149,8 +1156,8 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Try to delete non-existent binding - delete_args = {"type": "app", "target_id": app.id, "tag_id": tag.id} - TagService.delete_tag_binding(delete_args) + delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_id=tag.id) + TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes # No error should be raised, and database state should remain unchanged diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 4fe65d58032..7825f502f77 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -233,11 +233,10 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None - assert result.password is not None - assert result.password_salt is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None + assert refreshed.password is not None + assert refreshed.password_salt is not None def test_authenticate_account_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -414,9 +413,8 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None def test_get_user_through_email_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py new file mode 100644 index 00000000000..ec10c51e049 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -0,0 +1,507 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import AppTriggerStatus, AppTriggerType +from models.model import App +from models.trigger import AppTrigger, WorkflowWebhookTrigger +from models.workflow import Workflow +from services.errors.app import QuotaExceededError +from services.trigger.webhook_service import WebhookService + + +class WebhookServiceRelationshipFactory: + @staticmethod + def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]: + account = Account( + name=f"Account {uuid4()}", + email=f"webhook-{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App: + app = App( + tenant_id=tenant.id, + name=f"Webhook App {uuid4()}", + description="", + mode="workflow", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + @staticmethod + def create_workflow( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_ids: list[str], + version: str, + ) -> Workflow: + graph = { + "nodes": [ + { + "id": node_id, + "data": { + "type": TRIGGER_WEBHOOK_NODE_TYPE, + "title": f"Webhook {node_id}", + "method": "post", + "content_type": "application/json", + "headers": [], + "params": [], + "body": [], + "status_code": 200, + "response_body": '{"status": "ok"}', + "timeout": 30, + }, + } + for node_id in node_ids + ], + "edges": [], + } + + workflow = Workflow( + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + version=version, + ) + db_session_with_containers.add(workflow) + db_session_with_containers.commit() + return workflow + + @staticmethod + def create_webhook_trigger( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_id: str, + webhook_id: str | None = None, + ) -> WorkflowWebhookTrigger: + webhook_trigger = WorkflowWebhookTrigger( + app_id=app.id, + node_id=node_id, + tenant_id=app.tenant_id, + webhook_id=webhook_id or uuid4().hex[:24], + created_by=account.id, + ) + db_session_with_containers.add(webhook_trigger) + db_session_with_containers.commit() + return webhook_trigger + + @staticmethod + def create_app_trigger( + db_session_with_containers: Session, + *, + app: App, + node_id: str, + status: AppTriggerStatus, + ) -> AppTrigger: + app_trigger = AppTrigger( + tenant_id=app.tenant_id, + app_id=app.id, + node_id=node_id, + trigger_type=AppTriggerType.TRIGGER_WEBHOOK, + provider_name="webhook", + title=f"Webhook {node_id}", + status=status, + ) + db_session_with_containers.add(app_trigger) + db_session_with_containers.commit() + return app_trigger + + +class TestWebhookServiceLookupWithContainers: + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with pytest.raises(ValueError, match="App trigger not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED + ) + + with pytest.raises(ValueError, match="rate limited"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED + ) + + with pytest.raises(ValueError, match="disabled"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED + ) + + with pytest.raises(ValueError, match="Workflow not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["published-node"], + version="2026-04-14.001", + ) + draft_workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["debug-node"], + version=Workflow.VERSION_DRAFT, + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="debug-node" + ) + + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow( + webhook_trigger.webhook_id, + is_debug=True, + ) + + assert got_trigger.id == webhook_trigger.id + assert got_workflow.id == draft_workflow.id + assert got_node_config["id"] == "debug-node" + + +class TestWebhookServiceTriggerExecutionWithContainers: + def test_trigger_workflow_execution_triggers_async_workflow_successfully( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + end_user = SimpleNamespace(id=str(uuid4())) + webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"} + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=end_user, + ), + patch("services.trigger.webhook_service.QuotaType.TRIGGER.consume") as mock_consume, + patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger, + ): + WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) + + mock_consume.assert_called_once_with(webhook_trigger.tenant_id) + mock_trigger.assert_called_once() + trigger_args = mock_trigger.call_args.args + assert trigger_args[1] is end_user + assert trigger_args[2].workflow_id == workflow.id + assert trigger_args[2].root_node_id == webhook_trigger.node_id + + def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=SimpleNamespace(id=str(uuid4())), + ), + patch( + "services.trigger.webhook_service.QuotaType.TRIGGER.consume", + side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1), + ), + patch( + "services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited" + ) as mock_mark_rate_limited, + ): + with pytest.raises(QuotaExceededError): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_mark_rate_limited.assert_called_once_with(tenant.id) + + def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + side_effect=RuntimeError("boom"), + ), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + with pytest.raises(RuntimeError, match="boom"): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_logger_exception.assert_called_once() + + +class TestWebhookServiceRelationshipSyncWithContainers: + def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)] + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT + ) + + with pytest.raises(ValueError, match="maximum webhook node limit"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_raises_when_lock_not_acquired( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = False + + with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock): + with pytest.raises(RuntimeError, match="Failed to acquire lock"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + stale_trigger = factory.create_webhook_trigger( + db_session_with_containers, + app=app, + account=account, + node_id="node-stale", + webhook_id="stale-webhook-id-000001", + ) + stale_trigger_id = stale_trigger.id + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-new"], + version=Workflow.VERSION_DRAFT, + ) + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + db_session_with_containers.expire_all() + records = db_session_with_containers.scalars( + select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id) + ).all() + + assert [record.node_id for record in records] == ["node-new"] + assert records[0].webhook_id == "new-webhook-id-000001" + assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None + + def test_sync_webhook_relationships_sets_redis_cache_for_new_record( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-cache"], + version=Workflow.VERSION_DRAFT, + ) + cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache" + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key) + assert cached_payload is not None + assert cached_payload["node_id"] == "node-cache" + assert cached_payload["webhook_id"] == "cache-webhook-id-00001" + + def test_sync_webhook_relationships_logs_when_lock_release_fails( + self, db_session_with_containers: Session, flask_app_with_containers + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = RuntimeError("release failed") + + with ( + patch("services.trigger.webhook_service.redis_client.lock", return_value=lock), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + WebhookService.sync_webhook_relationships(app, workflow) + + mock_logger_exception.assert_called_once() + + +def _read_cache(cache_key: str) -> dict[str, str] | None: + from extensions.ext_redis import redis_client + + cached = redis_client.get(cache_key) + if not cached: + return None + if isinstance(cached, bytes): + cached = cached.decode("utf-8") + return json.loads(cached) + + +WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 749c6fff5bc..1e57b5603de 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 0c281c8c33b..86cf2327c7f 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker -from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index d3e765055a0..af83adaae01 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -1,3 +1,5 @@ +import inspect +import json from unittest.mock import patch import pytest @@ -6,6 +8,8 @@ from pydantic import TypeAdapter, ValidationError from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.errors import ApiToolProviderNotFoundError +from core.tools.tool_label_manager import ToolLabelManager from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -590,30 +594,204 @@ class TestApiToolManageService: with pytest.raises(ValueError, match="you have not added provider"): ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") - def test_update_api_tool_provider_not_found( + def test_update_api_tool_provider_success( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): - """Test update raises ValueError when original provider not found.""" fake = Faker() + + # Firmware fix for cache.delete() in update flow + mock_encrypter = mock_external_service_dependencies["encrypter"] + from unittest.mock import MagicMock + + mock_cache = MagicMock() + mock_cache.delete.return_value = None + mock_encrypter.return_value = (mock_encrypter, mock_cache) + + # Get fake account and tenant account, tenant = self._create_test_account_and_tenant( db_session_with_containers, mock_external_service_dependencies ) - with pytest.raises(ValueError, match="does not exists"): - ApiToolManageService.update_api_tool_provider( + # original provider name + original_name = "original-provider" + + # Create original provider + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=original_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="", + custom_disclaimer="", + labels=["old-label"], + ) + + # new provide name and new labels for update + new_name = "updated-provider" + new_labels = ["new-label-1", "new-label-2"] + + # Reset mock history so assertions focus on update path only + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + + # Act: Update the provider with new values + result = ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + # new provider name - changed 1 + provider_name=new_name, + original_provider=original_name, + # new icon - changed 2 + icon={"type": "emoji", "value": "🚀"}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + # new privacy policy - changed 3 + privacy_policy="https://new-policy.com", + # new custom disclaimer - changed 4 + custom_disclaimer="New disclaimer", + # new labels - changed 5 (However, we will not verify this, not this layer responsibility.) + labels=new_labels, + ) + + # Assert: Verify the result + assert result == {"result": "success"} + + # Get the updated provider from the database + updated_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == new_name) + .first() + ) + + # Verify the provider was updated successfully + assert updated_provider is not None + + # Manually refresh to keep object detachment + db_session_with_containers.refresh(updated_provider) + # Verify all the updated fields + # - changed 1 + assert updated_provider.name == new_name + # - changed 2 + icon_data = json.loads(updated_provider.icon) + assert icon_data["type"] == "emoji" + assert icon_data["value"] == "🚀" + # - changed 3 + assert updated_provider.privacy_policy == "https://new-policy.com" + # - changed 4 + assert updated_provider.custom_disclaimer == "New disclaimer" + + # Verify old provider name no longer exists after rename + original_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == original_name) + .first() + ) + assert original_provider is None + + # Verify update flow calls critical collaborators + mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + mock_external_service_dependencies["encrypter"].assert_called_once() + mock_cache.delete.assert_called_once() + + # Deeply verify on session propagation of labels update logics: + # Since in refactoring, we pass session down to label manager to keep atomicity. + # The assertion here is to verify this. + sig = inspect.signature(ToolLabelManager.update_tool_labels) + args, kwargs = mock_external_service_dependencies["tool_label_manager"].update_tool_labels.call_args + bound_args = sig.bind(*args, **kwargs) + passed_session = bound_args.arguments.get("session") + # Ensure the type: Session + assert isinstance(passed_session, Session), f"Expected Session object, got {type(passed_session)}" + assert passed_session is not None, ( + "Atomicity Failure: Session cannot be passed to Label Manager in update_api_tool_provider" + ) + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update raises ValueError when original provider not found. + + This test verifies: + - Proper error when trying to update a non-existing original provider + - No accidental upsert/new provider creation + - No external dependency invocation on early failure path + """ + # Arrange: Create test account and tenant + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Keep an existing provider in DB to ensure unrelated data remains unchanged + existing_provider_name = "existing-provider" + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=existing_provider_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="https://existing-policy.com", + custom_disclaimer="Existing disclaimer", + labels=["existing-label"], + ) + + # Reset mock history so assertions focus on update failure path only + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + + # Act & Assert: Verify update fails with clear error message + target_new_name = "new-provider-name" + missing_original_name = "missing-original-provider" + with pytest.raises(ApiToolProviderNotFoundError) as exc_info: + _ = ApiToolManageService.update_api_tool_provider( user_id=account.id, tenant_id=tenant.id, - provider_name="new-name", - original_provider="nonexistent", - icon={}, + provider_name=target_new_name, + original_provider=missing_original_name, + icon={"type": "emoji", "value": "🚀"}, credentials={"auth_type": "none"}, _schema_type=ApiProviderSchemaType.OPENAPI, schema=self._create_test_openapi_schema(), - privacy_policy=None, - custom_disclaimer="", - labels=[], + privacy_policy="https://new-policy.com", + custom_disclaimer="New disclaimer", + labels=["new-label"], ) + error = exc_info.value + assert error.provider_name == missing_original_name + assert error.tenant_id == tenant.id + assert error.error_code == "api_tool_provider_not_found" + + # Assert: Existing provider should remain unchanged + existing_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == existing_provider_name) + .first() + ) + assert existing_provider is not None + assert existing_provider.name == existing_provider_name + + # Assert: No new provider should be created + unexpected_new_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == target_new_name) + .first() + ) + assert unexpected_new_provider is None + + # Assert: Early failure should skip all downstream external interactions + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_not_called() + mock_external_service_dependencies["encrypter"].assert_not_called() + mock_external_service_dependencies["provider_controller"].from_db.assert_not_called() + def test_update_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index ce2fd2eeb1a..ce5c2bd162f 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -5,9 +5,6 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity, VariableEntityType from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -21,6 +18,9 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 7c43bf676b0..4dab895135b 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 4b04c1accb1..fcc15aad426 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Verify logs exist before processing - existing_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + existing_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(existing_logs) == 2 # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify auto disable logs were deleted - remaining_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + remaining_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(remaining_logs) == 0 # Verify index processing occurred normally diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 6cbbe431370..e29ca7ebab4 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,6 +11,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType @@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_with_image_files( @@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Verify that the task completed successfully by checking the log output @@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify database cleanup db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( @@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document_id).limit(1) + ) assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( @@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_multiple_documents( @@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( @@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None except Exception as e: @@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( @@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Verify initial state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 3 + ) + assert ( + db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1)) + is not None + ) # Store original IDs for verification document_id = document.id @@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify final database state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id) + ) + == 0 + ) + assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index f9ae33b32ff..05827112d4a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask: from extensions.ext_redis import redis_client # Clear all test data - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that segments were created - segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(segments) == 3 # Verify segment content and metadata @@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created (since dataset doesn't exist) - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify no documents were modified - documents = db_session_with_containers.query(Document).all() + documents = db_session_with_containers.scalars(select(Document)).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( @@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) db_session_with_containers.refresh(dataset) - segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments_for_dataset = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( @@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id) + ).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( @@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions - all_segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + all_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(all_segments) == 6 # 3 existing + 3 new # Verify position ordering diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1dd37fbc92c..32bc2fc0bd7 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -52,18 +53,18 @@ class TestCleanDatasetTask: from extensions.ext_redis import redis_client # Clear all test data using the provided session fixture - db_session_with_containers.query(DatasetMetadataBinding).delete() - db_session_with_containers.query(DatasetMetadata).delete() - db_session_with_containers.query(AppDatasetJoin).delete() - db_session_with_containers.query(DatasetQuery).delete() - db_session_with_containers.query(DatasetProcessRule).delete() - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DatasetMetadataBinding)) + db_session_with_containers.execute(delete(DatasetMetadata)) + db_session_with_containers.execute(delete(AppDatasetJoin)) + db_session_with_containers.execute(delete(DatasetQuery)) + db_session_with_containers.execute(delete(DatasetProcessRule)) + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -302,28 +303,40 @@ class TestCleanDatasetTask: # Verify results # Check that dataset-related data was cleaned up - documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + documents = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id)).all() assert len(documents) == 0 - segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments) == 0 # Check that metadata and bindings were cleaned up - metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(metadata) == 0 - bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(bindings) == 0 # Check that process rules and queries were cleaned up - process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all() + process_rules = db_session_with_containers.scalars( + select(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset.id) + ).all() assert len(process_rules) == 0 - queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all() + queries = db_session_with_containers.scalars( + select(DatasetQuery).where(DatasetQuery.dataset_id == dataset.id) + ).all() assert len(queries) == 0 # Check that app dataset joins were cleaned up - app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all() + app_joins = db_session_with_containers.scalars( + select(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset.id) + ).all() assert len(app_joins) == 0 # Verify index processor was called @@ -414,24 +427,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that metadata and bindings were cleaned up - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify index processor was called @@ -485,12 +506,14 @@ class TestCleanDatasetTask: # Check that all data was cleaned up - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 - remaining_segments = ( - db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() - ) + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Recreate data for next test case @@ -538,11 +561,15 @@ class TestCleanDatasetTask: # Verify results - even with vector cleanup failure, documents and segments should be deleted # Check that documents were still deleted despite vector cleanup failure - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that segments were still deleted despite vector cleanup failure - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Verify that index processor was called and failed @@ -622,18 +649,22 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all image files were deleted from database image_file_ids = [f.id for f in image_files] - remaining_image_files = ( - db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all() - ) + remaining_image_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(image_file_ids)) + ).all() assert len(remaining_image_files) == 0 # Verify that storage.delete was called for each image file @@ -738,24 +769,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that all metadata and bindings were deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify performance expectations @@ -826,7 +865,9 @@ class TestCleanDatasetTask: # Check that upload file was still deleted from database despite storage failure # Note: When storage operations fail, the upload file may not be deleted # This demonstrates that the cleanup process continues even with storage errors - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file.id) + ).all() # The upload file should still be deleted from the database even if storage cleanup fails # However, this depends on the specific implementation of clean_dataset_task if len(remaining_files) > 0: @@ -976,19 +1017,27 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file_id) + ).all() assert len(remaining_files) == 0 # Check that all metadata was deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 # Verify that storage.delete was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 926c839c8b6..fa3ac12cf00 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -11,6 +11,8 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import ColumnElement, func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -20,6 +22,14 @@ from tasks.clean_notion_document_task import clean_notion_document_task from tests.test_containers_integration_tests.helpers import generate_valid_password +def _count_documents(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(Document).where(condition)) or 0 + + +def _count_segments(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(DocumentSegment).where(condition)) or 0 + + class TestCleanNotionDocumentTask: """Integration tests for clean_notion_document_task using testcontainers.""" @@ -145,24 +155,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 3 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(document_ids)) == 3 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 6 # Execute cleanup task clean_notion_document_task(document_ids, dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 0 # Verify index processor was called mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value @@ -322,12 +322,7 @@ class TestCleanNotionDocumentTask: # The task properly handles various index types and document configurations. # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == document.id) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Reset mock for next iteration mock_index_processor_factory.reset_mock() @@ -410,10 +405,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies that segments without index_node_ids # are properly deleted from the database. @@ -499,11 +491,8 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 5 - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 10 - ) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == 5 + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 10 # Clean up only first 3 documents documents_to_clean = [doc.id for doc in documents[:3]] @@ -513,22 +502,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(documents_to_clean, dataset.id) # Verify only specified documents' segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(documents_to_clean)) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(documents_to_clean)) == 0 # Verify remaining documents and segments are intact remaining_docs = [doc.id for doc in documents[3:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() - == 4 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 4 # Note: This test successfully verifies partial document cleanup operations. # The database operations work correctly, isolating only the specified documents. @@ -612,19 +591,13 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all segments exist before cleanup - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 4 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 4 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies database operations. # IndexProcessor verification would require more sophisticated mocking. @@ -794,12 +767,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == num_documents assert ( - db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() - == num_documents - ) - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == num_documents * num_segments_per_doc ) @@ -808,10 +778,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies bulk document cleanup operations. # The database efficiently handles large-scale deletions. @@ -906,8 +873,8 @@ class TestCleanNotionDocumentTask: # Verify all data exists before cleanup # Note: There may be documents from previous tests, so we check for at least 3 - assert db_session_with_containers.query(Document).count() >= 3 - assert db_session_with_containers.query(DocumentSegment).count() >= 9 + assert db_session_with_containers.scalar(select(func.count()).select_from(Document)) >= 3 + assert db_session_with_containers.scalar(select(func.count()).select_from(DocumentSegment)) >= 9 # Clean up documents from only the first dataset target_dataset = datasets[0] @@ -918,22 +885,12 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([target_document.id], target_dataset.id) # Verify only documents' segments from target dataset are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == target_document.id) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == target_document.id) == 0 # Verify documents from other datasets remain intact remaining_docs = [doc.id for doc in all_documents[1:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 6 # Note: This test successfully verifies multi-tenant isolation. # Only documents from the target dataset are affected, maintaining tenant separation. @@ -1028,11 +985,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == len( - document_statuses - ) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == len(document_statuses) assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == len(document_statuses) * 2 ) @@ -1041,10 +996,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies cleanup of documents in various states. # All documents are deleted regardless of their indexing status. @@ -1142,20 +1094,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 1 - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 3 - ) + assert _count_documents(db_session_with_containers, Document.id == document.id) == 1 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 3 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies cleanup of documents with rich metadata. # The task properly handles complex document structures and metadata fields. diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 9f8e37fc9ec..9084667c314 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -11,6 +11,7 @@ from uuid import uuid4 import pytest from faker import Faker +from sqlalchemy import delete from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client @@ -28,12 +29,12 @@ class TestCreateSegmentToIndexTask: """Clean up database and Redis before each test to ensure isolation.""" # Clear all test data using fixture session - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 13ea94348a6..684097851b6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -175,7 +176,7 @@ class TestDatasetIndexingTaskIntegration: def _query_document(self, db_session_with_containers, document_id: str) -> Document | None: """Return the latest persisted document state.""" - return db_session_with_containers.query(Document).where(Document.id == document_id).first() + return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1)) def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None: """Assert all target documents are persisted in parsing status.""" diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index d457b59d588..48fec441c58 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -11,6 +11,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from sqlalchemy import select from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -221,7 +222,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load method was called @@ -322,7 +325,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "update") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor clean and load methods were called @@ -431,7 +436,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify that no index processor load was called since no segments exist @@ -564,7 +571,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to error - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.ERROR assert "Test exception during indexing" in updated_document.error @@ -635,7 +644,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type @@ -711,7 +722,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type @@ -815,7 +828,9 @@ class TestDealDatasetVectorIndexTask: # Verify all documents were processed for document in documents: - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load was called multiple times @@ -917,7 +932,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify final document status - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( @@ -1027,12 +1044,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only enabled document was processed - updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() + updated_enabled_document = db_session_with_containers.scalar( + select(Document).where(Document.id == enabled_document.id).limit(1) + ) assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED # Verify disabled document status remains unchanged - updated_disabled_document = ( - db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() + updated_disabled_document = db_session_with_containers.scalar( + select(Document).where(Document.id == disabled_document.id).limit(1) ) assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change @@ -1148,12 +1167,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only active document was processed - updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() + updated_active_document = db_session_with_containers.scalar( + select(Document).where(Document.id == active_document.id).limit(1) + ) assert updated_active_document.indexing_status == IndexingStatus.COMPLETED # Verify archived document status remains unchanged - updated_archived_document = ( - db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() + updated_archived_document = db_session_with_containers.scalar( + select(Document).where(Document.id == archived_document.id).limit(1) ) assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change @@ -1269,14 +1290,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only completed document was processed - updated_completed_document = ( - db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() + updated_completed_document = db_session_with_containers.scalar( + select(Document).where(Document.id == completed_document.id).limit(1) ) assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED # Verify incomplete document status remains unchanged - updated_incomplete_document = ( - db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() + updated_incomplete_document = db_session_with_containers.scalar( + select(Document).where(Document.id == incomplete_document.id).limit(1) ) assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 3e9a0c8f7f0..6e03bd93514 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -471,9 +472,9 @@ class TestDisableSegmentsFromIndexTask: db_session_with_containers.refresh(segments[1]) # Check that segments are re-enabled after error - updated_segments = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() - ) + updated_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + ).all() for segment in updated_segments: assert segment.enabled is True diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index d4021143eff..b6e7e6e5c92 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -12,10 +12,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy import delete, func, select, update from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -30,12 +31,12 @@ class DocumentIndexingSyncTaskTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.flush() - tenant = Tenant(name=f"tenant-{account.id}", status="normal") + tenant = Tenant(name=f"tenant-{account.id}", status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.flush() @@ -254,8 +255,8 @@ class TestDocumentIndexingSyncTask: """Test that task raises error when data_source_info is empty.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) - db_session_with_containers.query(Document).where(Document.id == context["document"].id).update( - {"data_source_info": None} + db_session_with_containers.execute( + update(Document).where(Document.id == context["document"].id).values(data_source_info=None) ) db_session_with_containers.commit() @@ -274,8 +275,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR @@ -294,13 +295,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.COMPLETED @@ -319,13 +320,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None @@ -354,7 +355,7 @@ class TestDocumentIndexingSyncTask: context = self._create_notion_sync_context(db_session_with_containers) def _delete_dataset_before_clean() -> str: - db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete() + db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id)) db_session_with_containers.commit() return "2024-01-02T00:00:00Z" @@ -367,8 +368,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -386,13 +387,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -410,8 +411,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING @@ -428,8 +429,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index d94abf2b40a..a9a8c0f30cb 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import func, select from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -123,13 +124,13 @@ class TestDocumentIndexingUpdateTask: db_session_with_containers.expire_all() # Assert document status updated before reindex - updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() + updated = db_session_with_containers.scalar(select(Document).where(Document.id == document.id).limit(1)) assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None # Segments should be deleted - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining == 0 @@ -167,8 +168,8 @@ class TestDocumentIndexingUpdateTask: mock_external_dependencies["runner_instance"].run.assert_called_once() # Segments should remain (since clean failed before DB delete) - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining > 0 diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 6a8e1869580..39c58987fd0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -317,7 +318,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -362,14 +363,14 @@ class TestDuplicateDocumentIndexingTasks: # Verify segments were deleted from database # Re-query segments from database using captured IDs to avoid stale ORM instances for seg_id in segment_ids: - deleted_segment = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == seg_id).limit(1) ) assert deleted_segment is None # Verify documents were updated to parsing status for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -438,7 +439,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify only existing documents were updated # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in existing_document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -485,7 +486,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _duplicate_document_indexing_task close the session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -543,7 +544,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error.lower() @@ -585,7 +586,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "limit" in updated_document.error.lower() @@ -649,7 +650,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) @@ -692,7 +693,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) @@ -736,7 +737,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) @@ -851,7 +852,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.is_paused is True assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.display_status == "paused" diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index c0ddc27286d..83437119988 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -14,6 +14,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -41,9 +42,9 @@ class TestSendEmailCodeLoginMailTask: from extensions.ext_redis import redis_client # Clear all test data - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index a16f3ff773b..95a867dbb5c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,16 +3,14 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.runtime import GraphRuntimeState, VariablePool +from sqlalchemy import delete from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -20,6 +18,9 @@ from core.workflow.human_input_compat import ( MemberRecipient, ) from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient @@ -30,14 +31,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @pytest.fixture(autouse=True) def cleanup_database(db_session_with_containers): - db_session_with_containers.query(HumanInputFormRecipient).delete() - db_session_with_containers.query(HumanInputDelivery).delete() - db_session_with_containers.query(HumanInputForm).delete() - db_session_with_containers.query(WorkflowPause).delete() - db_session_with_containers.query(WorkflowRun).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(HumanInputFormRecipient)) + db_session_with_containers.execute(delete(HumanInputDelivery)) + db_session_with_containers.execute(delete(HumanInputForm)) + db_session_with_containers.execute(delete(WorkflowPause)) + db_session_with_containers.execute(delete(WorkflowRun)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index 212fbd26cd4..d34828c4b14 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from extensions.ext_redis import redis_client from libs.email_i18n import EmailType @@ -44,9 +45,9 @@ class TestMailInviteMemberTask: def cleanup_database(self, db_session_with_containers): """Clean up database before each test to ensure isolation.""" # Clear all test data - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -491,10 +492,10 @@ class TestMailInviteMemberTask: assert tenant.name is not None # Verify tenant relationship exists - tenant_join = ( - db_session_with_containers.query(TenantAccountJoin) - .filter_by(tenant_id=tenant.id, account_id=pending_account.id) - .first() + tenant_join = db_session_with_containers.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == pending_account.id) + .limit(1) ) assert tenant_join is not None assert tenant_join.role == TenantAccountRole.NORMAL diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 96cf9cebf5e..b43b6228701 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,12 @@ import uuid from unittest.mock import ANY, call, patch import pytest -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole @@ -20,11 +21,11 @@ from tasks.remove_app_and_related_data_task import ( @pytest.fixture(autouse=True) def cleanup_database(db_session_with_containers): - db_session_with_containers.query(WorkflowDraftVariable).delete() - db_session_with_containers.query(WorkflowDraftVariableFile).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(App).delete() - db_session_with_containers.query(Tenant).delete() + db_session_with_containers.execute(delete(WorkflowDraftVariable)) + db_session_with_containers.execute(delete(WorkflowDraftVariableFile)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(App)) + db_session_with_containers.execute(delete(Tenant)) db_session_with_containers.commit() @@ -127,21 +128,21 @@ class TestDeleteDraftVariablesBatch: result = delete_draft_variables_batch(app1.id, batch_size=100) assert result == 150 - app1_remaining = db_session_with_containers.query(WorkflowDraftVariable).where( - WorkflowDraftVariable.app_id == app1.id + app1_remaining_count = db_session_with_containers.scalar( + select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app1.id) ) - app2_remaining = db_session_with_containers.query(WorkflowDraftVariable).where( - WorkflowDraftVariable.app_id == app2.id + app2_remaining_count = db_session_with_containers.scalar( + select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app2.id) ) - assert app1_remaining.count() == 0 - assert app2_remaining.count() == 100 + assert app1_remaining_count == 0 + assert app2_remaining_count == 100 def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers): """Test deletion when no draft variables exist for the app.""" result = delete_draft_variables_batch(str(uuid.uuid4()), 1000) assert result == 0 - assert db_session_with_containers.query(WorkflowDraftVariable).count() == 0 + assert db_session_with_containers.scalar(select(func.count()).select_from(WorkflowDraftVariable)) == 0 @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") @patch("tasks.remove_app_and_related_data_task.logger") @@ -190,12 +191,16 @@ class TestDeleteDraftVariableOffloadData: expected_storage_calls = [call(storage_key) for storage_key in upload_file_keys] mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True) - remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where( - WorkflowDraftVariableFile.id.in_(file_ids) + remaining_var_files_count = db_session_with_containers.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(file_ids)) ) - remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)) - assert remaining_var_files.count() == 0 - assert remaining_upload_files.count() == 0 + remaining_upload_files_count = db_session_with_containers.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) + assert remaining_var_files_count == 0 + assert remaining_upload_files_count == 0 @patch("extensions.ext_storage.storage") @patch("tasks.remove_app_and_related_data_task.logging") @@ -217,9 +222,13 @@ class TestDeleteDraftVariableOffloadData: assert result == 1 mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", storage_keys[0]) - remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where( - WorkflowDraftVariableFile.id.in_(file_ids) + remaining_var_files_count = db_session_with_containers.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(file_ids)) ) - remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)) - assert remaining_var_files.count() == 0 - assert remaining_upload_files.count() == 0 + remaining_upload_files_count = db_session_with_containers.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) + assert remaining_var_files_count == 0 + assert remaining_upload_files_count == 0 diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 159ab51304a..b00d827e37c 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,16 +24,16 @@ from dataclasses import dataclass from datetime import timedelta import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus -from sqlalchemy import delete, select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel -from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.model import UploadFile from models.workflow import Workflow, WorkflowRun from repositories.sqlalchemy_api_workflow_run_repository import ( @@ -181,7 +181,7 @@ class TestWorkflowPauseIntegration: tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -190,7 +190,7 @@ class TestWorkflowPauseIntegration: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -679,9 +679,12 @@ class TestWorkflowPauseIntegration: # Verify only 3 were deleted remaining_count = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities])) - .count() + self.session.scalar( + select(func.count(WorkflowPauseModel.id)).where( + WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]) + ) + ) + or 0 ) assert remaining_count == 2 @@ -693,7 +696,7 @@ class TestWorkflowPauseIntegration: tenant2 = Tenant( name="Test Tenant 2", - status="normal", + status=TenantStatus.NORMAL, ) self.session.add(tenant2) self.session.commit() @@ -702,7 +705,7 @@ class TestWorkflowPauseIntegration: email="test2@example.com", name="Test User 2", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) self.session.add(account2) self.session.commit() diff --git a/api/tests/test_containers_integration_tests/trigger/conftest.py b/api/tests/test_containers_integration_tests/trigger/conftest.py index e3832fb2ef2..272bee9630d 100644 --- a/api/tests/test_containers_integration_tests/trigger/conftest.py +++ b/api/tests/test_containers_integration_tests/trigger/conftest.py @@ -11,6 +11,7 @@ from collections.abc import Generator from typing import Any import pytest +from sqlalchemy import delete from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -40,9 +41,9 @@ def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[T yield tenant, account # Cleanup - db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() - db_session_with_containers.query(Account).filter_by(id=account.id).delete() - db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete() + db_session_with_containers.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == tenant.id)) + db_session_with_containers.execute(delete(Account).where(Account.id == account.id)) + db_session_with_containers.execute(delete(Tenant).where(Tenant.id == tenant.id)) db_session_with_containers.commit() @@ -93,14 +94,14 @@ def app_model( ) from models.workflow import Workflow - db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete() - db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete() - db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete() - db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete() - db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete() - db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete() - db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete() - db_session_with_containers.query(App).filter_by(id=app.id).delete() + db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app.id)) + db_session_with_containers.execute(delete(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app.id)) + db_session_with_containers.execute(delete(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id)) + db_session_with_containers.execute(delete(WorkflowPluginTrigger).where(WorkflowPluginTrigger.app_id == app.id)) + db_session_with_containers.execute(delete(AppTrigger).where(AppTrigger.app_id == app.id)) + db_session_with_containers.execute(delete(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant.id)) + db_session_with_containers.execute(delete(Workflow).where(Workflow.app_id == app.id)) + db_session_with_containers.execute(delete(App).where(App.id == app.id)) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 7539bae6855..55aec498781 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,7 +10,7 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient -from graphon.enums import BuiltinNodeTypes +from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -24,6 +24,7 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus @@ -227,7 +228,9 @@ def test_webhook_trigger_creates_trigger_log( assert response.status_code == 200 db_session_with_containers.expire_all() - logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + logs = db_session_with_containers.scalars( + select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id) + ).all() assert logs, "Webhook trigger should create trigger log" @@ -611,7 +614,9 @@ def test_schedule_trigger_creates_trigger_log( # Verify WorkflowTriggerLog was created db_session_with_containers.expire_all() - logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + logs = db_session_with_containers.scalars( + select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id) + ).all() assert logs, "Schedule trigger should create WorkflowTriggerLog" assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE assert logs[0].root_node_id == schedule_node_id @@ -786,11 +791,12 @@ def test_plugin_trigger_full_chain_with_db_verification( # Verify database records exist db_session_with_containers.expire_all() - plugin_triggers = ( - db_session_with_containers.query(WorkflowPluginTrigger) - .filter_by(app_id=app_model.id, node_id=plugin_node_id) - .all() - ) + plugin_triggers = db_session_with_containers.scalars( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.app_id == app_model.id, + WorkflowPluginTrigger.node_id == plugin_node_id, + ) + ).all() assert plugin_triggers, "WorkflowPluginTrigger record should exist" assert plugin_triggers[0].provider_id == provider_id assert plugin_triggers[0].event_name == "test_event" diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index d6933e2180f..bad246a4bb1 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -145,7 +145,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): - """Test that DB_EXTRAS options are properly merged with default timezone setting""" + """Test that DB_EXTRAS options are merged with the default timezone startup option.""" # Set environment variables monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") @@ -158,15 +158,28 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): # Create config config = DifyConfig() - # Get engine options - engine_options = config.SQLALCHEMY_ENGINE_OPTIONS - - # Verify options contains both search_path and timezone - options = engine_options["connect_args"]["options"] + options = config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"]["options"] assert "search_path=myschema" in options assert "timezone=UTC" in options +def test_db_session_timezone_override_can_disable_app_level_timezone_injection(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("DB_EXTRAS", "options=-c search_path=myschema") + monkeypatch.setenv("DB_SESSION_TIMEZONE_OVERRIDE", "") + + config = DifyConfig() + + assert config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"] == { + "options": "-c search_path=myschema", + } + + def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch): os.environ.clear() @@ -223,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest. _ = DifyConfig().normalized_pubsub_redis_url +def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "" + + +def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "enterprise-a" + + @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py new file mode 100644 index 00000000000..9c4678aed31 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -0,0 +1,139 @@ +"""Unit tests for console app import endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import app_import as app_import_module +from services.app_dsl_service import ImportStatus + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _Result: + def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): + self.status = status + self.app_id = app_id + + def model_dump(self, mode: str = "json"): + return {"status": self.status, "app_id": self.app_id} + + +def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: + features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled)) + monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features) + + +def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + return fake_session + + +class TestAppImportApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportApi() + + def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + assert status == 202 + assert response["status"] == ImportStatus.PENDING + + def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=True) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + update_access = MagicMock() + monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + update_access.assert_called_once_with("app-123", "private") + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + +class TestAppImportConfirmApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportConfirmApi() + + def test_import_confirm_returns_failed_status_and_rolls_back( + self, api, app, monkeypatch: pytest.MonkeyPatch + ) -> None: + method = _unwrap(api.post) + + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "confirm_import", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): + response, status = method(import_id="import-1") + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 2ac3dc037dd..35d07a987d0 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -138,12 +138,15 @@ def app_models(app_module): def patch_signed_url(monkeypatch, app_module): """Ensure icon URL generation uses a deterministic helper for tests.""" - def _fake_signed_url(key: str | None) -> str | None: - if not key: + def _fake_build_icon_url(_icon_type, key: str | None) -> str | None: + if key is None: + return None + icon_type = str(_icon_type).lower() + if icon_type != "image": return None return f"signed:{key}" - monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + monkeypatch.setattr(app_module, "build_icon_url", _fake_build_icon_url) def _ts(hour: int = 12) -> datetime: diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index c52bc02420e..2d218dac7e5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -4,7 +4,6 @@ import io from types import SimpleNamespace import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -21,6 +20,7 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 11b3b3470d6..24b7e39f737 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -33,12 +33,17 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() + paginate_result.page = 1 + paginate_result.per_page = 20 + paginate_result.total = 0 + paginate_result.has_next = False + paginate_result.items = [] monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"): response = method(app_model=SimpleNamespace(id="app-1")) - assert response is paginate_result + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -71,12 +76,17 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() + paginate_result.page = 1 + paginate_result.per_page = 20 + paginate_result.total = 0 + paginate_result.has_next = False + paginate_result.items = [] monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"): response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) - assert response is paginate_result + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py deleted file mode 100644 index f588ab261d8..00000000000 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ /dev/null @@ -1,42 +0,0 @@ -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from controllers.console.app.conversation import _get_conversation - - -def test_get_conversation_mark_read_keeps_updated_at_unchanged(): - app_model = SimpleNamespace(id="app-id") - account = SimpleNamespace(id="account-id") - conversation = MagicMock() - conversation.id = "conversation-id" - - with ( - patch( - "controllers.console.app.conversation.current_account_with_tenant", - return_value=(account, None), - autospec=True, - ), - patch( - "controllers.console.app.conversation.naive_utc_now", - return_value=datetime(2026, 2, 9, 0, 0, 0), - autospec=True, - ), - patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, - ): - mock_session.scalar.return_value = conversation - - _get_conversation(app_model, "conversation-id") - - statement = mock_session.execute.call_args[0][0] - compiled = statement.compile() - sql_text = str(compiled).lower() - compact_sql_text = sql_text.replace(" ", "") - params = compiled.params - - assert "updated_at=current_timestamp" not in compact_sql_text - assert "updated_at=conversations.updated_at" in compact_sql_text - assert "read_at=:read_at" in compact_sql_text - assert "read_account_id=:read_account_id" in compact_sql_text - assert params["read_at"] == datetime(2026, 2, 9, 0, 0, 0) - assert params["read_account_id"] == "account-id" diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py new file mode 100644 index 00000000000..1a412aff29b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from contextlib import nullcontext +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest +from pydantic import ValidationError + +from controllers.console.app import conversation_variables as conversation_variables_module +from graphon.variables.types import SegmentType + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + created_at = datetime(2026, 1, 1, tzinfo=UTC) + updated_at = datetime(2026, 1, 2, tzinfo=UTC) + row = SimpleNamespace( + created_at=created_at, + updated_at=updated_at, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-1", + "name": "my_var", + "value_type": "string", + "value": "value", + "description": "desc", + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["page"] == 1 + assert response["limit"] == 100 + assert response["total"] == 1 + assert response["has_more"] is False + assert response["data"][0]["id"] == "var-1" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["updated_at"] == int(updated_at.timestamp()) + + +def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + row = SimpleNamespace( + created_at=None, + updated_at=None, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-2", + "name": "my_var_2", + "value_type": SegmentType.INTEGER, + "value": 42, + "description": None, + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["data"][0]["value_type"] == "number" + assert response["data"][0]["value"] == "42" + + +def test_get_conversation_variables_requires_conversation_id(app) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"): + with pytest.raises(ValidationError): + method(app_model=SimpleNamespace(id="app-1")) diff --git a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py new file mode 100644 index 00000000000..1af15d8dc63 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py @@ -0,0 +1,138 @@ +import datetime +from types import SimpleNamespace +from unittest.mock import PropertyMock, patch + +from flask import Flask + +from controllers.console import console_ns +from controllers.console.app.mcp_server import AppMCPServerController, AppMCPServerResponse + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class _ValidatedResponse: + def __init__(self, payload): + self._payload = payload + + def model_dump(self, mode="json"): + return self._payload + + +class TestAppMCPServerResponse: + def test_parameters_json_string_parsed(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": '{"key": "value"}', + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == {"key": "value"} + + def test_parameters_invalid_json_returns_original(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": "not-valid-json", + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == "not-valid-json" + + def test_parameters_dict_passthrough(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": {"already": "parsed"}, + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == {"already": "parsed"} + + def test_parameters_json_array_parsed(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": '["a", "b"]', + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == ["a", "b"] + + def test_timestamps_normalized(self): + dt = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": {}, + "created_at": dt, + "updated_at": dt, + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.created_at == int(dt.timestamp()) + assert resp.updated_at == int(dt.timestamp()) + + def test_timestamps_none(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": {}, + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.created_at is None + assert resp.updated_at is None + + +class TestAppMCPServerController: + def test_get_returns_empty_dict_when_server_missing(self): + api = AppMCPServerController() + method = unwrap(api.get) + + with patch("controllers.console.app.mcp_server.db.session.scalar", return_value=None): + response = method(api, app_model=SimpleNamespace(id="app-1")) + + assert response == {} + + def test_post_returns_201(self): + api = AppMCPServerController() + method = unwrap(api.post) + payload = {"parameters": {"timeout": 30}} + app = Flask(__name__) + app.config["TESTING"] = True + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")), + patch("controllers.console.app.mcp_server.db.session.add"), + patch("controllers.console.app.mcp_server.db.session.commit"), + patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"), + patch( + "controllers.console.app.mcp_server.AppMCPServerResponse.model_validate", + return_value=_ValidatedResponse({"id": "server-1"}), + ), + ): + response, status_code = method( + api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description") + ) + + assert response == {"id": "server-1"} + assert status_code == 201 diff --git a/api/tests/unit_tests/controllers/console/app/test_message_api.py b/api/tests/unit_tests/controllers/console/app/test_message_api.py index a76e9588296..c984dbef5d7 100644 --- a/api/tests/unit_tests/controllers/console/app/test_message_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_message_api.py @@ -1,5 +1,7 @@ from __future__ import annotations +from datetime import UTC, datetime + import pytest from controllers.console.app import message as message_module @@ -120,3 +122,24 @@ def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> N response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"]) assert len(response.data) == 2 assert response.data[0] == "What is AI?" + + +def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageDetailResponse normalizes alias fields and datetime timestamps.""" + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = message_module.MessageDetailResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "conversation_id": "550e8400-e29b-41d4-a716-446655440001", + "inputs": {"foo": "bar"}, + "query": "hello", + "re_sign_file_url_answer": "world", + "from_source": "user", + "status": "normal", + "created_at": created_at, + "message_metadata_dict": {"token_usage": 3}, + } + ) + assert response.answer == "world" + assert response.metadata == {"token_usage": 3} + assert response.created_at == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 36076368805..e91c0a05972 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -1,15 +1,16 @@ from __future__ import annotations +import json from datetime import datetime from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.file import File, FileTransferMethod, FileType from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from graphon.file import File, FileTransferMethod, FileType def _unwrap(func): @@ -30,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: file_list = [ File( tenant_id="t1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="http://u", ) @@ -258,6 +259,63 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( assert exc.value.description == "invalid workflow graph" +def test_get_published_workflows_marshals_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = workflow_module.PublishedAllWorkflowApi() + handler = _unwrap(api.get) + + session_state = {"open": False} + + class _SessionContext: + def __enter__(self): + session_state["open"] = True + return object() + + def __exit__(self, exc_type, exc, tb): + session_state["open"] = False + return False + + class _SessionMaker: + def begin(self): + return _SessionContext() + + class _Workflow: + @property + def id(self): + assert session_state["open"] is True + return "w1" + + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker()) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + get_all_published_workflow=lambda **_kwargs: ([_Workflow()], False), + ), + ) + + def _fake_marshal(items, fields): + assert session_state["open"] is True + return [{"id": item.id} for item in items] + + monkeypatch.setattr(workflow_module, "marshal", _fake_marshal) + + with app.test_request_context( + "/apps/app/workflows", + method="GET", + query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"}, + ): + response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1")) + + assert response == { + "items": [{"id": "w1"}], + "page": 1, + "limit": 10, + "has_more": False, + } + + def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) @@ -290,3 +348,87 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk ): with pytest.raises(NotFound): handler(api, app_model=SimpleNamespace(id="app")) + + +def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: pytest.MonkeyPatch) -> None: + app_id_1 = "11111111-1111-1111-1111-111111111111" + app_id_2 = "22222222-2222-2222-2222-222222222222" + signed_avatar_url = "https://files.example.com/signed/avatar-1" + sign_avatar = Mock(return_value=signed_avatar_url) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=lambda app_ids, tenant_id: {app_id_1}), + ) + monkeypatch.setattr(workflow_module.file_helpers, "get_signed_file_url", sign_avatar) + + workflow_module.redis_client.hgetall.side_effect = lambda key: ( + { + b"sid-1": json.dumps( + { + "user_id": "u-1", + "username": "Alice", + "avatar": "avatar-file-id", + "sid": "sid-1", + } + ) + } + if key == f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}" + else {} + ) + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.get) + + with app.test_request_context( + f"/apps/workflows/online-users?app_ids={app_id_1},{app_id_2}", + method="GET", + ): + response = handler(api) + + assert response == { + "data": [ + { + "app_id": app_id_1, + "users": [ + { + "user_id": "u-1", + "username": "Alice", + "avatar": signed_avatar_url, + "sid": "sid-1", + } + ], + } + ] + } + workflow_module.redis_client.hgetall.assert_called_once_with( + f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}" + ) + sign_avatar.assert_called_once_with("avatar-file-id") + + +def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + accessible_app_ids = Mock(return_value=set()) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=accessible_app_ids), + ) + + excessive_ids = ",".join(f"wf-{index}" for index in range(workflow_module.MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS + 1)) + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.get) + + with app.test_request_context( + f"/apps/workflows/online-users?app_ids={excessive_ids}", + method="GET", + ): + with pytest.raises(HTTPException) as exc: + handler(api) + + assert exc.value.code == 400 + assert "Maximum" in exc.value.description + accessible_app_ids.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py new file mode 100644 index 00000000000..a9853f25b01 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from controllers.console.app import workflow_app_log as workflow_app_log_module +from graphon.enums import WorkflowExecutionStatus + + +def test_workflow_app_log_query_parses_bool_and_datetime(): + query = workflow_app_log_module.WorkflowAppLogQuery.model_validate( + { + "detail": "true", + "created_at__before": "2026-01-02T03:04:05Z", + "page": "2", + "limit": "10", + } + ) + + assert query.detail is True + assert query.created_at__before == datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + assert query.page == 2 + assert query.limit == 10 + + +def test_workflow_app_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowAppLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "log-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.SUCCEEDED, + "created_at": created_at, + "finished_at": created_at, + }, + "details": {"trigger_metadata": {}}, + "created_by_account": {"id": "acc-1", "name": "acc", "email": "acc@example.com"}, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "succeeded" + assert response["data"][0]["workflow_run"]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + + +def test_workflow_archived_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowArchivedLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "archived-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.FAILED, + }, + "trigger_metadata": {"type": "trigger-plugin"}, + "created_by_end_user": { + "id": "eu-1", + "type": "anonymous", + "is_anonymous": True, + "session_id": "session-1", + }, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "failed" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py new file mode 100644 index 00000000000..85afcf0e60a --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from contextlib import nullcontext +from dataclasses import dataclass +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_comment as workflow_comment_module +from controllers.console.app import wraps as app_wraps +from libs import login as login_lib +from models.account import Account, AccountStatus, TenantAccountRole + + +def _make_account(role: TenantAccountRole) -> Account: + account = Account(name="tester", email="tester@example.com") + account.status = AccountStatus.ACTIVE + account.role = role + account.id = "account-123" # type: ignore[assignment] + account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] + account._get_current_object = lambda: account # type: ignore[attr-defined] + return account + + +def _make_app() -> SimpleNamespace: + return SimpleNamespace(id="app-123", tenant_id="tenant-123", status="normal", mode="workflow") + + +def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None: + monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) + monkeypatch.setattr(login_lib, "current_user", account) + monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) + monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) + monkeypatch.setattr(workflow_comment_module, "current_user", account) + + +def _patch_write_services(monkeypatch: pytest.MonkeyPatch) -> None: + for method_name in ( + "create_comment", + "update_comment", + "delete_comment", + "resolve_comment", + "validate_comment_access", + "create_reply", + "update_reply", + "delete_reply", + ): + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, method_name, MagicMock()) + + +def _patch_payload(payload: dict[str, object] | None): + if payload is None: + return nullcontext() + return patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + +@dataclass(frozen=True) +class WriteCase: + resource_cls: type + method_name: str + path: str + kwargs: dict[str, str] + payload: dict[str, object] | None = None + + +@pytest.mark.parametrize( + "case", + [ + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentListApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments", + kwargs={"app_id": "app-123"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentResolveApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/resolve", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + ), + ], +) +def test_write_endpoints_require_edit_permission(app: Flask, monkeypatch: pytest.MonkeyPatch, case: WriteCase) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.NORMAL) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + _patch_write_services(monkeypatch) + + with app.test_request_context(case.path, method=case.method_name.upper(), json=case.payload): + with _patch_payload(case.payload): + handler = getattr(case.resource_cls(), case.method_name) + with pytest.raises(Forbidden): + handler(**case.kwargs) + + +def test_create_comment_allows_editor(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + create_comment_mock = MagicMock(return_value={"id": "comment-1"}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "create_comment", create_comment_mock) + payload = {"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments", method="POST", json=payload): + with _patch_payload(payload): + result = workflow_comment_module.WorkflowCommentListApi().post(app_id="app-123") + + if isinstance(result, tuple): + response = result[0] + else: + response = result + assert response["id"] == "comment-1" + create_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + created_by="account-123", + content="hello", + position_x=1.0, + position_y=2.0, + mentioned_user_ids=[], + ) + + +def test_update_comment_omits_mentions_when_payload_does_not_include_them( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + update_comment_mock = MagicMock(return_value={"id": "comment-1", "updated_at": datetime(2024, 1, 1, 12, 0, 0)}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "update_comment", update_comment_mock) + payload = {"content": "hello", "position_x": 10.0, "position_y": 20.0} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments/comment-1", method="PUT", json=payload): + with _patch_payload(payload): + workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1") + + update_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + comment_id="comment-1", + user_id="account-123", + content="hello", + position_x=10.0, + position_y=20.0, + mentioned_user_ids=None, + ) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index e11102acb1e..c4a81484466 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py new file mode 100644 index 00000000000..5363aa154f7 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +from controllers.console.app import workflow_trigger as workflow_trigger_module + + +def test_parser_models_validate(): + parser = workflow_trigger_module.Parser(node_id="node-1") + enable_parser = workflow_trigger_module.ParserEnable( + trigger_id="550e8400-e29b-41d4-a716-446655440000", enable_trigger=True + ) + + assert parser.node_id == "node-1" + assert enable_parser.enable_trigger is True + + +def test_workflow_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + trigger = SimpleNamespace( + id="trigger-1", + trigger_type="trigger-plugin", + title="Trigger", + node_id="node-1", + provider_name="provider", + icon="https://example.com/icon", + status="enabled", + created_at=created_at, + updated_at=created_at, + ) + + payload = workflow_trigger_module.WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump( + mode="json" + ) + assert payload["id"] == "trigger-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" + assert payload["updated_at"] == "2026-01-02T03:04:05Z" + + +def test_webhook_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + webhook = { + "id": "webhook-1", + "webhook_id": "whk-1", + "webhook_url": "https://example.com/hook", + "webhook_debug_url": "https://example.com/hook/debug", + "node_id": "node-1", + "created_at": created_at, + } + + payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json") + assert payload["webhook_id"] == "whk-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index 740da1f1df1..22b80b748e6 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal -from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -16,6 +15,7 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile @@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url(): # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", filename="test.jpg", @@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url(): # Create a File object with REMOTE_URL transfer method test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", filename="test.jpg", diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 560971206f6..0cf97da8785 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -14,18 +14,20 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask_restx import Api +from werkzeug.exceptions import Unauthorized from controllers.console.auth.error import ( AuthenticationFailedError, EmailPasswordLoginLimitError, InvalidEmailError, ) -from controllers.console.auth.login import LoginApi, LogoutApi +from controllers.console.auth.login import EmailCodeLoginApi, LoginApi, LogoutApi from controllers.console.error import ( AccountBannedError, AccountInFreezeError, WorkspacesLimitExceeded, ) +from services.entities.auth_entities import LoginFailureReason from services.errors.account import AccountLoginError, AccountPasswordError @@ -34,6 +36,11 @@ def encode_password(password: str) -> str: return base64.b64encode(password.encode("utf-8")).decode() +def encode_code(code: str) -> str: + """Helper to encode verification code as Base64 for testing.""" + return base64.b64encode(code.encode("utf-8")).decode() + + class TestLoginApi: """Test cases for the LoginApi endpoint.""" @@ -197,12 +204,17 @@ class TestLoginApi: mock_get_invitation.return_value = None # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} - ): - login_api = LoginApi() - with pytest.raises(EmailPasswordLoginLimitError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} + ): + login_api = LoginApi() + with pytest.raises(EmailPasswordLoginLimitError): + login_api.post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "test@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.LOGIN_RATE_LIMITED @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True) @@ -220,12 +232,17 @@ class TestLoginApi: mock_is_frozen.return_value = True # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} - ): - login_api = LoginApi() - with pytest.raises(AccountInFreezeError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} + ): + login_api = LoginApi() + with pytest.raises(AccountInFreezeError): + login_api.post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "frozen@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_IN_FREEZE @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -257,14 +274,20 @@ class TestLoginApi: mock_authenticate.side_effect = AccountPasswordError("Invalid password") # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")} - ): - login_api = LoginApi() - with pytest.raises(AuthenticationFailedError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", + method="POST", + json={"email": "test@example.com", "password": encode_password("WrongPass123!")}, + ): + login_api = LoginApi() + with pytest.raises(AuthenticationFailedError): + login_api.post() mock_add_rate_limit.assert_called_once_with("test@example.com") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "test@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -288,12 +311,19 @@ class TestLoginApi: mock_authenticate.side_effect = AccountLoginError("Account is banned") # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")} - ): - login_api = LoginApi() - with pytest.raises(AccountBannedError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", + method="POST", + json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}, + ): + login_api = LoginApi() + with pytest.raises(AccountBannedError): + login_api.post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "banned@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -417,6 +447,36 @@ class TestLoginApi: mock_add_rate_limit.assert_not_called() mock_reset_rate_limit.assert_called_once_with("upper@example.com") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") + @patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token") + @patch("controllers.console.auth.login._get_account_with_case_fallback") + def test_email_code_login_logs_banned_account( + self, + mock_get_account, + mock_revoke_token, + mock_get_token_data, + mock_db, + app, + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} + mock_get_account.side_effect = Unauthorized("Account is banned.") + + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(AccountBannedError): + EmailCodeLoginApi().post() + + mock_revoke_token.assert_called_once_with("token-123") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED + class TestLogoutApi: """Test cases for the LogoutApi endpoint.""" diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 9c9f8da87c1..5136922e88e 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -18,6 +17,7 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index 6ef8ccfdbd3..63950736c54 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Response -from graphon.variables.types import SegmentType from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist @@ -16,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor ) from controllers.web.error import InvalidArgumentError, NotFoundError from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 8555900f4ec..9465936f280 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -1555,7 +1555,17 @@ class TestDatasetApiKeyApi: method = unwrap(api.get) mock_key_1 = MagicMock(spec=ApiToken) + mock_key_1.id = "key-1" + mock_key_1.type = "dataset" + mock_key_1.token = "ds-abc" + mock_key_1.last_used_at = None + mock_key_1.created_at = None mock_key_2 = MagicMock(spec=ApiToken) + mock_key_2.id = "key-2" + mock_key_2.type = "dataset" + mock_key_2.token = "ds-def" + mock_key_2.last_used_at = None + mock_key_2.created_at = None with ( app.test_request_context("/"), @@ -1570,13 +1580,26 @@ class TestDatasetApiKeyApi: ): response = method(api) - assert "items" in response - assert response["items"] == [mock_key_1, mock_key_2] + assert "data" in response + assert len(response["data"]) == 2 + assert response["data"][0]["id"] == "key-1" + assert response["data"][0]["token"] == "ds-abc" + assert response["data"][1]["id"] == "key-2" + assert response["data"][1]["token"] == "ds-def" def test_post_create_api_key_success(self, app): api = DatasetApiKeyApi() method = unwrap(api.post) + mock_token = MagicMock() + mock_token.id = "new-key-id" + mock_token.last_used_at = None + mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + + mock_api_token_cls = MagicMock() + mock_api_token_cls.return_value = mock_token + mock_api_token_cls.generate_api_key.return_value = "dataset-abc123" + with ( app.test_request_context("/"), patch( @@ -1588,8 +1611,8 @@ class TestDatasetApiKeyApi: return_value=3, ), patch( - "controllers.console.datasets.datasets.ApiToken.generate_api_key", - return_value="dataset-abc123", + "controllers.console.datasets.datasets.ApiToken", + mock_api_token_cls, ), patch( "controllers.console.datasets.datasets.db.session.add", @@ -1603,9 +1626,11 @@ class TestDatasetApiKeyApi: response, status = method(api) assert status == 200 - assert isinstance(response, ApiToken) - assert response.token == "dataset-abc123" - assert response.type == "dataset" + assert isinstance(response, dict) + assert response["id"] == "new-key-id" + assert response["token"] == "dataset-abc123" + assert response["type"] == "dataset" + assert response["created_at"] is not None def test_post_exceed_max_keys(self, app): api = DatasetApiKeyApi() @@ -1747,6 +1772,21 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "http://localhost:5000/v1" + def test_get_api_base_url_no_double_v1(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + "https://example.com/v1", + ), + ): + response = method(api) + + assert response["api_base_url"] == "https://example.com/v1" + class TestDatasetRetrievalSettingApi: def test_get_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index ce2278de4f8..d9b02ac4533 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -215,17 +216,23 @@ class TestDatasetDocumentListApi: method = unwrap(api.post) payload = {"indexing_technique": "economy"} + created_dataset = SimpleNamespace(id="ds-1", name="Dataset", indexing_technique="economy") + created_document = SimpleNamespace(id="doc-1", name="Document", doc_metadata_details=None) with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=created_dataset, + ), patch( "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", - return_value=([MagicMock()], "batch-1"), + return_value=([created_document], "batch-1"), ), ): response = method(api, "ds-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 161d0c41e83..514bbbe0402 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -1,3 +1,4 @@ +from importlib import import_module from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -11,6 +12,7 @@ from controllers.console.datasets.external import ( BedrockRetrievalApi, ExternalApiTemplateApi, ExternalApiTemplateListApi, + ExternalApiUseCheckApi, ExternalDatasetCreateApi, ExternalKnowledgeHitTestingApi, ) @@ -19,6 +21,8 @@ from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService +external_controller = import_module("controllers.console.datasets.external") + def unwrap(func): while hasattr(func, "__wrapped__"): @@ -44,10 +48,11 @@ def current_user(): @pytest.fixture(autouse=True) -def mock_auth(mocker, current_user): - mocker.patch( - "controllers.console.datasets.external.current_account_with_tenant", - return_value=(current_user, "tenant-1"), +def mock_auth(monkeypatch, current_user): + monkeypatch.setattr( + external_controller, + "current_account_with_tenant", + lambda: (current_user, "tenant-1"), ) @@ -136,6 +141,26 @@ class TestExternalApiTemplateApi: method(api, "api-id") +class TestExternalApiUseCheckApi: + def test_get_scopes_usage_check_to_current_tenant(self, app): + api = ExternalApiUseCheckApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + ExternalDatasetService, + "external_knowledge_api_use_check", + return_value=(True, 2), + ) as mock_use_check, + ): + response, status = method(api, "api-id") + + assert status == 200 + assert response == {"is_using": True, "count": 2} + mock_use_check.assert_called_once_with("api-id", "tenant-1") + + class TestExternalDatasetCreateApi: def test_create_success(self, app): api = ExternalDatasetCreateApi() diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index 726c0a5cf39..09ed2aaf697 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -99,6 +99,57 @@ class TestHitTestingApi: assert "records" in result assert result["records"] == [] + def test_hit_testing_success_with_optional_record_fields(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "what is vector search", + } + records = [ + { + "segment": None, + "child_chunks": [], + "score": None, + "tsne_position": None, + "files": [], + "summary": None, + } + ] + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + ), + patch.object( + HitTestingApi, + "perform_hit_testing", + return_value={"query": payload["query"], "records": records}, + ), + ): + result = method(api, dataset_id) + + assert result["query"] == payload["query"] + assert result["records"] == records + def test_hit_testing_dataset_not_found(self, app, dataset_id): api = HitTestingApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index 710c9be684c..e4acd91b76c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -21,6 +20,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index 66c9ba48c59..b4b57022e28 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -2,7 +2,6 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import controllers.console.explore.audio as audio_module @@ -20,6 +19,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 2e4ca4f2a44..145cc9cdd7e 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -22,6 +21,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py index 02c7507ea72..76c863577a9 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import controllers.console.explore.recommended_app as module +from models.model import AppMode, IconType def unwrap(func): @@ -90,3 +91,48 @@ class TestRecommendedAppApi: service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111") assert result == result_data + + +class TestRecommendedAppResponseModels: + def test_recommended_app_info_response_computes_icon_url(self): + with patch.object(module, "build_icon_url", return_value="https://signed/icon.png"): + payload = module.RecommendedAppInfoResponse.model_validate( + { + "id": "app-1", + "name": "App", + "mode": AppMode.CHAT, + "icon": "icon.png", + "icon_type": IconType.IMAGE, + "icon_background": "#fff", + } + ).model_dump(mode="json") + + assert payload["icon_url"] == "https://signed/icon.png" + + def test_recommended_app_list_response_serialization(self): + response = module.RecommendedAppListResponse.model_validate( + { + "recommended_apps": [ + { + "app": { + "id": "app-1", + "name": "App", + "mode": "chat", + "icon": "icon.png", + "icon_type": "emoji", + "icon_background": "#fff", + }, + "app_id": "app-1", + "description": "desc", + "category": "cat", + "position": 1, + "is_listed": True, + "can_trial": False, + } + ], + "categories": ["cat"], + } + ).model_dump(mode="json") + + assert response["recommended_apps"][0]["app_id"] == "app-1" + assert response["categories"] == ["cat"] diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 04beb31389c..3625056af98 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -26,6 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode @@ -94,7 +94,7 @@ class TestTrialAppWorkflowRunApi: with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(MagicMock(mode=AppMode.CHAT)) + method(api, MagicMock(mode=AppMode.CHAT)) def test_success(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -106,7 +106,7 @@ class TestTrialAppWorkflowRunApi: patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), patch.object(module.RecommendedAppService, "add_trial_app_record"), ): - result = method(trial_app_workflow) + result = method(api, trial_app_workflow) assert result is not None @@ -124,7 +124,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_quota_exceeded(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -140,7 +140,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_model_not_support(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -156,7 +156,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_invoke_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -172,7 +172,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(CompletionRequestError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_rate_limit_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -188,7 +188,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InvokeRateLimitHttpError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_value_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -204,7 +204,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ValueError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_generic_exception(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -220,7 +220,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InternalServerError): - method(trial_app_workflow) + method(api, trial_app_workflow) class TestTrialChatApi: @@ -566,7 +566,7 @@ class TestTrialMessageSuggestedQuestionApi: with app.test_request_context("/"): with pytest.raises(NotChatAppError): - method(api, MagicMock(mode="completion"), str(uuid4())) + method(MagicMock(mode="completion"), str(uuid4())) def test_success(self, app, trial_app_chat, account): api = module.TrialMessageSuggestedQuestionApi() @@ -581,7 +581,7 @@ class TestTrialMessageSuggestedQuestionApi: return_value=["q1", "q2"], ), ): - result = method(api, trial_app_chat, str(uuid4())) + result = method(trial_app_chat, str(uuid4())) assert result == {"data": ["q1", "q2"]} @@ -599,7 +599,7 @@ class TestTrialMessageSuggestedQuestionApi: ), ): with pytest.raises(NotFound): - method(api, trial_app_chat, str(uuid4())) + method(trial_app_chat, str(uuid4())) class TestTrialAppParameterApi: @@ -931,7 +931,7 @@ class TestTrialAppWorkflowTaskStopApi: with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(trial_app_chat, str(uuid4())) + method(api, trial_app_chat, str(uuid4())) def test_success(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowTaskStopApi() @@ -944,7 +944,7 @@ class TestTrialAppWorkflowTaskStopApi: patch.object(module.AppQueueManager, "set_stop_flag_no_user_check") as mock_set_flag, patch.object(module.GraphEngineManager, "send_stop_command") as mock_send_cmd, ): - result = method(trial_app_workflow, task_id) + result = method(api, trial_app_workflow, task_id) assert result == {"result": "success"} mock_set_flag.assert_called_once_with(task_id) diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index e89b89c8b11..2be5a21f289 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -1,9 +1,11 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, PropertyMock, patch import pytest from flask import Flask from werkzeug.exceptions import Forbidden +import controllers.console.tag.tags as module from controllers.console import console_ns from controllers.console.tag.tags import ( TagBindingCreateApi, @@ -83,13 +85,20 @@ class TestTagListApi: ), patch( "controllers.console.tag.tags.TagService.get_tags", - return_value=[{"id": "1", "name": "tag"}], + return_value=[ + SimpleNamespace( + id="1", + name="tag", + type=TagType.KNOWLEDGE, + binding_count=1, + ) + ], ), ): result, status = method(api) assert status == 200 - assert isinstance(result, list) + assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}] def test_post_success(self, app, admin_user, tag, payload_patch): api = TagListApi() @@ -113,6 +122,7 @@ class TestTagListApi: assert status == 200 assert result["name"] == "test-tag" + assert result["binding_count"] == "0" def test_post_forbidden(self, app, readonly_user, payload_patch): api = TagListApi() @@ -158,7 +168,7 @@ class TestTagUpdateDeleteApi: result, status = method(api, "tag-1") assert status == 200 - assert result["binding_count"] == 3 + assert result["binding_count"] == "3" def test_patch_forbidden(self, app, readonly_user, payload_patch): api = TagUpdateDeleteApi() @@ -277,3 +287,13 @@ class TestTagBindingDeleteApi: ): with pytest.raises(Forbidden): method(api) + + +class TestTagResponseModel: + def test_tag_response_normalizes_enum_type(self): + payload = module.TagResponse.model_validate( + {"id": "tag-1", "name": "tag", "type": TagType.KNOWLEDGE, "binding_count": 1} + ).model_dump(mode="json") + + assert payload["type"] == "knowledge" + assert payload["binding_count"] == "1" diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 9afc1c41661..26ff264f18c 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -11,7 +11,7 @@ from controllers.console.workspace.account import ( ChangeEmailSendEmailApi, CheckEmailUnique, ) -from models import Account +from models import Account, AccountStatus from services.account_service import AccountService @@ -20,7 +20,7 @@ def app(): app = Flask(__name__) app.config["TESTING"] = True app.config["RESTX_MASK_HEADER"] = "X-Fields" - app.login_manager = SimpleNamespace(_load_user=lambda: None) + app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None) return app @@ -33,7 +33,7 @@ def _build_account(email: str, account_id: str = "acc", tenant: object | None = account = Account(name=account_id, email=email) account.email = email account.id = account_id - account.status = "active" + account.status = AccountStatus.ACTIVE account._current_tenant = tenant_obj return account @@ -68,7 +68,10 @@ class TestChangeEmailSend: mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) - mock_get_change_data.return_value = {"email": "current@example.com"} + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + } mock_send_email.return_value = "token-abc" with app.test_request_context( @@ -85,12 +88,55 @@ class TestChangeEmailSend: email="new@example.com", old_email="current@example.com", language="en-US", - phase="new_email", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, ) mock_extract_ip.assert_called_once() mock_is_ip_limit.assert_called_once_with("127.0.0.1") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_change_data, + mock_current_account, + mock_db, + app, + ): + """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("current@example.com", "acc1") + mock_current_account.return_value = (mock_account, None) + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailSendEmailApi().post() + + mock_send_email.assert_not_called() + class TestChangeEmailValidity: @patch("controllers.console.wraps.db") @@ -122,7 +168,12 @@ class TestChangeEmailValidity: mock_account = _build_account("user@example.com", "acc2") mock_current_account.return_value = (mock_account, None) mock_is_rate_limit.return_value = False - mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"} + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } mock_generate_token.return_value = (None, "new-token") with app.test_request_context( @@ -138,11 +189,169 @@ class TestChangeEmailValidity: mock_add_rate.assert_not_called() mock_revoke_token.assert_called_once_with("token-123") mock_generate_token.assert_called_once_with( - "user@example.com", code="1234", old_email="old@example.com", additional_data={} + "user@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + }, ) mock_reset_rate.assert_called_once_with("user@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_upgrade_new_phase_token_to_new_verified( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "new@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + } + mock_generate_token.return_value = (None, "new-verified-token") + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "new@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailCheckApi().post() + + assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"} + mock_generate_token.assert_called_once_with( + "new@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + }, + ) + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_phase_is_unknown( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + """A token whose phase marker is a string but not a known transition must be rejected.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_has_no_phase( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app, + ): + """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + class TestChangeEmailReset: @patch("controllers.console.wraps.db") @@ -175,7 +384,11 @@ class TestChangeEmailReset: mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True - mock_get_data.return_value = {"old_email": "OLD@example.com"} + mock_get_data.return_value = { + "email": "new@example.com", + "old_email": "OLD@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } mock_account_after_update = _build_account("new@example.com", "acc3-updated") mock_update_account.return_value = mock_account_after_update @@ -194,6 +407,155 @@ class TestChangeEmailReset: mock_send_notify.assert_called_once_with(email="new@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_phase_is_not_new_verified( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + # Simulate a token straight out of step #1 (phase=old_email) — exactly + # the replay used in the advisory PoC. + mock_get_data.return_value = { + "email": "old@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-from-step1"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_email_differs_from_payload_new_email( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app, + ): + """A verified token for address A must not be replayed to change to address B.""" + from controllers.console.auth.error import InvalidTokenError + + _mock_wraps_db(mock_db) + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + mock_get_data.return_value = { + "email": "verified@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-verified"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + +class TestAccountServiceSendChangeEmailEmail: + """Service-level coverage for the phase-bound changes in `send_change_email_email`.""" + + def test_should_raise_value_error_for_invalid_phase(self): + with pytest.raises(ValueError, match="phase must be one of"): + AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + phase="old_email_verified", + ) + + @patch("services.account_service.send_change_mail_task") + @patch("services.account_service.AccountService.change_email_rate_limiter") + @patch("services.account_service.AccountService.generate_change_email_token") + def test_should_stamp_phase_into_generated_token( + self, + mock_generate_token, + mock_rate_limiter, + mock_mail_task, + ): + mock_rate_limiter.is_rate_limited.return_value = False + mock_generate_token.return_value = ("123456", "the-token") + + returned = AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + language="en-US", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, + ) + + assert returned == "the-token" + mock_generate_token.assert_called_once_with( + "user@example.com", + None, + old_email="user@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + }, + ) + mock_mail_task.delay.assert_called_once() + mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com") + class TestAccountDeletionFeedback: @patch("controllers.console.wraps.db") @@ -233,15 +595,20 @@ class TestCheckEmailUnique: def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): - session = MagicMock() + mock_session = MagicMock() first = MagicMock() first.scalar_one_or_none.return_value = None second = MagicMock() expected_account = MagicMock() second.scalar_one_or_none.return_value = expected_account - session.execute.side_effect = [first, second] + mock_session.execute.side_effect = [first, second] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + mock_factory = MagicMock() + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + + with patch("services.account_service.session_factory", mock_factory): + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account - assert session.execute.call_count == 2 + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index 368892b9228..239fec84308 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -12,7 +12,7 @@ from models.account import Account, TenantAccountRole def app(): flask_app = Flask(__name__) flask_app.config["TESTING"] = True - flask_app.login_manager = SimpleNamespace(_load_user=lambda: None) + flask_app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None) return flask_app diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index 9c42ee9529a..b2f949c6e2b 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,9 +11,10 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView +from werkzeug.exceptions import Forbidden + from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from werkzeug.exceptions import Forbidden if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index fb9eec98cb5..168479af1ea 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -14,6 +13,7 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index c829327bc7a..f0d32f81fb5 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -2,8 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from controllers.console.workspace.models import ( DefaultModelApi, @@ -16,6 +14,8 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index b2d13dbbdf3..e82a29f0454 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -18,6 +18,7 @@ from controllers.console.workspace.workspace import ( CustomConfigWorkspaceApi, SwitchWorkspaceApi, TenantApi, + TenantInfoResponse, TenantListApi, WebappLogoWorkspaceApi, WorkspaceInfoApi, @@ -435,6 +436,23 @@ class TestTenantApi: assert status == 200 +class TestTenantInfoResponse: + def test_tenant_info_response_normalizes_enum_and_datetime(self): + created_at = naive_utc_now() + payload = TenantInfoResponse.model_validate( + { + "id": "t1", + "status": TenantStatus.NORMAL, + "plan": CloudPlan.TEAM, + "created_at": created_at, + } + ).model_dump(mode="json") + + assert payload["status"] == "normal" + assert payload["plan"] == "team" + assert payload["created_at"] == int(created_at.timestamp()) + + class TestSwitchWorkspaceApi: def test_switch_success(self, app): api = SwitchWorkspaceApi() diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 4a5f91cc5d5..71381e6a2b4 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -18,6 +18,7 @@ from controllers.inner_api.app.dsl import ( InnerAppDSLImportPayload, _get_active_account, ) +from models.account import AccountStatus from services.app_dsl_service import ImportStatus @@ -63,7 +64,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_active_account(self, mock_db): mock_account = MagicMock() - mock_account.status = "active" + mock_account.status = AccountStatus.ACTIVE mock_db.session.scalar.return_value = mock_account result = _get_active_account("user@example.com") @@ -74,7 +75,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_inactive_account(self, mock_db): mock_account = MagicMock() - mock_account.status = "banned" + mock_account.status = AccountStatus.BANNED mock_db.session.scalar.return_value = mock_account result = _get_active_account("banned@example.com") @@ -103,13 +104,15 @@ class TestEnterpriseAppDSLImport: @pytest.fixture def _mock_import_deps(self): """Patch db, Session, and AppDslService for import handler tests.""" + mock_session = MagicMock() + mock_session.__enter__ = MagicMock(return_value=mock_session) + mock_session.__exit__ = MagicMock(return_value=False) with ( patch("controllers.inner_api.app.dsl.db"), - patch("controllers.inner_api.app.dsl.Session") as mock_session, + patch("controllers.inner_api.app.dsl.Session", return_value=mock_session), patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, ): - mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock()) - mock_session.return_value.__exit__ = MagicMock(return_value=False) + self._mock_session = mock_session self._mock_dsl = MagicMock() mock_dsl_cls.return_value = self._mock_dsl yield @@ -145,6 +148,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 200 assert body["status"] == "completed" mock_account.set_tenant_id.assert_called_once_with("ws-123") + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -160,6 +165,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 202 assert body["status"] == "pending" + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -175,6 +182,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 400 assert body["status"] == "failed" + self._mock_session.rollback.assert_called_once_with() + self._mock_session.commit.assert_not_called() @patch("controllers.inner_api.app.dsl._get_active_account") def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask): diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index 957d7fbd9be..d1b09c3a584 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -2,6 +2,7 @@ Unit tests for inner_api plugin decorators """ +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -40,17 +41,22 @@ class TestTenantUserPayload: class TestGetUser: """Test get_user function""" + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_return_existing_user_by_id( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test returning existing user when found by ID""" # Arrange mock_user = MagicMock() mock_user.id = "user123" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -58,13 +64,45 @@ class TestGetUser: # Assert assert result == mock_user - mock_session.get.assert_called_once() + mock_session.scalar.assert_called_once() + @patch("controllers.inner_api.plugin.wraps.select") + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_not_resolve_non_anonymous_users_across_tenants( + self, + mock_db, + mock_sessionmaker, + mock_enduser_class, + mock_select, + app: Flask, + ): + """Test that explicit user IDs remain scoped to the current tenant.""" + # Arrange + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + mock_new_user = MagicMock() + mock_new_user.tenant_id = "tenant-current" + mock_enduser_class.return_value = mock_new_user + + # Act + with app.app_context(): + result = get_user("tenant-current", "foreign-user-id") + + # Assert + assert result == mock_new_user + mock_session.get.assert_not_called() + mock_session.scalar.assert_called_once() + mock_session.add.assert_called_once_with(mock_new_user) + + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") def test_should_return_existing_anonymous_user_by_session_id( - self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask ): """Test returning existing anonymous user by session_id""" # Arrange @@ -72,8 +110,9 @@ class TestGetUser: mock_user.session_id = "anonymous_session" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - # non-anonymous path uses session.get(); anonymous uses session.scalar() - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -82,17 +121,22 @@ class TestGetUser: # Assert assert result == mock_user + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_create_new_user_when_not_found( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test creating new user when not found in database""" # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = None + mock_session.scalar.return_value = None mock_new_user = MagicMock() mock_enduser_class.return_value = mock_new_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -133,7 +177,7 @@ class TestGetUser: # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.side_effect = Exception("Database error") + mock_session.scalar.side_effect = Exception("Database error") # Act & Assert with app.app_context(): @@ -232,11 +276,11 @@ class TestGetUserTenant: class PluginTestPayload: """Simple test payload class""" - def __init__(self, data: dict): + def __init__(self, data: dict[str, Any]): self.value = data.get("value") @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): return cls(data) @@ -277,7 +321,7 @@ class TestPluginData: # Arrange class InvalidPayload: @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): raise Exception("Validation failed") @plugin_data(payload_type=InvalidPayload) diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py index 56a8f949631..7d2193adc69 100644 --- a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -20,6 +20,7 @@ from controllers.inner_api.workspace.workspace import ( WorkspaceCreatePayload, WorkspaceOwnerlessPayload, ) +from models.account import TenantStatus class TestWorkspaceCreatePayload: @@ -98,7 +99,7 @@ class TestEnterpriseWorkspace: mock_tenant.id = "tenant-id" mock_tenant.name = "My Workspace" mock_tenant.plan = "sandbox" - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_tenant.created_at = now mock_tenant.updated_at = now mock_tenant_svc.create_tenant.return_value = mock_tenant @@ -162,7 +163,7 @@ class TestEnterpriseWorkspaceNoOwnerEmail: mock_tenant.name = "My Workspace" mock_tenant.encrypt_public_key = "pub-key" mock_tenant.plan = "sandbox" - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_tenant.custom_config = None mock_tenant.created_at = now mock_tenant.updated_at = now diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index 1507bf7a5fc..f48ace427d8 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -10,6 +10,7 @@ from flask import Flask from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameterApi from controllers.service_api.app.error import AppUnavailableError +from models.account import TenantStatus from models.model import App, AppMode from tests.unit_tests.conftest import setup_mock_tenant_account_query @@ -62,7 +63,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL # Mock DB queries for app and tenant mock_db.session.get.side_effect = [ @@ -110,7 +111,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -151,7 +152,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -190,7 +191,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -253,7 +254,7 @@ class TestAppMetaApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -321,7 +322,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -378,7 +379,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -424,7 +425,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -476,7 +477,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 5a8cb4619f4..c16ebad7399 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,7 +13,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -30,6 +29,7 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( @@ -95,30 +95,6 @@ class TestTextToAudioPayload: assert payload.streaming is True -# --------------------------------------------------------------------------- -# AudioService Interface Tests -# --------------------------------------------------------------------------- - - -class TestAudioServiceInterface: - """Test AudioService method interfaces exist.""" - - def test_transcript_asr_method_exists(self): - """Test that AudioService.transcript_asr exists.""" - assert hasattr(AudioService, "transcript_asr") - assert callable(AudioService.transcript_asr) - - def test_transcript_tts_method_exists(self): - """Test that AudioService.transcript_tts exists.""" - assert hasattr(AudioService, "transcript_tts") - assert callable(AudioService.transcript_tts) - - -# --------------------------------------------------------------------------- -# Audio Service Tests -# --------------------------------------------------------------------------- - - class TestAudioServiceInterface: """Test suite for AudioService interface methods.""" diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 57681d8f5bc..3364c07e623 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,7 +16,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -35,6 +34,7 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index dbd06677d8a..4fb8ecf7844 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -15,6 +15,7 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch @@ -29,11 +30,16 @@ from controllers.service_api.app.conversation import ( ConversationRenameApi, ConversationRenamePayload, ConversationVariableDetailApi, + ConversationVariableInfiniteScrollPaginationResponse, + ConversationVariableResponse, ConversationVariablesApi, ConversationVariablesQuery, ConversationVariableUpdatePayload, ) from controllers.service_api.app.error import NotChatAppError +from fields._value_type_serializer import serialize_value_type +from graphon.variables import StringSegment +from graphon.variables.types import SegmentType from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService from services.errors.conversation import ( @@ -261,6 +267,72 @@ class TestConversationVariableUpdatePayload: assert payload.value == nested +class TestConversationVariableResponseModels: + def test_variable_response_normalizes_value_type_and_timestamps(self): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "description": "desc", + "created_at": created_at, + "updated_at": created_at, + } + ) + assert response.value_type == "number" + assert response.value == "1" + assert response.created_at == int(created_at.timestamp()) + assert response.updated_at == int(created_at.timestamp()) + + def test_variable_response_normalizes_string_value_type_alias(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER.value, + } + ) + + assert response.value_type == "number" + + def test_variable_response_normalizes_callable_exposed_type(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()), + } + ) + + assert response.value_type == "string" + + def test_serialize_value_type_supports_segments_and_mappings(self): + assert serialize_value_type(StringSegment(value="hello")) == "string" + assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number" + + def test_variable_pagination_response(self): + response = ConversationVariableInfiniteScrollPaginationResponse.model_validate( + { + "limit": 1, + "has_more": False, + "data": [ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": "string", + "value": "bar", + } + ], + } + ) + assert response.limit == 1 + assert response.has_more is False + assert len(response.data) == 1 + assert response.data[0].name == "foo" + + class TestConversationAppModeValidation: """Test app mode validation for conversation endpoints.""" @@ -549,6 +621,44 @@ class TestConversationVariablesApiController: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "get_conversational_variable", + lambda *_args, **_kwargs: SimpleNamespace( + limit=1, + has_more=False, + data=[ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + } + ], + ), + ) + + api = ConversationVariablesApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables?limit=20", + method="GET", + ): + result = handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + + assert result["limit"] == 1 + assert result["has_more"] is False + assert result["data"][0]["value_type"] == "number" + assert result["data"][0]["value"] == "1" + assert result["data"][0]["created_at"] == int(created_at.timestamp()) + class TestConversationVariableDetailApiController: def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -602,3 +712,41 @@ class TestConversationVariableDetailApiController: c_id="00000000-0000-0000-0000-000000000001", variable_id="00000000-0000-0000-0000-000000000002", ) + + def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "update_conversation_variable", + lambda *_args, **_kwargs: { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + }, + ) + + api = ConversationVariableDetailApi() + handler = _unwrap(api.put) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables/2", + method="PUT", + json={"value": 1}, + ): + result = handler( + api, + app_model=app_model, + end_user=end_user, + c_id="00000000-0000-0000-0000-000000000001", + variable_id="00000000-0000-0000-0000-000000000002", + ) + + assert result["id"] == "550e8400-e29b-41d4-a716-446655440000" + assert result["value_type"] == "number" + assert result["value"] == "1" + assert result["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index cfa21bf2dd6..da09ec13cea 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -15,11 +15,11 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -36,6 +36,7 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -43,6 +44,22 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService +def _make_mock_workflow_run(run_id: str = "run-1"): + run = Mock() + run.id = run_id + run.workflow_id = "wf-1" + run.status = WorkflowExecutionStatus.SUCCEEDED + run.inputs = {"input": "value"} + run.outputs_dict = {"output": "value"} + run.error = None + run.total_steps = 1 + run.total_tokens = 10 + run.created_at = datetime(2026, 1, 1, tzinfo=UTC) + run.finished_at = datetime(2026, 1, 1, tzinfo=UTC) + run.elapsed_time = 0.1 + return run + + class TestWorkflowRunPayload: """Test suite for WorkflowRunPayload Pydantic model.""" @@ -359,7 +376,7 @@ class TestWorkflowRunDetailApi: handler(api, app_model=app_model, workflow_run_id="run") def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None: - run = SimpleNamespace(id="run") + run = _make_mock_workflow_run(run_id="run") repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run) workflow_module = sys.modules["controllers.service_api.app.workflow"] monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) @@ -373,7 +390,10 @@ class TestWorkflowRunDetailApi: handler = _unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1") - assert handler(api, app_model=app_model, workflow_run_id="run") == run + result = handler(api, app_model=app_model, workflow_run_id="run") + assert result["id"] == "run" + assert result["workflow_id"] == "wf-1" + assert result["status"] == "succeeded" class TestWorkflowRunApi: @@ -490,7 +510,7 @@ class TestWorkflowAppLogApi: monkeypatch.setattr( WorkflowAppService, "get_paginate_workflow_app_logs", - lambda *_args, **_kwargs: {"items": [], "total": 0}, + lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}, ) api = WorkflowAppLogApi() @@ -500,7 +520,7 @@ class TestWorkflowAppLogApi: with app.test_request_context("/workflows/logs", method="GET"): response = handler(api, app_model=app_model) - assert response == {"items": [], "total": 0} + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} # ============================================================================= @@ -527,9 +547,8 @@ def mock_workflow_app(): class TestWorkflowRunDetailApiGet: """Test suite for WorkflowRunDetailApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``) - and ``@service_api_ns.marshal_with``. We call the unwrapped method - directly; ``marshal_with`` is a no-op when calling directly. + ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``), + and we call the unwrapped method directly in tests. """ @patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory") @@ -542,9 +561,7 @@ class TestWorkflowRunDetailApiGet: mock_workflow_app, ): """Test successful workflow run detail retrieval.""" - mock_run = Mock() - mock_run.id = "run-1" - mock_run.status = "succeeded" + mock_run = _make_mock_workflow_run(run_id="run-1") mock_repo = Mock() mock_repo.get_workflow_run_by_id.return_value = mock_run mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo @@ -558,7 +575,8 @@ class TestWorkflowRunDetailApiGet: api = WorkflowRunDetailApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id) - assert result == mock_run + assert result["id"] == mock_run.id + assert result["status"] == "succeeded" @patch("controllers.service_api.app.workflow.db") def test_get_workflow_run_wrong_app_mode(self, mock_db, app): @@ -622,8 +640,7 @@ class TestWorkflowTaskStopApiPost: class TestWorkflowAppLogApiGet: """Test suite for WorkflowAppLogApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` and - ``@service_api_ns.marshal_with``. + ``get`` is wrapped by ``@validate_app_token``. """ @patch("controllers.service_api.app.workflow.WorkflowAppService") @@ -637,6 +654,10 @@ class TestWorkflowAppLogApiGet: ): """Test successful workflow log retrieval.""" mock_pagination = Mock() + mock_pagination.page = 1 + mock_pagination.limit = 20 + mock_pagination.total = 0 + mock_pagination.has_more = False mock_pagination.data = [] mock_svc_instance = Mock() mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination @@ -661,4 +682,4 @@ class TestWorkflowAppLogApiGet: api = WorkflowAppLogApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app) - assert result == mock_pagination + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index 4b8e3a738cb..eda270258d5 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,8 +1,7 @@ from types import SimpleNamespace -from graphon.enums import WorkflowExecutionStatus - from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField +from graphon.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 12d5e7345d2..288659b192d 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -22,6 +22,8 @@ import pytest from werkzeug.exceptions import Forbidden, NotFound from controllers.service_api.dataset.document import ( + DeprecatedDocumentAddByTextApi, + DeprecatedDocumentUpdateByTextApi, DocumentAddByFileApi, DocumentAddByTextApi, DocumentApi, @@ -1005,7 +1007,7 @@ class TestDocumentAddByTextApi: # Act with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={ "name": "Test Document", @@ -1037,7 +1039,7 @@ class TestDocumentAddByTextApi: # Act & Assert with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={"name": "Test Document", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, @@ -1066,7 +1068,7 @@ class TestDocumentAddByTextApi: # Act & Assert with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={"name": "Test Document", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, @@ -1093,6 +1095,20 @@ class TestArchivedDocumentImmutableError: assert error.code == 403 +class TestDocumentTextRouteDeprecation: + """Test that legacy underscore text routes stay marked deprecated.""" + + def test_create_by_text_legacy_alias_is_deprecated(self): + """Ensure only the legacy create-by-text alias is marked deprecated.""" + assert DeprecatedDocumentAddByTextApi.post.__apidoc__["deprecated"] is True + assert DocumentAddByTextApi.post.__apidoc__.get("deprecated") is not True + + def test_update_by_text_legacy_alias_is_deprecated(self): + """Ensure only the legacy update-by-text alias is marked deprecated.""" + assert DeprecatedDocumentUpdateByTextApi.post.__apidoc__["deprecated"] is True + assert DocumentUpdateByTextApi.post.__apidoc__.get("deprecated") is not True + + # ============================================================================= # Endpoint tests for DocumentUpdateByTextApi, DocumentAddByFileApi, # DocumentUpdateByFileApi. @@ -1162,7 +1178,7 @@ class TestDocumentUpdateByTextApiPost: doc_id = str(uuid.uuid4()) with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text", + f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text", method="POST", json={"name": "Updated Doc", "text": "New content"}, headers={"Authorization": "Bearer test_token"}, @@ -1195,7 +1211,7 @@ class TestDocumentUpdateByTextApiPost: doc_id = str(uuid.uuid4()) with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text", + f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text", method="POST", json={"name": "Doc", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py deleted file mode 100644 index c0b40d070a5..00000000000 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -Unit tests for Service API Site controller -""" - -import uuid -from unittest.mock import Mock, patch - -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.service_api.app.site import AppSiteApi -from models.account import TenantStatus -from models.model import App, Site -from tests.unit_tests.conftest import setup_mock_tenant_account_query - - -class TestAppSiteApi: - """Test suite for AppSiteApi""" - - @pytest.fixture - def mock_app_model(self): - """Create a mock App model with tenant.""" - app = Mock(spec=App) - app.id = str(uuid.uuid4()) - app.tenant_id = str(uuid.uuid4()) - app.status = "normal" - app.enable_api = True - - mock_tenant = Mock() - mock_tenant.id = app.tenant_id - mock_tenant.status = TenantStatus.NORMAL - app.tenant = mock_tenant - - return app - - @pytest.fixture - def mock_site(self): - """Create a mock Site model.""" - site = Mock(spec=Site) - site.id = str(uuid.uuid4()) - site.app_id = str(uuid.uuid4()) - site.title = "Test Site" - site.icon = "icon-url" - site.icon_background = "#ffffff" - site.description = "Site description" - site.copyright = "Copyright 2024" - site.privacy_policy = "Privacy policy text" - site.custom_disclaimer = "Custom disclaimer" - site.default_language = "en-US" - site.prompt_public = True - site.show_workflow_steps = True - site.use_icon_as_answer_icon = False - site.chat_color_theme = "light" - site.chat_color_theme_inverted = False - site.icon_type = "image" - site.created_at = "2024-01-01T00:00:00" - site.updated_at = "2024-01-01T00:00:00" - return site - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_success( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test successful retrieval of site configuration.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - # Mock wraps.db for authentication - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site.db for site query - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - response = api.get() - - # Assert - assert response["title"] == "Test Site" - assert response["icon"] == "icon-url" - assert response["description"] == "Site description" - mock_db.session.scalar.assert_called_once() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_not_found( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - ): - """Test that Forbidden is raised when site is not found.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query to return None - mock_db.session.scalar.return_value = None - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_tenant_archived( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test that Forbidden is raised when tenant is archived.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query - mock_db.session.scalar.return_value = mock_site - - # Set tenant status to archived AFTER authentication - mock_app_model.tenant.status = TenantStatus.ARCHIVE - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_queries_by_app_id( - self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model - ): - """Test that site is queried using the app model's id.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - mock_site = Mock(spec=Site) - mock_site.id = str(uuid.uuid4()) - mock_site.app_id = mock_app_model.id - mock_site.title = "Test Site" - mock_site.icon = "icon-url" - mock_site.icon_background = "#ffffff" - mock_site.description = "Site description" - mock_site.copyright = "Copyright 2024" - mock_site.privacy_policy = "Privacy policy text" - mock_site.custom_disclaimer = "Custom disclaimer" - mock_site.default_language = "en-US" - mock_site.prompt_public = True - mock_site.show_workflow_steps = True - mock_site.use_icon_as_answer_icon = False - mock_site.chat_color_theme = "light" - mock_site.chat_color_theme_inverted = False - mock_site.icon_type = "image" - mock_site.created_at = "2024-01-01T00:00:00" - mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - api.get() - - # Assert - # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index cbfc8fa6130..a6ca441801b 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.audio import AudioApi, TextApi from controllers.web.error import ( @@ -22,6 +21,7 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index 49039d03fe1..4f8d848637d 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -7,7 +7,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from controllers.web.error import ( @@ -19,6 +18,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/controllers/web/test_message_endpoints.py b/api/tests/unit_tests/controllers/web/test_message_endpoints.py index 89ab93d8d4d..da88b109a81 100644 --- a/api/tests/unit_tests/controllers/web/test_message_endpoints.py +++ b/api/tests/unit_tests/controllers/web/test_message_endpoints.py @@ -129,12 +129,6 @@ class TestMessageSuggestedQuestionApi: with pytest.raises(NotChatAppError): MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) - def test_wrong_mode_raises(self, app: Flask) -> None: - msg_id = uuid4() - with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): - with pytest.raises(NotChatAppError): - MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) - @patch("controllers.web.message.MessageService.get_suggested_questions_after_answer") def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None: msg_id = uuid4() diff --git a/api/tests/unit_tests/controllers/web/test_pydantic_models.py b/api/tests/unit_tests/controllers/web/test_pydantic_models.py index dcf81337123..bceb65b89f6 100644 --- a/api/tests/unit_tests/controllers/web/test_pydantic_models.py +++ b/api/tests/unit_tests/controllers/web/test_pydantic_models.py @@ -198,7 +198,7 @@ class TestMessageListQuery: assert q.limit == 20 def test_invalid_conversation_id(self) -> None: - with pytest.raises(ValidationError, match="not a valid uuid"): + with pytest.raises(ValidationError, match="must be a valid UUID"): MessageListQuery(conversation_id="bad") def test_limit_bounds(self) -> None: @@ -216,7 +216,7 @@ class TestMessageListQuery: def test_invalid_first_id(self) -> None: cid = str(uuid4()) - with pytest.raises(ValidationError, match="not a valid uuid"): + with pytest.raises(ValidationError, match="must be a valid UUID"): MessageListQuery(conversation_id=cid, first_id="invalid") diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index 0661c02578c..a01587d64a1 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -4,9 +4,12 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from jwt import InvalidTokenError +from werkzeug.exceptions import Unauthorized import services.errors.account from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi +from services.entities.auth_entities import LoginFailureReason def encode_code(code: str) -> str: @@ -115,13 +118,18 @@ class TestLoginApi: def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None: from controllers.console.error import AccountBannedError - with app.test_request_context( - "/web/login", - method="POST", - json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, - ): - with pytest.raises(AccountBannedError): - LoginApi().post() + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AccountBannedError): + LoginApi().post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED @patch( "controllers.web.login.WebAppAuthService.authenticate", @@ -130,13 +138,87 @@ class TestLoginApi: def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None: from controllers.console.auth.error import AuthenticationFailedError - with app.test_request_context( - "/web/login", - method="POST", - json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, - ): - with pytest.raises(AuthenticationFailedError): - LoginApi().post() + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AuthenticationFailedError): + LoginApi().post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS + + @patch( + "controllers.web.login.WebAppAuthService.authenticate", + side_effect=services.errors.account.AccountNotFoundError(), + ) + def test_login_account_not_found(self, mock_auth: MagicMock, app: Flask) -> None: + from controllers.console.auth.error import AuthenticationFailedError + + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "missing@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AuthenticationFailedError): + LoginApi().post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "missing@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_NOT_FOUND + + @patch("controllers.web.login.WebAppAuthService.get_email_code_login_data", return_value=None) + def test_email_code_login_logs_invalid_token(self, mock_get_token_data: MagicMock, app: Flask) -> None: + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/email-code-login/validity", + method="POST", + json={"email": "user@example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(InvalidTokenError): + EmailCodeLoginApi().post() + + mock_get_token_data.assert_called_once_with("token-123") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_EMAIL_CODE_TOKEN + + @patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token") + @patch( + "controllers.web.login.WebAppAuthService.get_user_through_email", + side_effect=Unauthorized("Account is banned."), + ) + @patch( + "controllers.web.login.WebAppAuthService.get_email_code_login_data", + return_value={"email": "User@Example.com", "code": "123456"}, + ) + def test_email_code_login_logs_banned_account( + self, + mock_get_token_data: MagicMock, + mock_get_user: MagicMock, + mock_revoke_token: MagicMock, + app: Flask, + ) -> None: + from controllers.console.error import AccountBannedError + + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(AccountBannedError): + EmailCodeLoginApi().post() + + mock_get_token_data.assert_called_once_with("token-123") + mock_revoke_token.assert_called_once_with("token-123") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED class TestLoginStatusApi: diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index bc7aea0ef92..cde8820e006 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -2,11 +2,11 @@ import json from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMUsage from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError +from graphon.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index 97206019b9f..ea8cc8aa864 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -1,9 +1,9 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index defc8b4b642..2f5873d865f 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -1,6 +1,8 @@ import json import pytest + +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -8,8 +10,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner - # ----------------------------- # Fixtures # ----------------------------- diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index a44a0650eb4..17ab5babcbd 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -3,6 +3,11 @@ from typing import Any from unittest.mock import MagicMock import pytest + +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, @@ -11,11 +16,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.agent.errors import AgentMaxIterationError -from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueMessageFileEvent - # ============================== # Dummy Helper Classes # ============================== diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py index 5ee66da94ab..186b4a501d3 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -2,8 +2,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.entities.model_entities import ModelStatus @@ -12,6 +10,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey class TestModelConfigConverter: diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py index e2f3c16335f..d9fe7004ff7 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -1,9 +1,9 @@ import pytest -from graphon.variables.input_entities import VariableEntityType from core.app.app_config.easy_ui_based_app.variables.manager import ( BasicVariablesConfigManager, ) +from graphon.variables.input_entities import VariableEntityType class TestBasicVariablesConfigManagerConvert: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 8bde9c1f979..11b53dd0f97 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,8 +1,7 @@ +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager - def test_convert_with_vision(): config = { diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py index 000f83cd5a2..f2bc3076dac 100644 --- a/api/tests/unit_tests/core/app/app_config/test_entities.py +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -1,10 +1,10 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.app_config.entities import ( DatasetRetrieveConfigEntity, PromptTemplateEntity, ) +from graphon.variables.input_entities import VariableEntity, VariableEntityType class TestAppConfigEntities: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 061719d15a5..45d4b0e3212 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 -from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from factories import variable_factory +from graphon.variables import SegmentType from models import ConversationVariable, Workflow MINIMAL_GRAPH = { @@ -134,6 +134,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -150,7 +151,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -177,7 +180,6 @@ class TestAdvancedChatAppRunnerConversationVariables: # Note: Since we're mocking ConversationVariable.from_variable, # we can't directly check the id, but we can verify add_all was called assert mock_session.add_all.called, "Session add_all should have been called" - assert mock_session.commit.called, "Session commit should have been called" def test_no_variables_creates_all(self): """Test that all conversation variables are created when none exist in DB.""" @@ -278,6 +280,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -295,7 +298,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -326,7 +331,6 @@ class TestAdvancedChatAppRunnerConversationVariables: # Verify that all variables were created assert len(added_items) == 2, "Should have added both variables" assert mock_session.add_all.called, "Session add_all should have been called" - assert mock_session.commit.called, "Session commit should have been called" def test_all_variables_exist_no_changes(self): """Test that no changes are made when all variables already exist in DB.""" @@ -429,6 +433,7 @@ class TestAdvancedChatAppRunnerConversationVariables: # Patch the necessary components with ( + patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker, patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class, patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, @@ -445,7 +450,9 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class, ): # Setup mocks - mock_session_class.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_session_class.return_value.__enter__.return_value = MagicMock() mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() @@ -465,4 +472,3 @@ class TestAdvancedChatAppRunnerConversationVariables: # Verify that no variables were added assert not mock_session.add_all.called, "Session add_all should not have been called" - assert mock_session.commit.called, "Session commit should still be called" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 079df0b4e61..5d8faee8971 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -93,6 +93,16 @@ def _patch_common_run_deps(runner: AdvancedChatAppRunner): scalar=lambda *a, **k: MagicMock(), ), ), + sessionmaker=MagicMock( + return_value=MagicMock( + begin=MagicMock( + return_value=MagicMock( + __enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))), + __exit__=lambda *a, **k: False, + ), + ), + ), + ), select=MagicMock(), db=MagicMock(engine=MagicMock()), RedisChannel=MagicMock(), diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index e9fdeefee4d..f2df35d7d02 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowNodeExecutionStatus - from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, @@ -12,6 +10,7 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) +from graphon.enums import WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index a6d85989556..99a386cd45e 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -6,8 +6,6 @@ from types import SimpleNamespace from unittest import mock import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom @@ -19,6 +17,8 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 82b2e51019f..29fd63c063a 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -4,8 +4,6 @@ from contextlib import contextmanager from types import SimpleNamespace import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.advanced_chat.generate_task_pipeline import ( @@ -49,6 +47,8 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables +from graphon.enums import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 7dc43581501..80f7f94b1ac 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -1,12 +1,12 @@ import contextlib import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 08250bc3b6f..4567b354804 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -1,10 +1,10 @@ import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 68bcffb0e83..8f3c41701b3 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -2,7 +2,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_runner import ChatAppRunner @@ -10,6 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index f255d2c7df2..b3ea1a464f8 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index 4a94a2b4f1b..201923e0e40 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,11 +1,11 @@ from types import SimpleNamespace import pytest -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.workflow.system_variables import build_system_variables from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.runtime import GraphRuntimeState, VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 328cd12f12e..dd6cd0e919d 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,10 +1,9 @@ from collections.abc import Mapping, Sequence +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.variables.segments import ArrayFileSegment, FileSegment -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter - class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test class for WorkflowResponseConverter._fetch_files_from_variable_value method""" @@ -12,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: def create_test_file(self, file_id: str = "test_file_1") -> File: """Create a test File object""" return File( - id=file_id, - type=FileType.DOCUMENT, + file_id=file_id, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", filename=f"{file_id}.txt", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index bc11bf41744..1bef6f69cdd 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,13 +1,12 @@ from datetime import UTC, datetime from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index c9e146ff126..936ac37e55a 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,11 +1,10 @@ from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 0fde7565d24..b3c0eb74faa 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,8 +10,6 @@ from typing import Any from unittest.mock import Mock import pytest -from graphon.entities import WorkflowStartReason -from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -27,6 +25,8 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index 619d66085a4..aa2085177e1 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index 96af9fbdeef..f2e35f9900b 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -3,13 +3,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index 6cdcab29abe..cfe797aa764 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus - from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -12,6 +10,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 4fe82efcb33..9db83f5531e 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -1,5 +1,4 @@ import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult import core.app.apps.pipeline.pipeline_queue_manager as module from core.app.apps.base_app_queue_manager import PublishFrom @@ -14,6 +13,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) +from graphon.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index c8ae288e6fe..618c8fd76f3 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -22,11 +22,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.graph_events import GraphRunFailedEvent import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from graphon.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 6167be3bbdb..b0f8b423e1e 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -476,9 +476,8 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): - from graphon.enums import BuiltinNodeTypes - from core.app.entities.app_invoke_entities import InvokeFrom + from graphon.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index 1dee7fdab66..17de39ca99f 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -4,15 +4,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from core.app.app_config.entities import ( AdvancedChatMessageEntity, @@ -23,6 +14,15 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index a126bc85f75..6104b8d6ca2 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -4,9 +4,14 @@ from types import ModuleType, SimpleNamespace from typing import Any import graphon.nodes.human_input.entities # noqa: F401 +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow import node_factory as node_factory_module +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph @@ -25,12 +30,6 @@ from graphon.nodes.base.node import Node from graphon.nodes.end.entities import EndNodeData from graphon.nodes.start.entities import StartNodeData from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -56,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]): def version(cls) -> str: return "1" - def init_node_data(self, data): - self._node_data = _StubToolNodeData.model_validate(data) + def __init__( + self, + node_id: str, + config: _StubToolNodeData, + *, + graph_init_params, + graph_runtime_state, + **_kwargs: Any, + ) -> None: + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) def _get_error_strategy(self): return self._node_data.error_strategy @@ -90,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): - original_create_node = DifyNodeFactory.create_node + original_resolve_node_class = node_factory_module.resolve_workflow_node_class - def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) - node_data = typed_node_config["data"] - if node_data.type == BuiltinNodeTypes.TOOL: - return _StubToolNode( - id=str(typed_node_config["id"]), - config=typed_node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return original_create_node(self, typed_node_config) + def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + if node_type == BuiltinNodeTypes.TOOL: + return _StubToolNode + return original_resolve_node_class(node_type=node_type, node_version=node_version) - mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) + mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class) def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index de5bca161c7..58c7bfa4bc2 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -4,6 +4,23 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest + +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.system_variables import default_system_variables from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -24,23 +41,6 @@ from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.variables import StringVariable -from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.app.entities.queue_entities import ( - QueueAgentLogEvent, - QueueIterationCompletedEvent, - QueueLoopCompletedEvent, - QueueNodeExceptionEvent, - QueueNodeFailedEvent, - QueueNodeRetryEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowPausedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.system_variables import default_system_variables - class TestWorkflowBasedAppRunner: def test_resolve_user_from(self): diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index aa789d9ff3a..10fb2271f49 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 9e30faecf23..620a153204d 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,14 +4,14 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 8a717e1dccd..a3ab379b66d 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,11 +3,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -16,6 +11,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.account import Account from models.human_input import RecipientType diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index b768e813bd7..7dd7ffd7277 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus - from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, @@ -11,6 +9,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 29df903aa8b..1f6e7e12ef1 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,15 +2,14 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState - from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index dabd2594b43..99433478d3e 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -2,10 +2,9 @@ from __future__ import annotations from contextlib import contextmanager from types import SimpleNamespace +from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline @@ -46,6 +45,8 @@ from core.app.entities.task_entities import ( ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser @@ -610,33 +611,33 @@ class TestWorkflowGenerateTaskPipeline: def test_database_session_rolls_back_on_error(self, monkeypatch): pipeline = _make_pipeline() - calls = {"commit": 0, "rollback": 0} - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs + calls = {"enter": 0, "exit_exc": None} + class _BeginContext: def __enter__(self): - return self + calls["enter"] += 1 + return MagicMock() def __exit__(self, exc_type, exc, tb): + calls["exit_exc"] = exc_type return False - def commit(self): - calls["commit"] += 1 + class _Sessionmaker: + def __init__(self, *args, **kwargs): + pass - def rollback(self): - calls["rollback"] += 1 + def begin(self): + return _BeginContext() - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.sessionmaker", _Sessionmaker) monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) with pytest.raises(RuntimeError, match="db error"): with pipeline._database_session(): raise RuntimeError("db error") - assert calls["commit"] == 0 - assert calls["rollback"] == 1 + assert calls["enter"] == 1 + assert calls["exit_exc"] is RuntimeError def test_node_retry_and_started_handlers_cover_none_and_value(self): pipeline = _make_pipeline() diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py index 014a0cba729..7c797806411 100644 --- a/api/tests/unit_tests/core/app/entities/test_task_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -1,11 +1,10 @@ -from graphon.enums import WorkflowNodeExecutionStatus - from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, StreamEvent, ) +from graphon.enums import WorkflowNodeExecutionStatus class TestTaskEntities: diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index a78c1b428fb..ba55e8f6950 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,6 +1,9 @@ from collections.abc import Sequence from unittest.mock import Mock +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.command_channels import CommandChannel from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent @@ -8,10 +11,6 @@ from graphon.node_events import NodeRunResult from graphon.runtime import ReadOnlyGraphRuntimeState from graphon.variables import StringVariable from graphon.variables.segments import Segment, StringSegment - -from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035e64325bb..539944d683d 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,16 @@ from time import time from unittest.mock import Mock import pytest + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import ( + PauseStatePersistenceLayer, + WorkflowResumptionContext, + _AdvancedChatAppGenerateEntityWrapper, + _WorkflowGenerateEntityWrapper, +) +from core.workflow.system_variables import SystemVariableKey from graphon.entities.pause_reason import SchedulingPause from graphon.graph_engine.entities.commands import GraphEngineCommand from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -15,16 +25,6 @@ from graphon.graph_events import ( ) from graphon.runtime import ReadOnlyVariablePool from graphon.variables.segments import Segment - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import ( - PauseStatePersistenceLayer, - WorkflowResumptionContext, - _AdvancedChatAppGenerateEntityWrapper, - _WorkflowGenerateEntityWrapper, -) -from core.workflow.system_variables import SystemVariableKey from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py index 95931f4f8ba..12d49be0f1a 100644 --- a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -1,6 +1,5 @@ -from graphon.graph_events import GraphRunPausedEvent - from core.app.layers.suspend_layer import SuspendLayer +from graphon.graph_events import GraphRunPausedEvent class TestSuspendLayer: diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py index 7cf6eb4f310..1ac9a4d8c0c 100644 --- a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -1,8 +1,7 @@ from unittest.mock import Mock, patch -from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand - from core.app.layers.timeslice_layer import TimeSliceLayer +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import SchedulerCommand diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index aa9285789b3..d3bd15b6f34 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -2,11 +2,10 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import Mock, patch -from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent -from graphon.runtime import VariablePool - from core.app.layers.trigger_post_layer import TriggerPostLayer from core.workflow.system_variables import build_system_variables +from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool from models.enums import WorkflowTriggerStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py index 58aa7d74782..c246f7b7836 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.entities.queue_entities import QueueErrorEvent from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 4aaa10a81af..1c1bf391d3e 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,8 +2,6 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -28,6 +26,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index f7e7b7e20ef..a20d89d8077 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -5,9 +5,6 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.file import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from core.app.app_config.entities import ( AppAdditionalFeatures, @@ -41,6 +38,9 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AudioTrunk +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from models.model import AppMode @@ -505,13 +505,7 @@ class TestEasyUiBasedGenerateTaskPipeline: def __exit__(self, exc_type, exc, tb): return False - def query(self, *args, **kwargs): - return self - - def where(self, *args, **kwargs): - return self - - def first(self): + def scalar(self, *args, **kwargs): return agent_thought monkeypatch.setattr( @@ -1182,13 +1176,7 @@ class TestEasyUiBasedGenerateTaskPipeline: def __exit__(self, exc_type, exc, tb): return False - def query(self, *args, **kwargs): - return self - - def where(self, *args, **kwargs): - return self - - def first(self): + def scalar(self, *args, **kwargs): return None monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 31b73130666..595d716666c 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -17,11 +17,11 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.file import FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from graphon.file import FileTransferMethod, FileType from models.model import MessageFile, UploadFile diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index 07ee75ed35f..92fe3cbec67 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -9,7 +9,7 @@ from flask import Flask, current_app from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueRetrieverResourcesEvent from core.app.entities.task_entities import MessageStreamResponse, StreamEvent, TaskStateMetadata from core.app.task_pipeline.message_cycle_manager import MessageCycleManager -from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.rag.entities import RetrievalSourceMetadata from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py index 29df7eea863..21c761c579b 100644 --- a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -1,10 +1,9 @@ from types import SimpleNamespace from unittest.mock import patch -from graphon.model_runtime.entities.model_entities import ModelPropertyKey - from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.entities import ModelConfigEntity +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index dc2d82ccd6c..5c50cb78dae 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -2,14 +2,14 @@ from datetime import UTC, datetime from unittest.mock import Mock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType -from graphon.node_events import NodeRunResult from core.app.workflow.layers.persistence import ( PersistenceWorkflowInfo, WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index 7be9d6ac1eb..701863b9272 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -8,13 +8,13 @@ from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, urlparse import pytest -from graphon.file import File, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope from core.app.workflow import file_runtime from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime from core.workflow.file_reference import build_file_reference +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile @@ -26,8 +26,8 @@ def _build_file( extension: str | None = None, ) -> File: return File( - id="file-id", - type=FileType.IMAGE, + file_id="file-id", + file_type=FileType.IMAGE, transfer_method=transfer_method, reference=reference, remote_url=remote_url, @@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M assert runtime.multimodal_send_format == "url" - with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get: assert runtime.http_get("http://example", follow_redirects=False) == "response" mock_get.assert_called_once_with("http://example", follow_redirects=False) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index 8497261d453..30a068f4c5a 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -1,15 +1,15 @@ from types import SimpleNamespace import pytest -from graphon.enums import BuiltinNodeTypes from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.workflow.node_factory import DifyNodeFactory +from graphon.enums import BuiltinNodeTypes class DummyNode: - def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): - self.id = id + def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = node_id self.config = config self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py index a47d3db6f5b..82552470a95 100644 --- a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -2,9 +2,8 @@ from __future__ import annotations from types import SimpleNamespace -from graphon.enums import BuiltinNodeTypes - from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerExtras: diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index d8a68f6d000..cacb4dd4fa9 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -4,6 +4,10 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables from graphon.entities import WorkflowNodeExecution from graphon.entities.pause_reason import SchedulingPause from graphon.enums import ( @@ -29,10 +33,6 @@ from graphon.graph_events import ( from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity -from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.workflow.system_variables import SystemVariableKey, build_system_variables - class _RepoRecorder: def __init__(self) -> None: diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 5ff9774b525..7b433ab57b9 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -301,6 +301,7 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -308,8 +309,6 @@ class TestAppGeneratorTTSPublisher: TextPromptMessageContent, ) - from core.app.entities.queue_entities import QueueAgentMessageEvent - chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( @@ -337,11 +336,10 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import AssistantPromptMessage - from core.app.entities.queue_entities import QueueAgentMessageEvent - chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index b0c72ee42f5..deeac49bbc2 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,15 +2,15 @@ import types from collections.abc import Generator import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.workflow.file_reference import parse_file_reference +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: @@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", extension=".png", @@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", extension=".pdf", @@ -632,16 +632,6 @@ def test_get_upload_file_by_id_builds_file(mocker): source_url="http://x", ) - class _Q: - def __init__(self, row): - self._row = row - - def where(self, *_args, **_kwargs): - return self - - def first(self): - return self._row - class _S: def __init__(self, row): self._row = row @@ -652,8 +642,8 @@ def test_get_upload_file_by_id_builds_file(mocker): def __exit__(self, *exc): return False - def query(self, *_): - return _Q(self._row) + def scalar(self, *_args, **_kwargs): + return self._row mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row)) @@ -665,13 +655,6 @@ def test_get_upload_file_by_id_builds_file(mocker): def test_get_upload_file_by_id_raises_when_missing(mocker): - class _Q: - def where(self, *_args, **_kwargs): - return self - - def first(self): - return None - class _S: def __enter__(self): return self @@ -679,8 +662,8 @@ def test_get_upload_file_by_id_raises_when_missing(mocker): def __exit__(self, *exc): return False - def query(self, *_): - return _Q() + def scalar(self, *_args, **_kwargs): + return None mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S()) diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index fbaf6d497d7..0fca43cd0b1 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -1,10 +1,10 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.file import File, FileTransferMethod, FileType from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index ff9fd0d8f3f..ef8f360dbfe 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -1,12 +1,11 @@ -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType - from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputContent, HumanInputFormDefinition, HumanInputFormSubmissionData, ) +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index 2acd278a31e..aeca2e3afdd 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -8,9 +8,6 @@ drive provider mapping behavior. """ import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.entities.model_entities import ( DefaultModelEntity, @@ -19,6 +16,9 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: @@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None: # Assert assert simple_provider.provider == "openai" - assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.label.en_us == "OpenAI" assert simple_provider.supported_model_types == [ModelType.LLM] diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 8cf0409c4c2..a28143026ff 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -6,17 +6,6 @@ from typing import Any from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FieldModelSchema, - FormType, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderEntity, -) from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus @@ -35,6 +24,17 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) ] ) - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="encrypted-old-key" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): - with patch( - "core.entities.provider_configuration.encrypter.encrypt_token", - side_effect=lambda tenant_id, value: f"enc::{value}", - ): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, - credential_id="credential-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + ) assert validated["openai_api_key"] == "enc::restored-key" assert validated["region"] == "us" @@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) -def test_validate_provider_credentials_opens_session_when_not_passed() -> None: +def test_validate_provider_credentials_without_credential_id() -> None: configuration = _build_provider_configuration() - mock_session = Mock() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.Session") as mock_session_cls: - with patch("core.entities.provider_configuration.db") as mock_db: - mock_db.engine = Mock() - mock_session_cls.return_value.__enter__.return_value = mock_session - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory - ): - validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} - mock_session_cls.assert_called_once() def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: @@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non def test_validate_provider_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-key"} @@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback( def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( encrypted_config='{"openai_api_key":"enc"}' ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) - assert validated == {"openai_api_key": "enc-new"} - - session = Mock() - mock_factory = Mock() - mock_factory.model_credentials_validate.return_value = {"region": "us"} - with _patched_session(session): + with _patched_session(mock_session): with patch( "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory ): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"region": "us"}, - ) + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) + assert validated == {"openai_api_key": "enc-new"} + + mock_factory2 = Mock() + mock_factory2.model_credentials_validate.return_value = {"region": "us"} + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) assert validated == {"region": "us"} @@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = None + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = None mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} @@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index 8685d162831..a159d3ad4d0 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -1,5 +1,4 @@ import pytest -from graphon.model_runtime.entities.model_entities import ModelType from core.entities.parameter_entities import AppSelectorScope from core.entities.provider_entities import ( @@ -9,6 +8,7 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/external_data_tool/test_base.py b/api/tests/unit_tests/core/external_data_tool/test_base.py index 216cda83c56..63e887f904e 100644 --- a/api/tests/unit_tests/core/external_data_tool/test_base.py +++ b/api/tests/unit_tests/core/external_data_tool/test_base.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from core.extension.extensible import ExtensionModule @@ -12,10 +14,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"}) @@ -28,10 +30,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") @@ -43,10 +45,10 @@ class TestExternalDataTool: def test_validate_config_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" with pytest.raises(NotImplementedError): @@ -55,10 +57,10 @@ class TestExternalDataTool: def test_query_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index bb6e40e2247..8cb09385755 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType def test_file(): file = File( - id="test-file", + file_id="test-file", tenant_id="test-tenant-id", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="test-related-id", filename="image.png", @@ -25,27 +25,21 @@ def test_file(): assert file.size == 67 -def test_file_model_validate_accepts_legacy_tenant_id(): - data = { - "id": "test-file", - "tenant_id": "test-tenant-id", - "type": "image", - "transfer_method": "tool_file", - "related_id": "test-related-id", - "filename": "image.png", - "extension": ".png", - "mime_type": "image/png", - "size": 67, - "storage_key": "test-storage-key", - "url": "https://example.com/image.png", - # Extra legacy fields - "tool_file_id": "tool-file-123", - "upload_file_id": "upload-file-456", - "datasource_file_id": "datasource-file-789", - } +def test_file_constructor_accepts_legacy_tenant_id(): + file = File( + file_id="test-file", + tenant_id="test-tenant-id", + file_type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + tool_file_id="tool-file-123", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) - file = File.model_validate(data) - - assert file.related_id == "test-related-id" + assert file.related_id == "tool-file-123" assert file.storage_key == "test-storage-key" assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/helper/code_executor/jinja2/test_jinja2_formatter.py b/api/tests/unit_tests/core/helper/code_executor/jinja2/test_jinja2_formatter.py new file mode 100644 index 00000000000..60002a757d5 --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/jinja2/test_jinja2_formatter.py @@ -0,0 +1,24 @@ +from pytest_mock import MockerFixture + +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter + + +def test_format_returns_result_value_as_string(mocker: MockerFixture) -> None: + execute_mock = mocker.patch( + "core.helper.code_executor.jinja2.jinja2_formatter.CodeExecutor.execute_workflow_code_template", + return_value={"result": 123}, + ) + + formatted = Jinja2Formatter.format("Hello {{ name }}", {"name": "Dify"}) + + assert formatted == "123" + execute_mock.assert_called_once() + + +def test_format_returns_empty_string_when_result_missing(mocker: MockerFixture) -> None: + mocker.patch( + "core.helper.code_executor.jinja2.jinja2_formatter.CodeExecutor.execute_workflow_code_template", + return_value={}, + ) + + assert Jinja2Formatter.format("Hello", {"name": "Dify"}) == "" diff --git a/api/tests/unit_tests/core/helper/code_executor/test_code_executor.py b/api/tests/unit_tests/core/helper/code_executor/test_code_executor.py new file mode 100644 index 00000000000..e09dd03489d --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/test_code_executor.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.helper.code_executor import code_executor as code_executor_module + + +def test_execute_workflow_code_template_raises_for_unsupported_language() -> None: + with pytest.raises(code_executor_module.CodeExecutionError, match="Unsupported language"): + code_executor_module.CodeExecutor.execute_workflow_code_template(cast(Any, "ruby"), "print(1)", {}) + + +def test_execute_workflow_code_template_uses_transformer(mocker: MockerFixture) -> None: + transformer = MagicMock() + transformer.transform_caller.return_value = ("runner-script", "preload-script") + transformer.transform_response.return_value = {"result": "ok"} + execute_mock = mocker.patch.object( + code_executor_module.CodeExecutor, + "execute_code", + return_value='<>{"result":"ok"}<>', + ) + mocker.patch.dict(code_executor_module.CodeExecutor.code_template_transformers, {"fake": transformer}, clear=False) + + result = code_executor_module.CodeExecutor.execute_workflow_code_template(cast(Any, "fake"), "code", {"a": 1}) + + assert result == {"result": "ok"} + transformer.transform_caller.assert_called_once_with("code", {"a": 1}) + execute_mock.assert_called_once_with("fake", "preload-script", "runner-script") + + +def test_execute_code_raises_service_unavailable_for_503(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 503 + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="service is unavailable"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_returns_stdout_on_success(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"code": 0, "message": "ok", "data": {"stdout": "done", "error": None}} + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + assert code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") == "done" + + +def test_execute_code_raises_for_non_200_status(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 500 + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="likely a network issue"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_raises_when_client_post_fails(mocker: MockerFixture) -> None: + client = MagicMock() + client.post.side_effect = RuntimeError("timeout") + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="likely a network issue"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_raises_when_response_json_is_invalid(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 200 + response.json.side_effect = ValueError("bad json") + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="Failed to parse response"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_raises_when_sandbox_returns_error_code(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"code": 1, "message": "boom", "data": {"stdout": "", "error": None}} + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="Got error code: 1"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") + + +def test_execute_code_raises_when_response_contains_runtime_error(mocker: MockerFixture) -> None: + response = MagicMock() + response.status_code = 200 + response.json.return_value = {"code": 0, "message": "ok", "data": {"stdout": "", "error": "runtime failed"}} + client = MagicMock() + client.post.return_value = response + mocker.patch("core.helper.code_executor.code_executor.get_pooled_http_client", return_value=client) + + with pytest.raises(code_executor_module.CodeExecutionError, match="runtime failed"): + code_executor_module.CodeExecutor.execute_code(cast(Any, "python3"), preload="", code="print(1)") diff --git a/api/tests/unit_tests/core/helper/code_executor/test_code_node_provider.py b/api/tests/unit_tests/core/helper/code_executor/test_code_node_provider.py new file mode 100644 index 00000000000..47761a32ac8 --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/test_code_node_provider.py @@ -0,0 +1,29 @@ +from core.helper.code_executor.code_node_provider import CodeNodeProvider + + +class _DummyProvider(CodeNodeProvider): + @staticmethod + def get_language() -> str: + return "dummy" + + @classmethod + def get_default_code(cls) -> str: + return "def main():\n return {'result': 'ok'}" + + +def test_is_accept_language() -> None: + assert _DummyProvider.is_accept_language("dummy") is True + assert _DummyProvider.is_accept_language("other") is False + + +def test_get_default_config_contains_expected_shape() -> None: + config = _DummyProvider.get_default_config() + + assert config["type"] == "code" + assert config["config"]["code_language"] == "dummy" + assert config["config"]["code"] == _DummyProvider.get_default_code() + assert config["config"]["variables"] == [ + {"variable": "arg1", "value_selector": []}, + {"variable": "arg2", "value_selector": []}, + ] + assert config["config"]["outputs"] == {"result": {"type": "string", "children": None}} diff --git a/api/tests/unit_tests/core/helper/code_executor/test_template_transformer.py b/api/tests/unit_tests/core/helper/code_executor/test_template_transformer.py new file mode 100644 index 00000000000..5b54b8e6474 --- /dev/null +++ b/api/tests/unit_tests/core/helper/code_executor/test_template_transformer.py @@ -0,0 +1,81 @@ +import json +from base64 import b64decode +from collections.abc import Mapping +from typing import Any + +import pytest + +from core.helper.code_executor.template_transformer import TemplateTransformer + + +class _DummyTransformer(TemplateTransformer): + @classmethod + def get_runner_script(cls) -> str: + return f"CODE={cls._code_placeholder};INPUTS={cls._inputs_placeholder}" + + +def test_serialize_code_encodes_to_base64() -> None: + encoded = _DummyTransformer.serialize_code("print('hi')") + + assert b64decode(encoded.encode()).decode() == "print('hi')" + + +def test_assemble_runner_script_embeds_code_and_inputs() -> None: + script = _DummyTransformer.assemble_runner_script("x = 1", {"a": "b"}) + + assert "CODE=x = 1" in script + payload = script.split("INPUTS=", maxsplit=1)[1] + assert json.loads(b64decode(payload.encode()).decode()) == {"a": "b"} + + +def test_transform_caller_returns_runner_and_empty_preload() -> None: + runner, preload = _DummyTransformer.transform_caller("x = 2", {"k": "v"}) + + assert "CODE=x = 2" in runner + assert preload == "" + + +def test_serialize_inputs_encodes_payload() -> None: + payload = _DummyTransformer.serialize_inputs({"foo": "bar"}) + + assert json.loads(b64decode(payload.encode()).decode()) == {"foo": "bar"} + + +def test_transform_response_parses_json_result_and_converts_scientific_notation() -> None: + response = '<>{"value": "1e+3", "nested": {"x": "2E-2"}, "arr": ["3e+1"]}<>' + + result: Mapping[str, Any] = _DummyTransformer.transform_response(response) + + assert result == {"value": 1000.0, "nested": {"x": 0.02}, "arr": [30.0]} + + +def test_transform_response_raises_for_invalid_json() -> None: + with pytest.raises(ValueError, match="Failed to parse JSON response"): + _DummyTransformer.transform_response("<>{invalid json}<>") + + +def test_transform_response_raises_for_non_dict_result() -> None: + with pytest.raises(ValueError, match="Result must be a dict"): + _DummyTransformer.transform_response("<>[1,2,3]<>") + + +def test_transform_response_raises_for_non_string_keys(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("json.loads", lambda _: {1: "x"}) + + with pytest.raises(ValueError, match="Result keys must be strings"): + _DummyTransformer.transform_response('<>{"ignored": true}<>') + + +def test_transform_response_raises_for_unexpected_errors(monkeypatch: pytest.MonkeyPatch) -> None: + def _raise_unexpected(_: str) -> Any: + raise RuntimeError("boom") + + monkeypatch.setattr("json.loads", _raise_unexpected) + + with pytest.raises(ValueError, match="Unexpected error during response transformation"): + _DummyTransformer.transform_response('<>{"a":1}<>') + + +def test_transform_response_raises_for_missing_result_tag() -> None: + with pytest.raises(ValueError, match="no result tag found"): + _DummyTransformer.transform_response("plain output") diff --git a/api/tests/unit_tests/core/helper/test_credential_utils.py b/api/tests/unit_tests/core/helper/test_credential_utils.py new file mode 100644 index 00000000000..dd10f81b02a --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_credential_utils.py @@ -0,0 +1,138 @@ +from types import SimpleNamespace +from typing import cast + +import pytest +from pytest_mock import MockerFixture + +from core.entities import PluginCredentialType +from core.helper.credential_utils import check_credential_policy_compliance, is_credential_exists + + +def test_check_credential_policy_compliance_returns_when_feature_disabled( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=False)), + ) + check_call = mocker.patch( + "services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance" + ) + + check_credential_policy_compliance("cred-1", "openai", PluginCredentialType.MODEL) + + check_call.assert_not_called() + + +def test_check_credential_policy_compliance_raises_when_credential_missing( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)), + ) + mocker.patch("core.helper.credential_utils.is_credential_exists", return_value=False) + + with pytest.raises(ValueError, match="Credential with id cred-1 for provider openai not found."): + check_credential_policy_compliance("cred-1", "openai", PluginCredentialType.TOOL) + + +def test_check_credential_policy_compliance_calls_plugin_manager_with_request( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)), + ) + mocker.patch("core.helper.credential_utils.is_credential_exists", return_value=True) + check_call = mocker.patch( + "services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance" + ) + + check_credential_policy_compliance("cred-1", "openai", PluginCredentialType.MODEL) + + check_call.assert_called_once() + request_arg = check_call.call_args.args[0] + assert request_arg.dify_credential_id == "cred-1" + assert request_arg.provider == "openai" + assert request_arg.credential_type == PluginCredentialType.MODEL + + +def test_check_credential_policy_compliance_skips_existence_check_when_disabled( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)), + ) + exists_call = mocker.patch("core.helper.credential_utils.is_credential_exists") + check_call = mocker.patch( + "services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance" + ) + + check_credential_policy_compliance( + credential_id="cred-1", + provider="openai", + credential_type=PluginCredentialType.MODEL, + check_existence=False, + ) + + exists_call.assert_not_called() + check_call.assert_called_once() + + +def test_check_credential_policy_compliance_returns_when_credential_id_empty( + mocker: MockerFixture, +) -> None: + mocker.patch( + "services.feature_service.FeatureService.get_system_features", + return_value=SimpleNamespace(plugin_manager=SimpleNamespace(enabled=True)), + ) + exists_call = mocker.patch("core.helper.credential_utils.is_credential_exists") + check_call = mocker.patch( + "services.enterprise.plugin_manager_service.PluginManagerService.check_credential_policy_compliance" + ) + + check_credential_policy_compliance("", "openai", PluginCredentialType.MODEL) + + exists_call.assert_not_called() + check_call.assert_not_called() + + +@pytest.mark.parametrize( + ("credential_type", "scalar_result", "expected"), + [ + (PluginCredentialType.MODEL, "model-credential", True), + (PluginCredentialType.MODEL, None, False), + (PluginCredentialType.TOOL, "tool-credential", True), + (PluginCredentialType.TOOL, None, False), + ], +) +def test_is_credential_exists_by_type( + mocker: MockerFixture, + credential_type: PluginCredentialType, + scalar_result: str | None, + expected: bool, +) -> None: + mocker.patch("extensions.ext_database.db", new=SimpleNamespace(engine=object())) + session_cls = mocker.patch("sqlalchemy.orm.Session") + session = session_cls.return_value.__enter__.return_value + session.scalar.return_value = scalar_result + + result = is_credential_exists("cred-1", credential_type) + + assert result is expected + session.scalar.assert_called_once() + + +def test_is_credential_exists_returns_false_for_unknown_type( + mocker: MockerFixture, +) -> None: + mocker.patch("extensions.ext_database.db", new=SimpleNamespace(engine=object())) + session_cls = mocker.patch("sqlalchemy.orm.Session") + session = session_cls.return_value.__enter__.return_value + + result = is_credential_exists("cred-1", cast(PluginCredentialType, "unknown")) + + assert result is False + session.scalar.assert_not_called() diff --git a/api/tests/unit_tests/core/helper/test_download.py b/api/tests/unit_tests/core/helper/test_download.py new file mode 100644 index 00000000000..0755c25826b --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_download.py @@ -0,0 +1,53 @@ +from collections.abc import Iterator + +import pytest +from pytest_mock import MockerFixture + +from core.helper.download import download_with_size_limit + + +class _StubResponse: + def __init__(self, status_code: int, chunks: list[bytes]) -> None: + self.status_code = status_code + self._chunks = chunks + + def iter_bytes(self) -> Iterator[bytes]: + return iter(self._chunks) + + +def test_download_with_size_limit_returns_content(mocker: MockerFixture) -> None: + response = _StubResponse(status_code=200, chunks=[b"ab", b"cd", b"ef"]) + mock_get = mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response) + + content = download_with_size_limit("https://example.com/a.txt", max_download_size=6, timeout=10) + + assert content == b"abcdef" + mock_get.assert_called_once_with("https://example.com/a.txt", follow_redirects=True, timeout=10) + + +def test_download_with_size_limit_raises_for_404(mocker: MockerFixture) -> None: + mocker.patch("core.helper.download.ssrf_proxy.get", return_value=_StubResponse(status_code=404, chunks=[])) + + with pytest.raises(ValueError, match="file not found"): + download_with_size_limit("https://example.com/missing.txt", max_download_size=10) + + +def test_download_with_size_limit_raises_when_size_exceeds_limit( + mocker: MockerFixture, +) -> None: + response = _StubResponse(status_code=200, chunks=[b"abc", b"de"]) + mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response) + + with pytest.raises(ValueError, match="Max file size reached"): + download_with_size_limit("https://example.com/large.bin", max_download_size=4) + + +def test_download_with_size_limit_accepts_content_equal_to_limit( + mocker: MockerFixture, +) -> None: + response = _StubResponse(status_code=200, chunks=[b"ab", b"cd"]) + mocker.patch("core.helper.download.ssrf_proxy.get", return_value=response) + + content = download_with_size_limit("https://example.com/exact.bin", max_download_size=4) + + assert content == b"abcd" diff --git a/api/tests/unit_tests/core/helper/test_http_client_pooling.py b/api/tests/unit_tests/core/helper/test_http_client_pooling.py new file mode 100644 index 00000000000..c29962f1b1a --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_http_client_pooling.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import httpx + +from core.helper.http_client_pooling import HttpClientPoolFactory + + +def test_get_or_create_reuses_client_for_same_key() -> None: + factory = HttpClientPoolFactory() + first_client = MagicMock(spec=httpx.Client) + second_client = MagicMock(spec=httpx.Client) + clients = [first_client, second_client] + + def _builder() -> httpx.Client: + return clients.pop(0) + + assert factory.get_or_create("shared", _builder) is first_client + assert factory.get_or_create("shared", _builder) is first_client + + +def test_get_or_create_creates_distinct_clients_for_distinct_keys() -> None: + factory = HttpClientPoolFactory() + client_a = MagicMock(spec=httpx.Client) + client_b = MagicMock(spec=httpx.Client) + + assert factory.get_or_create("a", lambda: client_a) is client_a + assert factory.get_or_create("b", lambda: client_b) is client_b + + +def test_close_all_closes_pooled_clients_and_allows_recreate() -> None: + factory = HttpClientPoolFactory() + first_client = MagicMock(spec=httpx.Client) + replacement_client = MagicMock(spec=httpx.Client) + + assert factory.get_or_create("x", lambda: first_client) is first_client + factory.close_all() + + first_client.close.assert_called_once() + assert factory.get_or_create("x", lambda: replacement_client) is replacement_client diff --git a/api/tests/unit_tests/core/helper/test_marketplace.py b/api/tests/unit_tests/core/helper/test_marketplace.py new file mode 100644 index 00000000000..bd561b16373 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_marketplace.py @@ -0,0 +1,110 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +from pytest_mock import MockerFixture + +from core.helper.marketplace import ( + batch_fetch_plugin_by_ids, + batch_fetch_plugin_manifests, + download_plugin_pkg, + fetch_global_plugin_manifest, + get_plugin_pkg_url, + record_install_plugin_event, +) + + +def test_get_plugin_pkg_url_contains_unique_identifier() -> None: + url = get_plugin_pkg_url("plugin@1.0.0") + + assert "api/v1/plugins/download" in url + assert "unique_identifier=plugin@1.0.0" in url + + +def test_download_plugin_pkg_delegates_with_configured_size(mocker: MockerFixture) -> None: + mocked_download = mocker.patch("core.helper.marketplace.download_with_size_limit", return_value=b"pkg") + mocker.patch("core.helper.marketplace.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 1234) + + result = download_plugin_pkg("plugin.a.b") + + assert result == b"pkg" + mocked_download.assert_called_once() + called_url, called_limit = mocked_download.call_args.args + assert "unique_identifier=plugin.a.b" in called_url + assert called_limit == 1234 + + +def test_batch_fetch_plugin_by_ids_returns_empty_for_empty_input(mocker: MockerFixture) -> None: + post_mock = mocker.patch("core.helper.marketplace.httpx.post") + + assert batch_fetch_plugin_by_ids([]) == [] + post_mock.assert_not_called() + + +def test_batch_fetch_plugin_by_ids_returns_plugins_from_response(mocker: MockerFixture) -> None: + response = MagicMock() + response.json.return_value = {"data": {"plugins": [{"id": "p1"}]}} + response.raise_for_status.return_value = None + post_mock = mocker.patch("core.helper.marketplace.httpx.post", return_value=response) + + plugins = batch_fetch_plugin_by_ids(["p1"]) + + assert plugins == [{"id": "p1"}] + post_mock.assert_called_once() + response.raise_for_status.assert_called_once() + + +def test_batch_fetch_plugin_manifests_returns_empty_for_empty_input(mocker: MockerFixture) -> None: + post_mock = mocker.patch("core.helper.marketplace.httpx.post") + + assert batch_fetch_plugin_manifests([]) == [] + post_mock.assert_not_called() + + +def test_batch_fetch_plugin_manifests_validates_and_returns_plugins(mocker: MockerFixture) -> None: + response = MagicMock() + response.raise_for_status.return_value = None + response.json.return_value = {"data": {"plugins": [{"id": "p1"}, {"id": "p2"}]}} + post_mock = mocker.patch("core.helper.marketplace.httpx.post", return_value=response) + validate_mock = mocker.patch( + "core.helper.marketplace.MarketplacePluginDeclaration.model_validate", + side_effect=["manifest-1", "manifest-2"], + ) + + result = batch_fetch_plugin_manifests(["p1", "p2"]) + + assert result == ["manifest-1", "manifest-2"] + post_mock.assert_called_once() + assert validate_mock.call_count == 2 + response.raise_for_status.assert_called_once() + + +def test_record_install_plugin_event_posts_and_checks_status(mocker: MockerFixture) -> None: + response = MagicMock() + response.raise_for_status.return_value = None + post_mock = mocker.patch("core.helper.marketplace.httpx.post", return_value=response) + + record_install_plugin_event("plugin.a") + + post_mock.assert_called_once() + response.raise_for_status.assert_called_once() + + +def test_fetch_global_plugin_manifest_caches_each_plugin(mocker: MockerFixture) -> None: + response = MagicMock() + response.raise_for_status.return_value = None + response.json.return_value = {"plugins": [{"id": "a"}, {"id": "b"}]} + mocker.patch("core.helper.marketplace.httpx.get", return_value=response) + + snapshot_a = SimpleNamespace(plugin_id="plugin-a", model_dump_json=lambda: '{"id":"a"}') + snapshot_b = SimpleNamespace(plugin_id="plugin-b", model_dump_json=lambda: '{"id":"b"}') + validate_mock = mocker.patch( + "core.helper.marketplace.MarketplacePluginSnapshot.model_validate", + side_effect=[snapshot_a, snapshot_b], + ) + setex_mock = mocker.patch("core.helper.marketplace.redis_client.setex") + + fetch_global_plugin_manifest("prefix:", 60) + + assert validate_mock.call_count == 2 + setex_mock.assert_any_call(name="prefix:plugin-a", time=60, value='{"id":"a"}') + setex_mock.assert_any_call(name="prefix:plugin-b", time=60, value='{"id":"b"}') diff --git a/api/tests/unit_tests/core/helper/test_moderation.py b/api/tests/unit_tests/core/helper/test_moderation.py new file mode 100644 index 00000000000..a0dfa86d209 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_moderation.py @@ -0,0 +1,158 @@ +from types import SimpleNamespace +from typing import cast + +import pytest +from pytest_mock import MockerFixture + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.moderation import check_moderation +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from models.provider import ProviderType + + +def _build_model_config(provider: str = "openai") -> SimpleNamespace: + return SimpleNamespace( + provider=provider, + provider_model_bundle=SimpleNamespace( + configuration=SimpleNamespace(using_provider_type=ProviderType.SYSTEM), + ), + ) + + +def test_check_moderation_returns_false_when_feature_not_enabled(mocker: MockerFixture) -> None: + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace(moderation_config=None, provider_map={}), + ) + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "hello", + ) + is False + ) + + +def test_check_moderation_returns_false_when_hosting_credentials_missing(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: SimpleNamespace(enabled=True, credentials=None)}, + ), + ) + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "hello", + ) + is False + ) + + +def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"}) + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: hosting_openai}, + ), + ) + mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") + + moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk") + factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) + mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "abc", + ) + is True + ) + + +def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"}) + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: hosting_openai}, + ), + ) + factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory") + choice_mock = mocker.patch("core.helper.moderation.secrets.choice") + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "", + ) + is True + ) + factory_mock.assert_not_called() + choice_mock.assert_not_called() + + +def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"}) + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: hosting_openai}, + ), + ) + mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") + + moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False) + factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) + mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + + assert ( + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "abc", + ) + is False + ) + + +def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: MockerFixture) -> None: + openai_provider = "langgenius/openai/openai" + hosting_openai = SimpleNamespace(enabled=True, credentials={"api_key": "k"}) + mocker.patch( + "core.helper.moderation.hosting_configuration", + SimpleNamespace( + moderation_config=SimpleNamespace(enabled=True, providers={"openai"}), + provider_map={openai_provider: hosting_openai}, + ), + ) + mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") + + failing_model = SimpleNamespace( + invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")), + ) + factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model) + mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + + with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."): + check_moderation( + "tenant-1", + cast(ModelConfigWithCredentialsEntity, _build_model_config()), + "abc", + ) diff --git a/api/tests/unit_tests/core/helper/test_name_generator.py b/api/tests/unit_tests/core/helper/test_name_generator.py new file mode 100644 index 00000000000..37a87260f15 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_name_generator.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass + +from pytest_mock import MockerFixture + +from core.helper.name_generator import generate_incremental_name, generate_provider_name +from core.plugin.entities.plugin_daemon import CredentialType + + +@dataclass +class _Provider: + name: str + + +def test_generate_incremental_name_uses_next_highest_suffix() -> None: + names = ["API KEY 1", "API KEY 3", "API KEY 2", "other", "", "API KEY x"] + + assert generate_incremental_name(names, "API KEY") == "API KEY 4" + + +def test_generate_incremental_name_returns_default_when_no_matches() -> None: + assert generate_incremental_name(["custom", " ", ""], "AUTH") == "AUTH 1" + + +def test_generate_provider_name_uses_credential_display_name() -> None: + providers = [_Provider(name="API KEY 1"), _Provider(name="API KEY 2")] + + assert generate_provider_name(providers, CredentialType.API_KEY) == "API KEY 3" + + +def test_generate_provider_name_falls_back_on_generation_error(mocker: MockerFixture) -> None: + mocker.patch("core.helper.name_generator.generate_incremental_name", side_effect=RuntimeError("boom")) + + assert generate_provider_name([], CredentialType.OAUTH2, fallback_context="ctx") == "AUTH 1" diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 3b5c5e65978..d9fed9ae2ad 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -1,11 +1,17 @@ from unittest.mock import MagicMock, patch +import httpx import pytest from core.helper.ssrf_proxy import ( SSRF_DEFAULT_MAX_RETRIES, + SSRFProxy, _get_user_provided_host_header, + _to_graphon_http_response, + graphon_ssrf_proxy, make_request, + max_retries_exceeded_error, + request_error, ) @@ -174,3 +180,56 @@ class TestFollowRedirectsParameter: call_kwargs = mock_client.request.call_args.kwargs assert call_kwargs.get("follow_redirects") is True + + +def test_to_graphon_http_response_preserves_httpx_response_fields() -> None: + response = httpx.Response( + 201, + headers={"X-Test": "1"}, + content=b"payload", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + wrapped = _to_graphon_http_response(response) + + assert wrapped.status_code == 201 + assert wrapped.headers == {"x-test": "1", "content-length": "7"} + assert wrapped.content == b"payload" + assert wrapped.url == "https://example.com/resource" + assert wrapped.reason_phrase == "Created" + assert wrapped.text == "payload" + + +def test_ssrf_proxy_exposes_expected_error_types() -> None: + proxy = SSRFProxy() + + assert proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert proxy.request_error is request_error + assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert graphon_ssrf_proxy.request_error is request_error + + +@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"]) +def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None: + response = httpx.Response( + 200, + headers={"X-Test": "1"}, + content=b"ok", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method: + wrapped = getattr(graphon_ssrf_proxy, method_name)( + "https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + + mock_method.assert_called_once_with( + url="https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + assert wrapped.status_code == 200 + assert wrapped.url == "https://example.com/resource" + assert wrapped.content == b"ok" diff --git a/api/tests/unit_tests/core/helper/test_tool_parameter_cache.py b/api/tests/unit_tests/core/helper/test_tool_parameter_cache.py new file mode 100644 index 00000000000..3c8b44d0101 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_tool_parameter_cache.py @@ -0,0 +1,71 @@ +import json + +from pytest_mock import MockerFixture + +from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType + + +def test_tool_parameter_cache_get_returns_decoded_dict(mocker: MockerFixture) -> None: + redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client") + cache = ToolParameterCache( + tenant_id="tenant", + provider="provider", + tool_name="tool", + cache_type=ToolParameterCacheType.PARAMETER, + identity_id="identity", + ) + payload = {"k": "v", "n": 1} + cache_key = cache.cache_key + + redis_client_mock.get.return_value = json.dumps(payload).encode("utf-8") + + assert cache.get() == payload + redis_client_mock.get.assert_called_once_with(cache_key) + + +def test_tool_parameter_cache_get_returns_none_for_invalid_json(mocker: MockerFixture) -> None: + redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client") + cache = ToolParameterCache( + tenant_id="tenant", + provider="provider", + tool_name="tool", + cache_type=ToolParameterCacheType.PARAMETER, + identity_id="identity", + ) + + redis_client_mock.get.return_value = b"{invalid-json" + + assert cache.get() is None + + +def test_tool_parameter_cache_get_returns_none_when_key_is_missing(mocker: MockerFixture) -> None: + redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client") + cache = ToolParameterCache( + tenant_id="tenant", + provider="provider", + tool_name="tool", + cache_type=ToolParameterCacheType.PARAMETER, + identity_id="identity", + ) + + redis_client_mock.get.return_value = None + + assert cache.get() is None + + +def test_tool_parameter_cache_set_and_delete(mocker: MockerFixture) -> None: + redis_client_mock = mocker.patch("core.helper.tool_parameter_cache.redis_client") + cache = ToolParameterCache( + tenant_id="tenant", + provider="provider", + tool_name="tool", + cache_type=ToolParameterCacheType.PARAMETER, + identity_id="identity", + ) + + params = {"a": "b"} + cache.set(params) + cache.delete() + + redis_client_mock.setex.assert_called_once_with(cache.cache_key, 86400, json.dumps(params)) + redis_client_mock.delete.assert_called_once_with(cache.cache_key) diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index b45f6fd9a77..6ed9ddb476c 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -2,20 +2,6 @@ import json from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultWithStructuredOutput, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import ( @@ -30,6 +16,20 @@ from core.llm_generator.output_parser.structured_output import ( remove_additional_properties, ) from core.model_manager import ModelInstance +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType class TestStructuredOutput: diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 62e714deb61..2716f4712c3 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -2,12 +2,12 @@ import json from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: @@ -346,13 +346,13 @@ class TestLLMGenerator: def test_instruction_modify_workflow_app_not_found(self): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = None + mock_session.return_value.scalar.return_value = None with pytest.raises(ValueError, match="App not found."): LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock()) def test_instruction_modify_workflow_no_workflow(self): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow_service = MagicMock() workflow_service.get_draft_workflow.return_value = None with pytest.raises(ValueError, match="Workflow not found for the given app model."): @@ -360,7 +360,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index fe533e62af8..1f5fdd2657f 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -862,6 +862,15 @@ class TestAuthOrchestration: result = discover_protected_resource_metadata(None, "https://api.example.com") assert result is None + # JSONDecodeError (non-JSON 200 response) + mock_get.side_effect = None + bad_json_response = Mock() + bad_json_response.status_code = 200 + bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_response + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is None + @patch("core.helper.ssrf_proxy.get") def test_discover_oauth_authorization_server_metadata(self, mock_get): # Success @@ -892,6 +901,14 @@ class TestAuthOrchestration: result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") assert result is None + # JSONDecodeError (non-JSON 200 response) + bad_json_response = Mock() + bad_json_response.status_code = 200 + bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is None + def test_get_effective_scope(self): prm = ProtectedResourceMetadata( resource="https://api.example.com", @@ -997,6 +1014,24 @@ class TestAuthOrchestration: supported, url = check_support_resource_discovery("https://api") assert supported is False + # Case 6: JSONDecodeError (non-JSON 200 response) + mock_get.side_effect = None + bad_json_res = Mock() + bad_json_res.status_code = 200 + bad_json_res.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_res + supported, url = check_support_resource_discovery("https://api") + assert supported is False + assert url == "" + + # Case 7: Empty authorization_servers array (IndexError) + empty_res = Mock() + empty_res.status_code = 200 + empty_res.json.return_value = {"authorization_servers": []} + mock_get.return_value = empty_res + supported, url = check_support_resource_discovery("https://api") + assert supported is False + def test_discover_oauth_metadata(self): with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py index 81f8da9a621..bbbffa2e695 100644 --- a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py @@ -971,6 +971,23 @@ class TestHandlePostRequestNew: assert isinstance(item, SessionMessage) assert isinstance(item.message.root, JSONRPCError) assert item.message.root.id == 77 + assert item.message.root.error.message == "Session terminated by server" + + def test_404_on_initialization_includes_url_in_error(self): + t = _new_transport(url="http://example.com/mcp/server/abc123/mcp") + q: queue.Queue = queue.Queue() + msg = _make_request_msg("initialize", 1) + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.error.code == 32600 + assert "404 Not Found" in item.message.root.error.message + assert "http://example.com/mcp/server/abc123/mcp" in item.message.root.error.message def test_404_for_notification_no_error_sent(self): t = _new_transport() diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 9a815fb94d0..57456085c34 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,7 +3,6 @@ from unittest.mock import Mock, patch import jsonschema import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -19,6 +18,7 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/mcp/test_entities.py b/api/tests/unit_tests/core/mcp/test_entities.py index 3fede55916e..e99c38285ca 100644 --- a/api/tests/unit_tests/core/mcp/test_entities.py +++ b/api/tests/unit_tests/core/mcp/test_entities.py @@ -4,9 +4,7 @@ from unittest.mock import Mock from core.mcp.entities import ( SUPPORTED_PROTOCOL_VERSIONS, - LifespanContextT, RequestContext, - SessionT, ) from core.mcp.session.base_session import BaseSession from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestParams @@ -198,42 +196,3 @@ class TestRequestContext: assert "RequestContext" in repr_str assert "test-123" in repr_str assert "MockSession" in repr_str - - -class TestTypeVariables: - """Test type variables defined in the module.""" - - def test_session_type_var(self): - """Test SessionT type variable.""" - - # Create a custom session class - class CustomSession(BaseSession): - pass - - # Use in generic context - def process_session(session: SessionT) -> SessionT: - return session - - mock_session = Mock(spec=CustomSession) - result = process_session(mock_session) - assert result == mock_session - - def test_lifespan_context_type_var(self): - """Test LifespanContextT type variable.""" - - # Use in generic context - def process_lifespan(context: LifespanContextT) -> LifespanContextT: - return context - - # Test with different types - str_context = "string-context" - assert process_lifespan(str_context) == str_context - - dict_context = {"key": "value"} - assert process_lifespan(dict_context) == dict_context - - class CustomContext: - pass - - custom_context = CustomContext() - assert process_lifespan(custom_context) == custom_context diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index 9a5fb319d7b..f459250b8ef 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -4,6 +4,8 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest + +from core.memory.token_buffer_memory import TokenBufferMemory from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -11,8 +13,6 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from core.memory.token_buffer_memory import TokenBufferMemory from models.model import AppMode # --------------------------------------------------------------------------- diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 6a672fdfd57..c4fd9705625 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest + from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -12,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index 3a97ad5c5d2..4c668ee96b7 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -10,6 +10,7 @@ This module tests all aspects of the content moderation system including: - Configuration validation """ +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest @@ -28,7 +29,7 @@ class TestKeywordsModeration: """Test suite for custom keyword-based content moderation.""" @pytest.fixture - def keywords_config(self) -> dict: + def keywords_config(self) -> dict[str, Any]: """ Fixture providing a standard keywords moderation configuration. @@ -48,7 +49,7 @@ class TestKeywordsModeration: } @pytest.fixture - def keywords_moderation(self, keywords_config: dict) -> KeywordsModeration: + def keywords_moderation(self, keywords_config: dict[str, Any]) -> KeywordsModeration: """ Fixture providing a KeywordsModeration instance. @@ -64,7 +65,7 @@ class TestKeywordsModeration: config=keywords_config, ) - def test_validate_config_success(self, keywords_config: dict): + def test_validate_config_success(self, keywords_config: dict[str, Any]): """Test successful validation of keywords moderation configuration.""" # Should not raise any exception KeywordsModeration.validate_config("test-tenant", keywords_config) @@ -274,7 +275,7 @@ class TestOpenAIModeration: """Test suite for OpenAI-based content moderation.""" @pytest.fixture - def openai_config(self) -> dict: + def openai_config(self) -> dict[str, Any]: """ Fixture providing OpenAI moderation configuration. @@ -293,7 +294,7 @@ class TestOpenAIModeration: } @pytest.fixture - def openai_moderation(self, openai_config: dict) -> OpenAIModeration: + def openai_moderation(self, openai_config: dict[str, Any]) -> OpenAIModeration: """ Fixture providing an OpenAIModeration instance. @@ -309,7 +310,7 @@ class TestOpenAIModeration: config=openai_config, ) - def test_validate_config_success(self, openai_config: dict): + def test_validate_config_success(self, openai_config: dict[str, Any]): """Test successful validation of OpenAI moderation configuration.""" # Should not raise any exception OpenAIModeration.validate_config("test-tenant", openai_config) diff --git a/api/tests/unit_tests/core/ops/test_base_trace_instance.py b/api/tests/unit_tests/core/ops/test_base_trace_instance.py index a8bee7dfa72..ac65d134546 100644 --- a/api/tests/unit_tests/core/ops/test_base_trace_instance.py +++ b/api/tests/unit_tests/core/ops/test_base_trace_instance.py @@ -76,10 +76,7 @@ def test_get_service_account_with_tenant_tenant_not_found(mock_db_session): mock_account = MagicMock(spec=Account) mock_account.id = "creator_id" - mock_db_session.scalar.side_effect = [mock_app, mock_account] - - # session.query(TenantAccountJoin).filter_by(...).first() returns None - mock_db_session.query.return_value.filter_by.return_value.first.return_value = None + mock_db_session.scalar.side_effect = [mock_app, mock_account, None] config = MagicMock(spec=BaseTracingConfig) instance = ConcreteTraceInstance(config) @@ -97,11 +94,10 @@ def test_get_service_account_with_tenant_success(mock_db_session): mock_account.id = "creator_id" mock_account.set_tenant_id = MagicMock() - mock_db_session.scalar.side_effect = [mock_app, mock_account] - mock_tenant_join = MagicMock(spec=TenantAccountJoin) mock_tenant_join.tenant_id = "tenant_id" - mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join + + mock_db_session.scalar.side_effect = [mock_app, mock_account, mock_tenant_join] config = MagicMock(spec=BaseTracingConfig) instance = ConcreteTraceInstance(config) diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 2cbff54c42a..69650c85cc2 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -1,16 +1,11 @@ -import pytest -from pydantic import ValidationError +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_opik.config import OpikConfig +from dify_trace_weave.config import WeaveConfig -from core.ops.entities.config_entity import ( - AliyunConfig, - ArizeConfig, - LangfuseConfig, - LangSmithConfig, - OpikConfig, - PhoenixConfig, - TracingProviderEnum, - WeaveConfig, -) +from core.ops.entities.config_entity import TracingProviderEnum class TestTracingProviderEnum: @@ -27,349 +22,8 @@ class TestTracingProviderEnum: assert TracingProviderEnum.ALIYUN == "aliyun" -class TestArizeConfig: - """Test cases for ArizeConfig""" - - def test_valid_config(self): - """Test valid Arize configuration""" - config = ArizeConfig( - api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" - ) - assert config.api_key == "test_key" - assert config.space_id == "test_space" - assert config.project == "test_project" - assert config.endpoint == "https://custom.arize.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = ArizeConfig() - assert config.api_key is None - assert config.space_id is None - assert config.project is None - assert config.endpoint == "https://otlp.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = ArizeConfig(project="") - assert config.project == "default" - - def test_project_validation_none(self): - """Test project validation with None value""" - config = ArizeConfig(project=None) - assert config.project == "default" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = ArizeConfig(endpoint="") - assert config.endpoint == "https://otlp.arize.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") - assert config.endpoint == "https://custom.arize.com" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="ftp://invalid.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="invalid.com") - - -class TestPhoenixConfig: - """Test cases for PhoenixConfig""" - - def test_valid_config(self): - """Test valid Phoenix configuration""" - config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.phoenix.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = PhoenixConfig() - assert config.api_key is None - assert config.project is None - assert config.endpoint == "https://app.phoenix.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = PhoenixConfig(project="") - assert config.project == "default" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation with path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") - assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" - - def test_endpoint_validation_without_path(self): - """Test endpoint validation without path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") - assert config.endpoint == "https://app.phoenix.arize.com" - - -class TestLangfuseConfig: - """Test cases for LangfuseConfig""" - - def test_valid_config(self): - """Test valid Langfuse configuration""" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == "https://custom.langfuse.com" - - def test_valid_config_with_path(self): - host = "https://custom.langfuse.com/api/v1" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == host - - def test_default_values(self): - """Test default values are set correctly""" - config = LangfuseConfig(public_key="public", secret_key="secret") - assert config.host == "https://api.langfuse.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangfuseConfig() - - with pytest.raises(ValidationError): - LangfuseConfig(public_key="public") - - with pytest.raises(ValidationError): - LangfuseConfig(secret_key="secret") - - def test_host_validation_empty(self): - """Test host validation with empty value""" - config = LangfuseConfig(public_key="public", secret_key="secret", host="") - assert config.host == "https://api.langfuse.com" - - -class TestLangSmithConfig: - """Test cases for LangSmithConfig""" - - def test_valid_config(self): - """Test valid LangSmith configuration""" - config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.smith.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = LangSmithConfig(api_key="key", project="project") - assert config.endpoint == "https://api.smith.langchain.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangSmithConfig() - - with pytest.raises(ValidationError): - LangSmithConfig(api_key="key") - - with pytest.raises(ValidationError): - LangSmithConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") - - -class TestOpikConfig: - """Test cases for OpikConfig""" - - def test_valid_config(self): - """Test valid Opik configuration""" - config = OpikConfig( - api_key="test_key", - project="test_project", - workspace="test_workspace", - url="https://custom.comet.com/opik/api/", - ) - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.workspace == "test_workspace" - assert config.url == "https://custom.comet.com/opik/api/" - - def test_default_values(self): - """Test default values are set correctly""" - config = OpikConfig() - assert config.api_key is None - assert config.project is None - assert config.workspace is None - assert config.url == "https://www.comet.com/opik/api/" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = OpikConfig(project="") - assert config.project == "Default Project" - - def test_url_validation_empty(self): - """Test URL validation with empty value""" - config = OpikConfig(url="") - assert config.url == "https://www.comet.com/opik/api/" - - def test_url_validation_missing_suffix(self): - """Test URL validation requires /api/ suffix""" - with pytest.raises(ValidationError, match="URL should end with /api/"): - OpikConfig(url="https://custom.comet.com/opik/") - - def test_url_validation_invalid_scheme(self): - """Test URL validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - OpikConfig(url="ftp://custom.comet.com/opik/api/") - - -class TestWeaveConfig: - """Test cases for WeaveConfig""" - - def test_valid_config(self): - """Test valid Weave configuration""" - config = WeaveConfig( - api_key="test_key", - entity="test_entity", - project="test_project", - endpoint="https://custom.wandb.ai", - host="https://custom.host.com", - ) - assert config.api_key == "test_key" - assert config.entity == "test_entity" - assert config.project == "test_project" - assert config.endpoint == "https://custom.wandb.ai" - assert config.host == "https://custom.host.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = WeaveConfig(api_key="key", project="project") - assert config.entity is None - assert config.endpoint == "https://trace.wandb.ai" - assert config.host is None - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - WeaveConfig() - - with pytest.raises(ValidationError): - WeaveConfig(api_key="key") - - with pytest.raises(ValidationError): - WeaveConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") - - def test_host_validation_optional(self): - """Test host validation is optional but validates when provided""" - config = WeaveConfig(api_key="key", project="project", host=None) - assert config.host is None - - config = WeaveConfig(api_key="key", project="project", host="") - assert config.host == "" - - config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") - assert config.host == "https://valid.host.com" - - def test_host_validation_invalid_scheme(self): - """Test host validation rejects invalid schemes when provided""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") - - -class TestAliyunConfig: - """Test cases for AliyunConfig""" - - def test_valid_config(self): - """Test valid Aliyun configuration""" - config = AliyunConfig( - app_name="test_app", - license_key="test_license_key", - endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", - ) - assert config.app_name == "test_app" - assert config.license_key == "test_license_key" - assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - assert config.app_name == "dify_app" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - AliyunConfig() - - with pytest.raises(ValidationError): - AliyunConfig(license_key="test_license") - - with pytest.raises(ValidationError): - AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_app_name_validation_empty(self): - """Test app_name validation with empty value""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" - ) - assert config.app_name == "dify_app" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = AliyunConfig(license_key="test_license", endpoint="") - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation preserves path for Aliyun endpoints""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - ) - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_license_key_required(self): - """Test that license_key is required and cannot be empty""" - with pytest.raises(ValidationError): - AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_valid_endpoint_format_examples(self): - """Test valid endpoint format examples from comments""" - valid_endpoints = [ - # cms2.0 public endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", - # cms2.0 intranet endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", - # xtrace public endpoint - "http://tracing-cn-heyuan.arms.aliyuncs.com", - # xtrace intranet endpoint - "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", - ] - - for endpoint in valid_endpoints: - config = AliyunConfig(license_key="test_license", endpoint=endpoint) - assert config.endpoint == endpoint - - class TestConfigIntegration: - """Integration tests for configuration classes""" + """Cross-provider configuration sanity checks""" def test_all_configs_can_be_instantiated(self): """Test that all config classes can be instantiated with valid data""" @@ -388,7 +42,6 @@ class TestConfigIntegration: def test_url_normalization_consistency(self): """Test that URL normalization works consistently across configs""" - # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py index 543b278715d..c24d3ac0120 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -1,10 +1,9 @@ from types import SimpleNamespace from unittest.mock import patch -from graphon.model_runtime.entities.message_entities import UserPromptMessage - from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.entities.request import RequestInvokeSummary +from graphon.model_runtime.entities.message_entities import UserPromptMessage def test_system_model_helpers_forward_user_id(): diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index f8d0e127b1b..88bf5555949 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -6,15 +6,15 @@ from types import SimpleNamespace from unittest.mock import Mock, sentinel import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl import model_runtime as model_runtime_module from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_schema() -> AIModelEntity: @@ -56,7 +56,7 @@ class TestPluginModelRuntime: assert len(providers) == 1 assert providers[0].provider == "langgenius/openai/openai" assert providers[0].provider_name == "openai" - assert providers[0].label.en_US == "OpenAI" + assert providers[0].label.en_us == "OpenAI" client.fetch_model_providers.assert_called_once_with("tenant") def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index a812b01c5bd..f1c4c7e7009 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -4,12 +4,6 @@ from enum import StrEnum import pytest from flask import Response -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) from pydantic import ValidationError from core.plugin.entities.endpoint import EndpointEntityWithInstance @@ -31,6 +25,12 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) class TestEndpointEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index a3b1e5f6b0e..704b82adc00 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,14 +17,6 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -45,6 +37,14 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index 90730dff5a4..00a42077860 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,12 +1,12 @@ from collections.abc import Generator import pytest -from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector +from graphon.file import File, FileTransferMethod, FileType class TestChunkMerger: @@ -466,7 +466,7 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): file_param = File( tenant_id="tenant-1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/file.png", storage_key="", @@ -499,14 +499,14 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): file_one = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/a.txt", storage_key="", ) file_two = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/b.txt", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 2b280dd6746..e536c0831fd 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,6 +2,13 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -11,13 +18,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation @@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg files = [ File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="", @@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files(): completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2") file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query(): memory = MagicMock(spec=TokenBufferMemory) prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 4a54649b289..28966242d81 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,19 +1,18 @@ from unittest.mock import MagicMock -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index a4b3960b0a2..5d865d934cf 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,3 +1,5 @@ +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -7,9 +9,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil - def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index e35ce2c48a9..5308c8e7b3b 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -2,16 +2,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.prompt.prompt_transform import PromptTransform +from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle # from graphon.model_runtime.entities.message_entities import UserPromptMessage # from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule # from graphon.model_runtime.entities.provider_entities import ProviderEntity -# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 3f188cfbb4b..0dc74b33dfb 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,12 +2,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - TextPromptMessageContent, - UserPromptMessage, -) from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -24,6 +18,12 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py index 006b4e7345e..1f3247590c4 100644 --- a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -1,13 +1,12 @@ from unittest.mock import MagicMock, patch -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError - from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError def _doc(content: str) -> Document: diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py index bbdd4769146..136ac0c72a4 100644 --- a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -1,5 +1,6 @@ import json from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest @@ -57,7 +58,7 @@ class _FakeSelect: return self -def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None): +def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict[str, Any] | None = None): return SimpleNamespace( data_source_type=data_source_type, keyword_table_dict=keyword_table_dict, diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index 5dbd62580aa..0baf85c3149 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, call, patch from uuid import uuid4 @@ -20,7 +21,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -119,6 +120,14 @@ class _FakeSummaryQuery: return self._summaries +class _FakeScalarsResult: + def __init__(self, data: list) -> None: + self._data = data + + def all(self) -> list: + return self._data + + class _FakeSession: def __init__(self, execute_payloads: list[list], summaries: list) -> None: self._payloads = list(execute_payloads) @@ -128,8 +137,8 @@ class _FakeSession: data = self._payloads.pop(0) if self._payloads else [] return _FakeExecuteResult(data) - def query(self, model): - return _FakeSummaryQuery(self._summaries) + def scalars(self, stmt): + return _FakeScalarsResult(self._summaries) class _FakeSessionContext: @@ -228,7 +237,7 @@ class TestRetrievalServiceInternals: assert mock_retrieve.call_count == 2 @patch("core.rag.datasource.retrieval_service.ExternalDatasetService.fetch_external_knowledge_retrieval") - @patch("core.rag.datasource.retrieval_service.MetadataCondition.model_validate") + @patch("core.rag.datasource.retrieval_service.MetadataFilteringCondition.model_validate") @patch("core.rag.datasource.retrieval_service.db.session.scalar") def test_external_retrieve_with_metadata_conditions(self, mock_scalar, mock_validate, mock_fetch): mock_scalar.return_value = SimpleNamespace(tenant_id="tenant-1") @@ -265,14 +274,14 @@ class TestRetrievalServiceInternals: def test_get_dataset_queries_by_id(self, mock_session_class): expected_dataset = Mock(spec=Dataset) mock_session = Mock() - mock_session.query.return_value.where.return_value.first.return_value = expected_dataset + mock_session.scalar.return_value = expected_dataset mock_session_class.return_value.__enter__.return_value = mock_session with patch.object(retrieval_service_module, "db", SimpleNamespace(engine=Mock())): result = RetrievalService._get_dataset("dataset-123") assert result == expected_dataset - mock_session.query.assert_called_once() + mock_session.scalar.assert_called_once() @patch("core.rag.datasource.retrieval_service.Keyword") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -1046,12 +1055,8 @@ class TestRetrievalServiceInternals: size=42, ) binding = SimpleNamespace(segment_id="segment-1", attachment_id="upload-1") - upload_query = Mock() - upload_query.where.return_value.first.return_value = upload_file - binding_query = Mock() - binding_query.where.return_value.first.return_value = binding session = Mock() - session.query.side_effect = [upload_query, binding_query] + session.scalar.side_effect = [upload_file, binding] result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) @@ -1076,32 +1081,26 @@ class TestRetrievalServiceInternals: mime_type="image/png", size=42, ) - upload_query = Mock() - upload_query.where.return_value.first.return_value = upload_file - binding_query = Mock() - binding_query.where.return_value.first.return_value = None session = Mock() - session.query.side_effect = [upload_query, binding_query] + session.scalar.side_effect = [upload_file, None] result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) assert result is None def test_get_segment_attachment_info_returns_none_when_upload_file_missing(self): - upload_query = Mock() - upload_query.where.return_value.first.return_value = None session = Mock() - session.query.return_value = upload_query + session.scalar.return_value = None result = RetrievalService.get_segment_attachment_info("dataset-id", "tenant-id", "upload-1", session) assert result is None def test_get_segment_attachment_infos_returns_empty_when_upload_files_missing(self): - upload_query = Mock() - upload_query.where.return_value.all.return_value = [] + scalars_result = Mock() + scalars_result.all.return_value = [] session = Mock() - session.query.return_value = upload_query + session.scalars.return_value = scalars_result result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) @@ -1115,12 +1114,12 @@ class TestRetrievalServiceInternals: mime_type="image/png", size=42, ) - upload_query = Mock() - upload_query.where.return_value.all.return_value = [upload_file] - binding_query = Mock() - binding_query.where.return_value.all.return_value = [] + upload_scalars = Mock() + upload_scalars.all.return_value = [upload_file] + binding_scalars = Mock() + binding_scalars.all.return_value = [] session = Mock() - session.query.side_effect = [upload_query, binding_query] + session.scalars.side_effect = [upload_scalars, binding_scalars] result = RetrievalService.get_segment_attachment_infos(["upload-1"], session) @@ -1144,12 +1143,12 @@ class TestRetrievalServiceInternals: ) binding = SimpleNamespace(attachment_id="upload-1", segment_id="segment-1") - upload_query = Mock() - upload_query.where.return_value.all.return_value = [upload_file_1, upload_file_2] - binding_query = Mock() - binding_query.where.return_value.all.return_value = [binding] + upload_scalars = Mock() + upload_scalars.all.return_value = [upload_file_1, upload_file_2] + binding_scalars = Mock() + binding_scalars.all.return_value = [binding] session = Mock() - session.query.side_effect = [upload_query, binding_query] + session.scalars.side_effect = [upload_scalars, binding_scalars] result = RetrievalService.get_segment_attachment_infos(["upload-1", "upload-2"], session) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index 4e9ceddda9a..dc21d378a24 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -21,6 +21,9 @@ def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str def vector_factory_module(): import importlib + from core.rag.datasource.vdb import vector_backend_registry as reg + + reg.clear_vector_factory_cache() import core.rag.datasource.vdb.vector_factory as module return importlib.reload(module) @@ -41,61 +44,62 @@ def test_gen_index_struct_dict(vector_factory_module): @pytest.mark.parametrize( ("vector_type", "module_path", "class_name"), [ - ("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"), - ("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"), + ("CHROMA", "dify_vdb_chroma.chroma_vector", "ChromaVectorFactory"), + ("MILVUS", "dify_vdb_milvus.milvus_vector", "MilvusVectorFactory"), ( "ALIBABACLOUD_MYSQL", - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector", + "dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector", "AlibabaCloudMySQLVectorFactory", ), - ("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"), - ("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"), - ("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"), - ("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"), - ("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"), - ("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"), + ("MYSCALE", "dify_vdb_myscale.myscale_vector", "MyScaleVectorFactory"), + ("PGVECTOR", "dify_vdb_pgvector.pgvector", "PGVectorFactory"), + ("VASTBASE", "dify_vdb_vastbase.vastbase_vector", "VastbaseVectorFactory"), + ("PGVECTO_RS", "dify_vdb_pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"), + ("QDRANT", "dify_vdb_qdrant.qdrant_vector", "QdrantVectorFactory"), + ("RELYT", "dify_vdb_relyt.relyt_vector", "RelytVectorFactory"), ( "ELASTICSEARCH", - "core.rag.datasource.vdb.elasticsearch.elasticsearch_vector", + "dify_vdb_elasticsearch.elasticsearch_vector", "ElasticSearchVectorFactory", ), ( "ELASTICSEARCH_JA", - "core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector", + "dify_vdb_elasticsearch.elasticsearch_ja_vector", "ElasticSearchJaVectorFactory", ), - ("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"), - ("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"), - ("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"), - ("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"), + ("TIDB_VECTOR", "dify_vdb_tidb_vector.tidb_vector", "TiDBVectorFactory"), + ("WEAVIATE", "dify_vdb_weaviate.weaviate_vector", "WeaviateVectorFactory"), + ("TENCENT", "dify_vdb_tencent.tencent_vector", "TencentVectorFactory"), + ("ORACLE", "dify_vdb_oracle.oraclevector", "OracleVectorFactory"), ( "OPENSEARCH", - "core.rag.datasource.vdb.opensearch.opensearch_vector", + "dify_vdb_opensearch.opensearch_vector", "OpenSearchVectorFactory", ), - ("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"), - ("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"), - ("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"), - ("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"), - ("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"), + ("ANALYTICDB", "dify_vdb_analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"), + ("COUCHBASE", "dify_vdb_couchbase.couchbase_vector", "CouchbaseVectorFactory"), + ("BAIDU", "dify_vdb_baidu.baidu_vector", "BaiduVectorFactory"), + ("VIKINGDB", "dify_vdb_vikingdb.vikingdb_vector", "VikingDBVectorFactory"), + ("UPSTASH", "dify_vdb_upstash.upstash_vector", "UpstashVectorFactory"), ( "TIDB_ON_QDRANT", - "core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector", + "dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector", "TidbOnQdrantVectorFactory", ), - ("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"), - ("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), - ("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), - ("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"), - ("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"), + ("LINDORM", "dify_vdb_lindorm.lindorm_vector", "LindormVectorStoreFactory"), + ("OCEANBASE", "dify_vdb_oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("SEEKDB", "dify_vdb_oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("OPENGAUSS", "dify_vdb_opengauss.opengauss", "OpenGaussFactory"), + ("TABLESTORE", "dify_vdb_tablestore.tablestore_vector", "TableStoreVectorFactory"), ( "HUAWEI_CLOUD", - "core.rag.datasource.vdb.huawei.huawei_cloud_vector", + "dify_vdb_huawei_cloud.huawei_cloud_vector", "HuaweiCloudVectorFactory", ), - ("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"), - ("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"), - ("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"), + ("MATRIXONE", "dify_vdb_matrixone.matrixone_vector", "MatrixoneVectorFactory"), + ("CLICKZETTA", "dify_vdb_clickzetta.clickzetta_vector", "ClickzettaVectorFactory"), + ("IRIS", "dify_vdb_iris.iris_vector", "IrisVectorFactory"), + ("HOLOGRES", "dify_vdb_hologres.hologres_vector", "HologresVectorFactory"), ], ) def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name): @@ -111,6 +115,34 @@ def test_get_vector_factory_unsupported(vector_factory_module): vector_factory_module.Vector.get_vector_factory("unknown") +class _PluginChromaFactory: + """Stub used only for entry-point override test.""" + + +def test_get_vector_factory_entry_point_overrides_builtin(vector_factory_module, monkeypatch): + from importlib.metadata import EntryPoint + + from core.rag.datasource.vdb import vector_backend_registry as reg + + reg.clear_vector_factory_cache() + ep = EntryPoint( + name="chroma", + value=f"{__name__}:_PluginChromaFactory", + group="dify.vector_backends", + ) + + class _FakeGroups: + def select(self, *, group: str): + if group == "dify.vector_backends": + return (ep,) + return () + + monkeypatch.setattr(reg, "entry_points", lambda: _FakeGroups()) + + result_cls = vector_factory_module.Vector.get_vector_factory(vector_factory_module.VectorType.CHROMA) + assert result_cls is _PluginChromaFactory + + def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): dataset = SimpleNamespace(id="dataset-1") @@ -121,7 +153,18 @@ def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): default_vector = vector_factory_module.Vector(dataset) custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) - assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + # `is_summary` and `original_chunk_id` must be in the default return-properties + # projection so summary index retrieval works on backends that honor the list + # as an explicit projection (e.g. Weaviate). See #34884. + assert default_vector._attributes == [ + "doc_id", + "dataset_id", + "document_id", + "doc_hash", + "doc_type", + "is_summary", + "original_chunk_id", + ] assert custom_vector._attributes == ["doc_id"] assert default_vector._embeddings == "embeddings" assert default_vector._vector_processor == "processor" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py index a7b7c1595b6..007a76aa66c 100644 --- a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -721,6 +721,30 @@ class TestDatasetDocumentStoreMultimodelBinding: mock_db.session.add.assert_not_called() + def test_add_multimodel_documents_binding_with_none_document_id(self): + """Test that no bindings are added when document_id is None.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + + mock_attachment = MagicMock(spec=AttachmentDocument) + mock_attachment.metadata = {"doc_id": "attachment-1"} + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id=None, + ) + + store.add_multimodel_documents_binding("seg-1", [mock_attachment]) + + mock_db.session.add.assert_not_called() + class TestDatasetDocumentStoreAddDocumentsUpdateChild: """Tests for add_documents when updating existing documents with children.""" diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index 35631861868..051a1455aef 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -12,11 +12,11 @@ from unittest.mock import Mock, patch import numpy as np import pytest -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 408cf14a51b..4b8175b0b42 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,6 +49,10 @@ from unittest.mock import Mock, patch import numpy as np import pytest +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from graphon.model_runtime.errors.invoke import ( @@ -56,10 +60,6 @@ from graphon.model_runtime.errors.invoke import ( InvokeConnectionError, InvokeRateLimitError, ) -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 64eb89590a3..0220fb6d4aa 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -1,12 +1,14 @@ """Primarily used for testing merged cell scenarios""" +import gc import io import os import tempfile +import warnings from collections import UserDict from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from docx import Document @@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): WordExtractor("not-a-file", "tenant", "user") -def test_del_closes_temp_file(): +def test_close_closes_temp_file(): extractor = object.__new__(WordExtractor) + extractor._closed = False extractor.temp_file = MagicMock() - WordExtractor.__del__(extractor) + extractor.close() extractor.temp_file.close.assert_called_once() +def test_close_is_idempotent(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + + extractor.close() + extractor.close() + + extractor.temp_file.close.assert_called_once() + + +def test_close_handles_async_close_mock(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + extractor.temp_file.close = AsyncMock() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + extractor.close() + gc.collect() + + extractor.temp_file.close.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] + + def test_extract_images_handles_invalid_external_cases(monkeypatch): class FakeTargetRef: def __contains__(self, item): diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index d4b987c8325..4ba4d54fa09 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -1,15 +1,16 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelFeature from core.entities.knowledge_entities import PreviewDetail from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: @@ -71,7 +72,9 @@ class TestParagraphIndexProcessor: with pytest.raises(ValueError, match="No rules found in process rule"): processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) - def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + def test_transform_validates_segmentation( + self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any] + ) -> None: rules_without_segmentation = SimpleNamespace(segmentation=None) with patch( @@ -84,7 +87,9 @@ class TestParagraphIndexProcessor: process_rule={"mode": "custom", "rules": {"enabled": True}}, ) - def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + def test_transform_builds_split_documents( + self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any] + ) -> None: source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) splitter = Mock() splitter.split_documents.return_value = [ diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index d363a0804d0..8ef0e046ef6 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock, Mock, patch import pytest from core.entities.knowledge_entities import PreviewDetail +from core.rag.entities import ParentMode from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.models.document import AttachmentDocument, ChildDocument, Document -from services.entities.knowledge_entities.knowledge_entities import ParentMode class TestParentChildIndexProcessor: @@ -258,10 +258,10 @@ class TestParentChildIndexProcessor: session.commit.assert_called_once() def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: - segment_query = Mock() - segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="seg-1")] session = Mock() - session.query.return_value = segment_query + session.scalars.return_value = scalars_result session_ctx = MagicMock() session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 98c47bec8f9..bfae9001b77 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, patch import pandas as pd @@ -77,7 +78,7 @@ class TestQAIndexProcessor: processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) def test_transform_preview_calls_formatter_once( - self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app ) -> None: document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) split_node = Document(page_content=".question", metadata={}) @@ -119,7 +120,7 @@ class TestQAIndexProcessor: mock_format.assert_called_once() def test_transform_non_preview_uses_thread_batches( - self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app ) -> None: documents = [ Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}), @@ -220,10 +221,10 @@ class TestQAIndexProcessor: self, processor: QAIndexProcessor, dataset: Mock ) -> None: mock_segment = SimpleNamespace(id="seg-1") - mock_query = Mock() - mock_query.filter.return_value.all.return_value = [mock_segment] + scalars_result = Mock() + scalars_result.all.return_value = [mock_segment] mock_session = Mock() - mock_session.query.return_value = mock_query + mock_session.scalars.return_value = scalars_result session_context = MagicMock() session_context.__enter__.return_value = mock_session session_context.__exit__.return_value = False diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 57f92660a03..14630083d3e 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,7 +53,6 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -64,6 +63,7 @@ from core.indexing_runner import ( ) from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document +from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index c279b00d3bc..8bc7dbf70db 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -17,7 +17,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_manager import ModelInstance from core.rag.index_processor.constant.doc_type import DocType @@ -29,6 +28,7 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index fee7b168ad0..89830f75174 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -1,18 +1,13 @@ import threading from contextlib import contextmanager, nullcontext from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest from flask import Flask, current_app -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelFeature -from sqlalchemy import column -from core.app.app_config.entities import ( - Condition as AppCondition, -) from core.app.app_config.entities import ( DatasetEntity, DatasetRetrieveConfigEntity, @@ -29,6 +24,7 @@ from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.rag.data_post_processor.data_post_processor import WeightsDict from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities import Condition as AppCondition from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document @@ -37,6 +33,8 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset from models.enums import CreatorUserRole @@ -48,7 +46,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -2024,7 +2022,7 @@ def create_mock_document_methods( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -4041,21 +4039,9 @@ class TestDatasetRetrievalAdditionalHelpers: def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None: session = Mock() - subquery_query = Mock() - subquery_query.where.return_value = subquery_query - subquery_query.group_by.return_value = subquery_query - subquery_query.having.return_value = subquery_query - subquery_query.subquery.return_value = SimpleNamespace( - c=SimpleNamespace( - dataset_id=column("dataset_id"), available_document_count=column("available_document_count") - ) - ) - - dataset_query = Mock() - dataset_query.outerjoin.return_value = dataset_query - dataset_query.where.return_value = dataset_query - dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] - session.query.side_effect = [subquery_query, dataset_query] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] + session.scalars.return_value = scalars_result session_ctx = MagicMock() session_ctx.__enter__.return_value = session @@ -4106,7 +4092,7 @@ def _doc( dataset_id: str = "dataset-1", document_id: str = "document-1", doc_id: str = "node-1", - extra: dict | None = None, + extra: dict[str, Any] | None = None, ) -> Document: metadata = { "score": score, @@ -4904,22 +4890,21 @@ class TestInternalHooksCoverage: _scalars(segments), _scalars(bindings), ] - query = Mock() - query.where.return_value = query - session.query.return_value = query session_ctx = MagicMock() session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False + sessionmaker_ctx = MagicMock() + sessionmaker_ctx.begin.return_value = session_ctx + with ( patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), - patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx), + patch("core.rag.retrieval.dataset_retrieval.sessionmaker", return_value=sessionmaker_ctx), patch.object(retrieval, "_send_trace_task") as mock_trace, ): retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) - query.update.assert_called_once() - session.commit.assert_called_once() + session.execute.assert_called_once() mock_trace.assert_called_once() def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py index 48782515d05..90feb4cf019 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 @@ -55,7 +56,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index 5a2ecb82204..43c521dcfd4 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,8 +1,7 @@ from unittest.mock import Mock -from graphon.model_runtime.entities.llm_entities import LLMUsage - from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter +from graphon.model_runtime.entities.llm_entities import LLMUsage class TestFunctionCallMultiDatasetRouter: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index 539ac0f849f..c56528cf55e 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -1,13 +1,12 @@ from types import SimpleNamespace from unittest.mock import Mock, patch +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.model_runtime.entities.model_entities import ModelType -from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish -from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter - class TestReactMultiDatasetRouter: def test_invoke_returns_none_when_no_tools(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e229d5fc1a5..3d3322094eb 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,10 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowType from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 7dbf78d0f0c..05b4f3a053c 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from graphon.enums import BuiltinNodeTypes - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 0fc82dda530..18ae9fafc8e 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,11 +7,6 @@ from datetime import datetime from types import SimpleNamespace import pytest -from graphon.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -19,13 +14,18 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 8ff0e405874..4248782d936 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -9,8 +9,6 @@ from typing import Any from unittest.mock import MagicMock import pytest -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( FormCreateParams, @@ -23,7 +21,7 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -31,6 +29,8 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index e5c3e854875..a08c5729cb1 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 5b4d26b7808..6af7b02d4cc 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -10,12 +10,6 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest -from graphon.entities import WorkflowNodeExecution -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) from sqlalchemy import Engine, create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker @@ -29,6 +23,12 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( _find_first, _replace_or_append_offload, ) +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 84fe522388e..abdbc72085f 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 27729e7f06b..5af1376a0a9 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index ac65d0c02bc..eab0176f419 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,15 +1,14 @@ import json from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig - from models.workflow import Workflow def test_file_to_dict(): file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="storage_key", diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index f5efb78b614..afea9144c06 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock, patch import pytest import redis -from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 331166fe63c..b19a21d7f44 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,15 +1,6 @@ from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -21,6 +12,15 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index ee261724597..a5a542c94f1 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -2,12 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel from models.provider_ids import ModelProviderID @@ -372,6 +372,78 @@ def test_get_configurations_binds_manager_runtime_to_provider_configuration( provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) +def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls, + patch( + "core.provider_manager.ProviderConfiguration", + return_value=provider_configuration, + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is second + mock_get_all_providers.assert_called_once_with("tenant-id") + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_provider_configuration.assert_called_once() + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_clear_configurations_cache_rebuilds_requested_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + provider_configuration_first = Mock() + provider_configuration_second = Mock() + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch( + "core.provider_manager.ProviderConfiguration", + side_effect=[provider_configuration_first, provider_configuration_second], + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + manager.clear_configurations_cache("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is not second + assert mock_get_all_providers.call_count == 2 + assert mock_provider_configuration.call_count == 2 + provider_configuration_first.bind_model_runtime.assert_called_once_with(manager._model_runtime) + provider_configuration_second.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): manager = _build_provider_manager(mocker) provider_configuration = Mock() diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index 5d744f88c9b..1ff81f61200 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -6,13 +6,13 @@ from typing import Any from unittest.mock import patch import pytest -from graphon.model_runtime.entities.message_entities import UserPromptMessage from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from graphon.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index ee0ce51eec9..c7829fc0d71 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -6,8 +6,6 @@ from datetime import date from types import SimpleNamespace import pytest -from graphon.file import FileType -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -29,6 +27,8 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError +from graphon.file import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: diff --git a/api/tests/unit_tests/core/tools/test_custom_tool.py b/api/tests/unit_tests/core/tools/test_custom_tool.py index 79b8eaaa870..f35546b0251 100644 --- a/api/tests/unit_tests/core/tools/test_custom_tool.py +++ b/api/tests/unit_tests/core/tools/test_custom_tool.py @@ -1,6 +1,7 @@ from __future__ import annotations from types import SimpleNamespace +from typing import Any import httpx import pytest @@ -14,7 +15,7 @@ from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvo from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -def _build_tool(*, openapi: dict | None = None) -> ApiTool: +def _build_tool(*, openapi: dict[str, Any] | None = None) -> ApiTool: entity = ToolEntity( identity=ToolIdentity( author="author", diff --git a/api/tests/unit_tests/core/tools/test_tool_engine.py b/api/tests/unit_tests/core/tools/test_tool_engine.py index 40c107667cb..cd16557ef64 100644 --- a/api/tests/unit_tests/core/tools/test_tool_engine.py +++ b/api/tests/unit_tests/core/tools/test_tool_engine.py @@ -260,6 +260,28 @@ def test_agent_invoke_engine_meta_error(): assert error_meta.error == "meta failure" +def test_convert_tool_response_excludes_variable_messages(): + """Regression test for issue #34723. + + WorkflowTool._invoke yields VARIABLE, TEXT, and suppressed-JSON messages. + _convert_tool_response_to_str must skip VARIABLE messages so that the + returned string contains only the TEXT representation and not a + duplicated, garbled Pydantic repr of the same data. + """ + tool = _build_tool() + outputs = {"reports": "hello"} + messages = [ + tool.create_variable_message(variable_name="reports", variable_value="hello"), + tool.create_text_message('{"reports": "hello"}'), + tool.create_json_message(outputs, suppress_output=True), + ] + + result = ToolEngine._convert_tool_response_to_str(messages) + + assert result == '{"reports": "hello"}' + assert "variable_name" not in result + + def test_agent_invoke_tool_invoke_error(): tool = _build_tool(with_llm_parameter=True) callback = Mock() diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index 7fcebde3c55..ccffdf16d10 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -12,9 +12,9 @@ from unittest.mock import MagicMock, Mock, patch import httpx import pytest -from graphon.file import FileTransferMethod from core.tools.tool_file_manager import ToolFileManager +from graphon.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: @@ -129,7 +129,7 @@ def test_get_file_binary_returns_none_when_not_found() -> None: # Arrange manager = ToolFileManager() session = Mock() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None # Act with _patch_session_factory(session): @@ -144,7 +144,7 @@ def test_get_file_binary_returns_bytes_when_found() -> None: manager = ToolFileManager() tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain") session = Mock() - session.query.return_value.where.return_value.first.return_value = tool_file + session.scalar.return_value = tool_file # Act with patch("core.tools.tool_file_manager.storage") as storage: @@ -160,11 +160,7 @@ def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None: # Arrange manager = ToolFileManager() session = Mock() - first_query = Mock() - second_query = Mock() - first_query.where.return_value.first.return_value = None - second_query.where.return_value.first.return_value = None - session.query.side_effect = [first_query, second_query] + session.scalar.side_effect = [None, None] # Act with _patch_session_factory(session): @@ -179,11 +175,7 @@ def test_get_file_binary_by_message_file_id_when_url_is_none() -> None: manager = ToolFileManager() message_file = SimpleNamespace(url=None) session = Mock() - first_query = Mock() - second_query = Mock() - first_query.where.return_value.first.return_value = message_file - second_query.where.return_value.first.return_value = None - session.query.side_effect = [first_query, second_query] + session.scalar.side_effect = [message_file, None] # Act with _patch_session_factory(session): @@ -199,11 +191,7 @@ def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None: message_file = SimpleNamespace(url="https://x/files/tools/tool123.png") tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") session = Mock() - first_query = Mock() - second_query = Mock() - first_query.where.return_value.first.return_value = message_file - second_query.where.return_value.first.return_value = tool_file - session.query.side_effect = [first_query, second_query] + session.scalar.side_effect = [message_file, tool_file] # Act with patch("core.tools.tool_file_manager.storage") as storage: @@ -219,7 +207,7 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None: # Arrange manager = ToolFileManager() session = Mock() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None # Act with _patch_session_factory(session): @@ -242,7 +230,7 @@ def test_get_file_generator_returns_stream_when_found() -> None: size=12, ) session = Mock() - session.query.return_value.where.return_value.first.return_value = tool_file + session.scalar.return_value = tool_file # Act with patch("core.tools.tool_file_manager.storage") as storage: diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index 8c0e7e9419e..e13f430f9b1 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -2,7 +2,7 @@ from __future__ import annotations from types import SimpleNamespace from typing import Any -from unittest.mock import PropertyMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -12,11 +12,13 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +# Create a mock class for testing abstract/base classes class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController): def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): return None +# Factory function to create a "lightweight" controller for testing def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController: controller = object.__new__(ApiToolProviderController) controller.provider_id = provider_id @@ -29,6 +31,7 @@ def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderContr return controller +# Test pure logic: filtering and deduplication def test_tool_label_manager_filter_tool_labels(): filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"]) assert set(filtered) == {"search", "news"} @@ -36,22 +39,68 @@ def test_tool_label_manager_filter_tool_labels(): def test_tool_label_manager_update_tool_labels_db(): + """ + Test the database update logic for tool labels. + Focus: Verify that labels are filtered, de-duplicated, and safely handled within a database session. + """ + # 1. Setup expected data from the controller controller = _api_controller("api-1") - with patch("core.tools.tool_label_manager.db") as mock_db: + expected_id = controller.provider_id + expected_type = controller.provider_type + + # 2. Patching External Dependencies + # - We patch 'db' to prevent Flask from trying to access a real database. + # - We patch 'sessionmaker' to intercept and control the creation of SQLAlchemy sessions. + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + # 3. Constructing the "Mocking Chain" + # In the business logic, we use: with sessionmaker(db.engine).begin() as _session: + # We need to link our 'mock_session' to the end of this complex context manager chain: + # Step A: sessionmaker(db.engine) -> returns an object (mock_sessionmaker.return_value) + # Step B: .begin() -> returns a context manager (begin.return_value) + # Step C: with ... as _session: -> calls __enter__(), and _session gets the __enter__.return_value + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + + # 4. Trigger the logic under test + # Input: ["search", "search", "invalid"] + # Logic: + # - "invalid" should be filtered out (not in default_tool_label_name_list). + # - The duplicate "search" should be merged (unique labels). ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) - mock_db.session.execute.assert_called_once() - # only one valid unique label should be inserted. - assert mock_db.session.add.call_count == 1 - mock_db.session.commit.assert_called_once() + # 5. Behavior Assertion: DELETE operation + # Verify that the manager first attempts to clear existing labels for this specific tool. + # This ensures the update is idempotent. + mock_session.execute.assert_called_once() + + # 6. Behavior Assertion: INSERT operation + # Verify that only ONE valid label ("search") was added after filtering and deduplication. + # If call_count == 1, it proves filter_tool_labels() worked as expected. + assert mock_session.add.call_count == 1 + + # 7. State Assertion: Data Integrity & Isolation + # Inspect the actual object passed to session.add() to ensure it has correct properties. + # This confirms that the data isolation (tool_id + tool_type) we refactored is active. + call_args = mock_session.add.call_args + added_label = call_args[0][0] # Retrieve the ToolLabelBinding instance + + assert added_label.label_name == "search", "The label name should be 'search' after filtering." + assert added_label.tool_id == expected_id, "The tool_id must match the provider_id for correct binding." + assert added_label.tool_type == expected_type, "Isolation failed: tool_type must be verified during update." +# Test error handling def test_tool_label_manager_update_tool_labels_unsupported(): with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type] +# Test retrieval logic def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): + # Mocking a property (@property) using PropertyMock with patch.object( _ConcreteBuiltinToolProviderController, "tool_labels", @@ -62,29 +111,67 @@ def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"] api = _api_controller("api-1") - with patch("core.tools.tool_label_manager.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = ["search", "news"] - labels = ToolLabelManager.get_tool_labels(api) - assert labels == ["search", "news"] + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + # Inject mock data into the query result: session.scalars(stmt).all() + mock_session.scalars.return_value.all.return_value = ["search", "news"] + + labels = ToolLabelManager.get_tool_labels(api) + assert labels == ["search", "news"] + + +def test_tool_label_manager_get_tool_labels_unsupported(): + """ + Negative Test: Ensure get_tool_labels raises ValueError for unsupported controller types. + This protects the internal API contract against accidental regressions during refactoring. + """ + # Passing a generic object() which doesn't match Api, Workflow, or Builtin controllers. with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type] +# Test batch processing and mapping def test_tool_label_manager_get_tools_labels_batch(): assert ToolLabelManager.get_tools_labels([]) == {} api = _api_controller("api-1") wf = _workflow_controller("wf-1") + + # SimpleNamespace is a quick way to simulate SQLAlchemy row objects records = [ SimpleNamespace(tool_id="api-1", label_name="search"), SimpleNamespace(tool_id="api-1", label_name="news"), SimpleNamespace(tool_id="wf-1", label_name="utilities"), ] - with patch("core.tools.tool_label_manager.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = records + + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + + # Simulating the batch query result + mock_session.scalars.return_value.all.return_value = records + labels = ToolLabelManager.get_tools_labels([api, wf]) + + # Verify the final dictionary mapping assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]} + +def test_tool_label_manager_get_tools_labels_unsupported(): + """ + Negative Test: Ensure get_tools_labels raises ValueError if the list contains + unsupported controller types, even alongside valid ones. + """ + api = _api_controller("api-1") + + # Passing a list with one valid controller and one invalid object() with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item] diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 31b68f0b3f3..9ebaa0417bf 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -637,7 +637,7 @@ def test_list_default_builtin_providers_for_postgres_and_mysql(): for scheme in ("postgresql", "mysql"): session = Mock() session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] - session.query.return_value.where.return_value.all.return_value = provider_records + session.scalars.return_value = iter(provider_records) with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)): with patch("core.tools.tool_manager.db") as mock_db: diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index 6454a5bcd1f..5f34135af4b 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest import core.tools.utils.message_transformer as mt @@ -13,7 +15,7 @@ class _FakeToolFile: class _FakeToolFileManager: """Fake ToolFileManager to capture the mimetype passed in.""" - last_call: dict | None = None + last_call: dict[str, Any] | None = None def __init__(self, *args, **kwargs): pass diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 52f262e1cf1..44785f939ca 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -10,9 +10,12 @@ from __future__ import annotations from decimal import Decimal from types import SimpleNamespace +from typing import Any from unittest.mock import Mock, patch import pytest + +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -22,10 +25,8 @@ from graphon.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) -from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils - -def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: +def _mock_model_instance(*, schema: dict[str, Any] | None = None) -> SimpleNamespace: model_type_instance = Mock() model_type_instance.get_model_schema.return_value = ( SimpleNamespace(model_properties=schema or {}) if schema is not None else None diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index 40f91b12a00..032b1377a42 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -1,4 +1,5 @@ from json.decoder import JSONDecodeError +from typing import Any from unittest.mock import Mock, patch import pytest @@ -259,8 +260,8 @@ def test_parse_openapi_to_tool_bundle_server_env_and_refs(app): }, } - extra_info: dict = {} - warning: dict = {} + extra_info: dict[str, Any] = {} + warning: dict[str, Any] = {} with app.test_request_context(headers={"X-Request-Env": "prod"}): bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @@ -298,7 +299,7 @@ def test_parse_swagger_to_openapi_branches(): } ) - warning: dict = {"seed": True} + warning: dict[str, Any] = {"seed": True} converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi( { "servers": [{"url": "https://x"}], diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index 0e3a7e623a8..43f3fbd5c99 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index 2607861b59d..5a585c609af 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -4,7 +4,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -14,6 +13,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from graphon.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: @@ -43,7 +43,7 @@ def test_get_db_provider_tool_builds_entity(): controller = _controller() session = Mock() workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={}) - session.query.return_value.where.return_value.first.return_value = workflow + session.scalar.return_value = workflow app = SimpleNamespace(id="app-1") db_provider = SimpleNamespace( id="provider-1", @@ -136,7 +136,7 @@ def test_from_db_builds_controller(): parameter_configurations=[], ) session = _mock_session_with_begin() - session.query.return_value.where.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider session.get.side_effect = [app, user] fake_cm = MagicMock() fake_cm.__enter__.return_value = session @@ -163,7 +163,7 @@ def test_get_tools_returns_empty_when_provider_missing(): mock_db.engine = object() with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: session = _mock_session_with_begin() - session.query.return_value.where.return_value.first.return_value = None + session.scalar.return_value = None session_cls.return_value.__enter__.return_value = session assert controller.get_tools("tenant-1") == [] @@ -189,7 +189,7 @@ def test_get_tools_raises_when_app_missing(): mock_db.engine = object() with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: session = _mock_session_with_begin() - session.query.return_value.where.return_value.first.return_value = db_provider + session.scalar.return_value = db_provider session.get.return_value = None session_cls.return_value.__enter__.return_value = session with pytest.raises(ValueError, match="app not found"): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index c20edd74004..72a73dd9368 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,7 +11,6 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -25,6 +24,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index 78622b78b6b..fb7dc528386 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -8,10 +8,10 @@ and select_trigger_debug_events orchestrator. from __future__ import annotations from datetime import datetime +from typing import Any from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes, NodeType from core.plugin.entities.request import TriggerInvokeEventResponse from core.trigger.constants import ( @@ -27,10 +27,11 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent +from graphon.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID -def _make_poller_args(node_config: dict | None = None) -> dict: +def _make_poller_args(node_config: dict[str, Any] | None = None) -> dict[str, Any]: return { "tenant_id": "t1", "user_id": "u1", diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 7406b88270b..9e07ea1b6db 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,23 +1,30 @@ import dataclasses +from typing import Annotated import orjson import pytest +from pydantic import BaseModel, Discriminator, Tag + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( ArrayAnySegment, + ArrayBooleanSegment, ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + BooleanSegment, FileSegment, FloatSegment, IntegerSegment, NoneSegment, ObjectSegment, Segment, - SegmentUnion, StringSegment, get_segment_discriminator, ) @@ -42,11 +49,26 @@ from graphon.variables.variables import ( StringVariable, Variable, ) -from pydantic import BaseModel +from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup -from core.helper import encrypter -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool +type SegmentUnion = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] + ), + Discriminator(get_segment_discriminator), +] def _build_variable_pool( @@ -123,7 +145,7 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, @@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad: assert restored == model def test_all_segments_serialization(self): - """Test serialization/deserialization of all segment types""" + """Test file-aware segment serialization through Dify's model boundary.""" # Create one instance of each segment type test_file = create_test_file() @@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Segments(segments=all_segments) json_str = model.model_dump_json() - loaded = _Segments.model_validate_json(json_str) + loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all segments are preserved assert len(loaded.segments) == len(all_segments) @@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad: assert loaded_segment.value == original.value def test_all_variables_serialization(self): - """Test serialization/deserialization of all variable types""" + """Test file-aware variable serialization through Dify's model boundary.""" # Create one instance of each variable type test_file = create_test_file() @@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Variables(variables=all_variables) json_str = model.model_dump_json() - loaded = _Variables.model_validate_json(json_str) + loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all variables are preserved assert len(loaded.variables) == len(all_variables) diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index 37ecd2890bb..d4e862220ae 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,4 +1,5 @@ import pytest + from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import StringSegment from graphon.variables.types import ArrayValidation, SegmentType diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 09254e17a30..317fe99d37e 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import Any import pytest + from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( @@ -34,7 +35,7 @@ def create_test_file( """Factory function to create File objects for testing.""" return File( tenant_id="test-tenant", - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 75b01bf42e9..dae5e1ce984 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,4 +1,6 @@ import pytest +from pydantic import ValidationError + from graphon.variables import ( ArrayFileVariable, ArrayVariable, @@ -10,7 +12,6 @@ from graphon.variables import ( StringVariable, ) from graphon.variables.variables import VariableBase -from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 41627f5e0be..025d79b25df 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,12 +5,13 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider +from graphon.enums import BuiltinNodeTypes + @pytest.fixture def memory_span_exporter(): @@ -61,9 +62,8 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" - from graphon.nodes.tool.entities import ToolNodeData - from core.tools.entities.tool_entities import ToolProviderType + from graphon.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 99d131737e9..5d6667257f6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -3,17 +3,16 @@ from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom +from core.app.workflow.layers.llm_quota import LLMQuotaLayer +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.entities.commands import CommandType from graphon.graph_events import NodeRunSucceededEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom -from core.app.workflow.layers.llm_quota import LLMQuotaLayer -from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance - def _build_dify_context() -> DifyRunContext: return DifyRunContext( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 9cf72763ee2..919f15efd09 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest -from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 88989db8565..9f3e3b00b9e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -1,18 +1,18 @@ -""" -Mock node factory for testing workflows with third-party service dependencies. +"""Mock node factory for third-party-service workflow tests. -This module provides a MockNodeFactory that automatically detects and mocks nodes -requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). +The factory follows the same config adaptation path as production +`DifyNodeFactory.create_node()`, but swaps selected node classes for mock +implementations before instantiation. """ from typing import TYPE_CHECKING, Any +from core.workflow.human_input_adapter import adapt_node_config_for_graph +from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node -from core.workflow.node_factory import DifyNodeFactory - from .test_mock_nodes import ( MockAgentNode, MockCodeNode, @@ -83,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory): :param node_config: Node configuration dictionary :return: Node instance (real or mocked) """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) + node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_type = node_data.type # Check if this node type should be mocked if node_type in self._mock_node_types: - node_id = typed_node_config["id"] - # Create mock node instance mock_class = self._mock_node_types[node_type] + resolved_node_data = self._validate_resolved_node_data(mock_class, node_data) if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -105,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory): ) elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -121,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory): BuiltinNodeTypes.PARAMETER_EXTRACTOR, }: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -131,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory): ) else: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -141,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory): return mock_instance # For non-mocked node types, use parent implementation - return super().create_node(typed_node_config) + return super().create_node(node_config) def should_mock_node(self, node_type: NodeType) -> bool: """ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 8b7fbd1b303..f9819c47ec9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,6 +10,10 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock +from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -27,11 +31,6 @@ from graphon.nodes.template_transform import TemplateTransformNode from graphon.nodes.tool import ToolNode from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from core.model_manager import ModelInstance -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode - if TYPE_CHECKING: from graphon.entities import GraphInitParams from graphon.runtime import GraphRuntimeState @@ -56,13 +55,14 @@ class MockNodeMixin: def __init__( self, - id: str, - config: Mapping[str, Any], + node_id: str, + config: Any, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", mock_config: Optional["MockConfig"] = None, **kwargs: Any, - ): + ) -> None: if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) @@ -97,7 +97,7 @@ class MockNodeMixin: kwargs.setdefault("message_transformer", MagicMock()) super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 8311a1e847a..75bc6d05f7e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,6 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -23,14 +30,6 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -140,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( - id=start_config["id"], - config=start_config, + node_id=start_config["id"], + config=StartNodeData(title="Start", variables=[]), graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -155,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_a_config = {"id": "human_a", "data": human_data.model_dump()} human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, + node_id=human_a_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -165,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_b_config = {"id": "human_b", "data": human_data.model_dump()} human_b = HumanInputNode( - id=human_b_config["id"], - config=human_b_config, + node_id=human_b_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -183,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor ) end_config = {"id": "end", "data": end_data.model_dump()} end_node = EndNode( - id=end_config["id"], - config=end_config, + node_id=end_config["id"], + config=end_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index b11f9576777..7d23b63049c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,6 +19,11 @@ from functools import lru_cache from pathlib import Path from typing import Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from graphon.entities import GraphInitParams from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -39,12 +44,6 @@ from graphon.variables import ( StringVariable, ) -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool - from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py index cbc920705ca..1f4509af9a1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -1,9 +1,8 @@ from unittest.mock import patch -from graphon.enums import BuiltinNodeTypes - from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from graphon.enums import BuiltinNodeTypes def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py index 59dd763b59d..c86de7f6e63 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -1,9 +1,8 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from graphon.model_runtime.entities.model_entities import ModelType - from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from graphon.model_runtime.entities.model_entities import ModelType def test_fetch_model_reuses_single_model_assembly(): diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 7195471eb6b..76b4cd1ef45 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,15 +2,15 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.answer.answer_node import AnswerNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.nodes.answer.entities import AnswerNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -67,20 +67,15 @@ def test_execute_answer(): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - } - node = AnswerNode( - id=str(uuid.uuid4()), + node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, + config=AnswerNodeData( + title="123", + type="answer", + answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + ), ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 343bcd39193..ec4cef1955a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,10 +1,10 @@ import pytest + +from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node -from core.workflow.node_factory import get_node_type_classes_mapping - # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index b9371a34f44..ef0df55995e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,6 +1,7 @@ import types from collections.abc import Mapping +from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node @@ -13,8 +14,6 @@ from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) -from core.workflow.node_factory import get_node_type_classes_mapping - def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index d155124c501..ce0c9b79c68 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,3 +1,4 @@ +from configs import dify_config from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.entities import CodeLanguage, CodeNodeData from graphon.nodes.code.exc import ( @@ -8,8 +9,6 @@ from graphon.nodes.code.exc import ( from graphon.nodes.code.limits import CodeNodeLimits from graphon.variables.types import SegmentType -from configs import dify_config - CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index fb03ae9998d..d7ef7817329 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,8 +1,8 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: @@ -78,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=gp, graph_runtime_state=gs, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index a5026b40cf6..be7cc073dba 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,4 +1,8 @@ import pytest + +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables from graphon.file.file_manager import file_manager from graphon.nodes.http_request import ( BodyData, @@ -12,10 +16,6 @@ from graphon.nodes.http_request.exc import AuthorizationConfigError from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.system_variables import default_system_variables - HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 4705b3f76ec..2e89a2da3ce 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,17 +3,17 @@ from typing import Any import httpx import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config(): assert default_config["retry_config"]["max_retries"] == 7 -def test_get_default_config_with_malformed_http_request_config_raises_value_error(): - with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"): +def test_get_default_config_with_malformed_http_request_config_raises_type_error(): + with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"): HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"}) @@ -114,8 +114,8 @@ def _build_http_node( start_at=time.perf_counter(), ) return HttpRequestNode( - id="http-node", - config=node_config, + node_id="http-node", + config=HttpRequestNodeData.model_validate(node_data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d16e1233ac9..07430498e55 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,7 +1,6 @@ +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients - def test_render_body_template_replaces_variable_values(): config = EmailDeliveryConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index a2cdbbf132e..0659984c763 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -10,6 +10,28 @@ from typing import Any from unittest.mock import MagicMock import pytest +from pydantic import ValidationError + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) +from core.workflow.human_input_adapter import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + EmailRecipientType, + ExternalRecipient, + MemberRecipient, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.node_events import PauseRequestedEvent from graphon.node_events.node import StreamCompletedEvent @@ -28,28 +50,6 @@ from graphon.nodes.human_input.enums import ( ) from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool -from pydantic import ValidationError - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from core.workflow.human_input_compat import ( - DeliveryMethodType, - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - EmailRecipientType, - ExternalRecipient, - MemberRecipient, - WebAppDeliveryMethod, - _WebAppDeliveryConfig, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now @@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): entity.status_value = HumanInputFormStatus.SUBMITTED +def _build_human_input_node( + *, + node_id: str, + node_data: HumanInputNodeData | Mapping[str, Any], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + runtime: DifyHumanInputNodeRuntime, +) -> HumanInputNode: + typed_node_data = ( + node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data) + ) + return HumanInputNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + runtime=runtime, + ) + + class TestDeliveryMethod: """Test DeliveryMethod entity.""" @@ -239,7 +259,7 @@ class TestUserAction: data[field_name] = value with pytest.raises(ValidationError) as exc_info: - UserAction(**data) + UserAction.model_validate(data) errors = exc_info.value.errors() assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) @@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent: form_repository = InMemoryHumanInputFormRepository() runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 52802c7ce1e..4a9438b14f5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,6 +1,9 @@ import datetime from types import SimpleNamespace +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -8,13 +11,10 @@ from graphon.graph_events import ( NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) +from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now @@ -26,6 +26,28 @@ class _FakeFormRepository: return self._form +def _create_human_input_node( + *, + config: dict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + repo: _FakeFormRepository, +) -> HumanInputNode: + node_data = ( + config["data"] + if isinstance(config["data"], HumanInputNodeData) + else HumanInputNodeData.model_validate(config["data"]) + ) + return HumanInputNode( + node_id=config["id"], + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + ) + + def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( @@ -81,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) @@ -146,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode: ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index bbfe350f7e4..8ffce39cd62 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -2,7 +2,10 @@ from collections.abc import Mapping from typing import Any import pytest + +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams +from graphon.nodes.iteration.entities import IterationNodeData from graphon.nodes.iteration.exc import IterationGraphNotFoundError from graphon.nodes.iteration.iteration_node import IterationNode from graphon.runtime import ( @@ -11,8 +14,6 @@ from graphon.runtime import ( GraphRuntimeState, VariablePool, ) - -from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -44,17 +45,14 @@ def _build_iteration_node( ) -> IterationNode: init_params = build_test_graph_init_params(graph_config=graph_config) return IterationNode( - id="iteration-node", - config={ - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration-node", "output"], - "start_node_id": start_node_id, - }, - }, + node_id="iteration-node", + config=IterationNodeData( + type="iteration", + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id=start_node_id, + ), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index f8802138b58..f254fc3d09a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -3,9 +3,6 @@ import uuid from unittest.mock import Mock import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -19,6 +16,9 @@ from core.workflow.nodes.knowledge_index.protocols import ( SummaryIndexServiceProtocol, ) from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -93,6 +93,25 @@ def sample_chunks(): } +def _build_node( + *, + node_id: str, + node_data: KnowledgeIndexNodeData | dict[str, object], + graph_init_params, + graph_runtime_state, +) -> KnowledgeIndexNode: + return KnowledgeIndexNode( + node_id=node_id, + config=( + node_data + if isinstance(node_data, KnowledgeIndexNodeData) + else KnowledgeIndexNodeData.model_validate(node_data) + ), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + class TestKnowledgeIndexNode: """ Test suite for KnowledgeIndexNode. @@ -115,9 +134,9 @@ class TestKnowledgeIndexNode: } # Act - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -143,9 +162,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -176,9 +195,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -212,9 +231,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -269,9 +288,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -332,9 +351,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -383,9 +402,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -440,9 +459,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -498,9 +517,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -536,9 +555,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -583,9 +602,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index ab64be59ad2..e923ee761ba 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,10 +3,6 @@ import uuid from unittest.mock import Mock import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -18,9 +14,17 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( SingleRetrievalConfig, ) from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import ( + KnowledgeRetrievalNode, + _normalize_metadata_filter_scalar, + _normalize_metadata_filter_sequence_item, +) from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -85,6 +89,12 @@ def sample_node_data(): ) +def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None: + assert _normalize_metadata_filter_scalar(3) == 3 + assert _normalize_metadata_filter_scalar(True) == "True" + assert _normalize_metadata_filter_sequence_item(4) == "4" + + class TestKnowledgeRetrievalNode: """ Test suite for KnowledgeRetrievalNode. @@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode: # Act node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -470,8 +480,8 @@ class TestFetchDatasetRetriever: config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -507,8 +517,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -562,8 +572,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -610,8 +620,8 @@ class TestFetchDatasetRetriever: mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -671,8 +681,8 @@ class TestFetchDatasetRetriever: node_id = str(uuid.uuid4()) config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index fdf1706765a..388654f2797 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,23 +1,42 @@ +from types import SimpleNamespace from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.entities import ListOperatorNodeData from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY - class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" + @staticmethod + def _build_node(*, config, graph_init_params, graph_runtime_state): + return ListOperatorNode( + node_id="test", + config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + @staticmethod + def _filter_by(comparison_operator: str, value: str) -> dict[str, object]: + return { + "enabled": True, + "conditions": [{"comparison_operator": comparison_operator, "value": value}], + } + @pytest.fixture def mock_graph_runtime_state(self): """Create mock GraphRuntimeState.""" mock_state = MagicMock(spec=GraphRuntimeState) mock_variable_pool = MagicMock() + mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value) mock_state.variable_pool = mock_variable_pool return mock_state @@ -45,9 +64,8 @@ class TestListOperatorNode: def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable - return ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + return self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -64,9 +82,8 @@ class TestListOperatorNode: "limit": {"enabled": False}, } - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -109,9 +126,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=[]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -128,11 +144,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "contains", - "value": "app", - }, + "filter_by": self._filter_by("contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -140,9 +152,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -157,11 +168,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "not contains", - "value": "app", - }, + "filter_by": self._filter_by("not contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -169,9 +176,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -186,11 +192,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "5", - }, + "filter_by": self._filter_by(">", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -198,9 +200,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -226,9 +227,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -254,9 +254,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -282,9 +281,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -299,11 +297,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "3", - }, + "filter_by": self._filter_by(">", "3"), "order_by": { "enabled": True, "value": "desc", @@ -317,9 +311,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -341,9 +334,8 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = None - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -366,9 +358,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["first", "middle", "last"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -384,11 +375,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "start with", - "value": "app", - }, + "filter_by": self._filter_by("start with", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -396,9 +383,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -413,11 +399,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "end with", - "value": "le", - }, + "filter_by": self._filter_by("end with", "le"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -425,9 +407,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -442,11 +423,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "=", - "value": "5", - }, + "filter_by": self._filter_by("=", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -454,9 +431,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -471,11 +447,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "≠", - "value": "5", - }, + "filter_by": self._filter_by("≠", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -483,9 +455,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -511,9 +482,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index c784f805c01..212ad07bd38 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -1,6 +1,8 @@ from unittest import mock import pytest + +from core.model_manager import ModelInstance from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities import ( ImagePromptMessageContent, @@ -33,8 +35,6 @@ from graphon.nodes.llm.exc import ( from graphon.runtime import VariablePool from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment -from core.model_manager import ModelInstance - def _build_model_schema( *, @@ -71,8 +71,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -95,6 +95,8 @@ def variable_pool() -> VariablePool: def _fetch_prompt_messages_with_mocked_content(content): variable_pool = VariablePool.empty() model_instance = mock.MagicMock(spec=ModelInstance) + model_schema = mock.MagicMock() + model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text" prompt_template = [ LLMNodeChatModelMessage( text="You are a classifier.", @@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( "graphon.nodes.llm.llm_utils.fetch_model_schema", - return_value=mock.MagicMock(features=[]), + return_value=model_schema, ), mock.patch( "graphon.nodes.llm.llm_utils.handle_list_messages", diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 7841bf05ad9..c707cf28cde 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,6 +5,19 @@ from collections.abc import Sequence from unittest import mock import pytest + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.common_entities import I18nObject @@ -67,19 +80,6 @@ from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment - -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import ( - DifyCredentialsProvider, - DifyModelFactory, - build_dify_model_access, - fetch_model_config, -) -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params @@ -140,8 +140,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -205,14 +205,10 @@ def llm_node( mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers(): def test_fetch_files_with_file_segment(): file = File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment(): def test_fetch_files_with_array_file_segment(): files = [ File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", storage_key="", ), File( - id="2", - type=FileType.IMAGE, + file_id="2", + file_type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="2", @@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/png", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/jpg", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: image_b64_data = base64.b64encode(image_raw_data).decode() mock_saved_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", extension=".png", @@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable file_saver=file_saver, file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, reasoning_format="separated", ) ) @@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=_build_prepared_llm_mock(), reasoning_format="separated", request_start_time=1.0, @@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors(): file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=model_instance, ) ) diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 1c362a0a037..8f8ec49f141 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any import pytest + +from factories.variable_factory import build_segment_with_type from graphon.model_runtime.entities import LLMMode from graphon.nodes.llm import ModelConfig, VisionConfig from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData @@ -18,8 +20,6 @@ from graphon.nodes.parameter_extractor.exc import ( from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.variables.types import SegmentType -from factories.variable_factory import build_segment_with_type - @dataclass class ValidTestCase: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index d86e0efe023..892f6cc5860 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.base.entities import VariableSelector @@ -8,11 +10,31 @@ from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState from graphon.template_rendering import TemplateRenderError - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params +def _build_template_transform_node( + *, + node_data, + graph_init_params, + graph_runtime_state, + node_id: str = "test_node", + **kwargs, +) -> TemplateTransformNode: + typed_node_data = ( + node_data + if isinstance(node_data, TemplateTransformNodeData) + else TemplateTransformNodeData.model_validate(node_data) + ) + return TemplateTransformNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + **kwargs, + ) + + class TestTemplateTransformNode: """Comprehensive test suite for TemplateTransformNode.""" @@ -59,9 +81,8 @@ class TestTemplateTransformNode: def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test that TemplateTransformNode initializes correctly.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -75,9 +96,8 @@ class TestTemplateTransformNode: def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_title method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -88,9 +108,8 @@ class TestTemplateTransformNode: def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_description method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -108,9 +127,8 @@ class TestTemplateTransformNode: } mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -143,9 +161,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() with pytest.raises(ValueError, match="max_output_length must be a positive integer"): - TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -170,9 +187,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -198,9 +214,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Value: " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -218,9 +233,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error") - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -238,9 +252,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -260,9 +273,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "1234567890" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -302,9 +314,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -375,8 +386,8 @@ class TestTemplateTransformNode: ) assert mapping == { - "node_123.var1": ["sys", "input1"], - "node_123.empty_selector": [], + "node_123.var1": ("sys", "input1"), + "node_123.empty_selector": (), } def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): @@ -409,9 +420,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a static message." - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -448,9 +458,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Total: $31.5" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -477,9 +486,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -507,9 +515,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Tags: #python #ai #workflow " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index bd22a8e318c..a846efbb43d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -1,14 +1,15 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import ( DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, TemplateTransformNode, ) from graphon.runtime import GraphRuntimeState - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 @@ -37,15 +38,13 @@ def mock_graph_runtime_state(): def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): node = TemplateTransformNode( - id="test_node", - config={ - "id": "test_node", - "data": { - "title": "Template Transform", - "variables": [], - "template": "hello", - }, - }, + node_id="test_node", + config=TemplateTransformNodeData( + title="Template Transform", + type="template-transform", + variables=[], + template="hello", + ), graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=MagicMock(), @@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri assert mapping == { "node_123.validated": ["sys", "input1"], - "node_123.raw": ["sys", "input2"], + "node_123.raw": ("sys", "input2"), } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index e11ebf6eb8b..364408ead61 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,16 +1,15 @@ from collections.abc import Mapping import pytest -from graphon.entities import GraphInitParams -from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_runtime import resolve_dify_run_context from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - }, - } - ) +def _build_node_config() -> dict[str, object]: + return { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + ), + } + + +def _build_node_data() -> _SampleNodeData: + return _build_node_config()["data"] # type: ignore[return-value] def test_node_hydrates_data_during_initialization(): @@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization(): init_params, runtime_state = _build_context(graph_config) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum(): ) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error(): def test_base_node_data_keeps_dict_style_access_compatibility(): - node_data = _SampleNodeData.model_validate( - { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - } - ) + node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar") assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" @@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): def test_node_hydration_preserves_compatibility_extra_fields(): graph_config: dict[str, object] = {} init_params, runtime_state = _build_context(graph_config) - node_config = NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - "compat_flag": True, - }, - } - ) + node_config = { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + compat_flag=True, + ), + } node = _SampleNode( - id="node-1", - config=node_config, + node_id="node-1", + config=node_config["data"], graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 555ff0c9452..dd75b325935 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,23 +4,25 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod from graphon.node_events import NodeRunResult from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, + _extract_text_from_file, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from graphon.variables import ArrayFileSegment +from graphon.variables import ArrayFileSegment, FileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params @@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params): title="Test Document Extractor", variable_selector=["node_id", "variable_name"], ) - node_config = {"id": "test_node_id", "data": node_data.model_dump()} http_client = Mock() node = DocumentExtractorNode( - id="test_node_id", - config=node_config, + node_id="test_node_id", + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=Mock(), http_client=http_client, @@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"] - mock_excel_instance.parse.side_effect = [df, Exception("Parse error")] + mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_mixed_content" @@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"] - mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")] + mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_all_bad_sheets" @@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): assert mock_excel_instance.parse.call_count == 2 +@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook")) +def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file): + with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"): + _extract_text_from_excel(b"broken") + + @patch("pandas.ExcelFile") def test_extract_text_from_excel_numeric_type_column(mock_excel_file): """Test extracting text from Excel file with numeric column names.""" @@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): assert expected_manual == result +@pytest.mark.parametrize( + ("extension", "mime_type"), + [ + (".xlsx", "text/plain"), + (None, "application/vnd.ms-excel"), + ], +) +def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type): + file = Mock(spec=File) + file.extension = extension + file.mime_type = mime_type + + with ( + patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"excel", + ), + patch( + "graphon.nodes.document_extractor.node._extract_text_from_excel", + return_value="excel text", + ) as mock_extract, + ): + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + assert result == "excel text" + mock_extract.assert_called_once_with(b"excel") + + +def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node): + file = Mock(spec=File) + file.extension = None + file.mime_type = None + + with patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"unknown", + ): + with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"): + _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + +def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_list = Mock(spec=ArrayFileSegment) + file_list.value = [Mock(spec=File)] + mock_graph_runtime_state.variable_pool.get.return_value = file_list + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("bad file"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "bad file" + + +def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("single file failed"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "single file failed" + + +def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + return_value="single file text", + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs == {"text": "single file text"} + + def _make_docx_zip(use_backslash: bool) -> bytes: """Helper to build a minimal in-memory DOCX zip. diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 1b14f0ab133..aa9a1360b0e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,6 +3,11 @@ import uuid from unittest.mock import MagicMock, Mock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph @@ -11,14 +16,23 @@ from graphon.nodes.if_else.if_else_node import IfElseNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition from graphon.variables import ArrayFileSegment - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params +def _build_if_else_node( + *, + node_data: IfElseNodeData | dict[str, object], + init_params, + graph_runtime_state, +) -> IfElseNode: + return IfElseNode( + node_id=str(uuid.uuid4()), + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), + ) + + def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} @@ -61,9 +75,8 @@ def test_execute_if_else_result_true(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "and", @@ -104,13 +117,8 @@ def test_execute_if_else_result_true(): {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -155,9 +163,8 @@ def test_execute_if_else_result_false(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "or", @@ -174,13 +181,8 @@ def test_execute_if_else_result_false(): }, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -222,11 +224,6 @@ def test_array_file_contains_file_name(): ], ) - node_config = { - "id": "if-else", - "data": node_data.model_dump(), - } - # Create properly configured mock for graph_init_params graph_init_params = Mock() graph_init_params.workflow_id = "test_workflow" @@ -242,17 +239,12 @@ def test_array_file_contains_file_name(): } } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=graph_init_params, - graph_runtime_state=Mock(), - config=node_config, - ) + node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock()) node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", filename="ab", @@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition): "logical_operator": "and", "conditions": [condition.model_dump()], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() @@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions(): ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={ - "id": "if-else", - "data": node_data, - }, ) # Mock db.session.close() @@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure(): } ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index d28c3e01e5f..465a4c0ff4e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.nodes.list_operator.entities import ( @@ -16,7 +18,14 @@ from graphon.nodes.list_operator.exc import InvalidKeyError from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from graphon.variables import ArrayFileSegment -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom + +def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode: + return ListOperatorNode( + node_id="test_node_id", + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=MagicMock(), + ) @pytest.fixture @@ -35,10 +44,6 @@ def list_operator_node(): "title": "Test Title", } node_data = ListOperatorNodeData.model_validate(config) - node_config = { - "id": "test_node_id", - "data": node_data.model_dump(), - } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() graph_init_params.workflow_id = "test_workflow" @@ -54,12 +59,7 @@ def list_operator_node(): } } - node = ListOperatorNode( - id="test_node_id", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=MagicMock(), - ) + node = _build_list_operator_node(node_data, graph_init_params) node.graph_runtime_state = MagicMock() node.graph_runtime_state.variable_pool = MagicMock() return node @@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node): files = [ File( filename="image1.jpg", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", ), File( filename="document1.pdf", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", ), File( filename="image2.png", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", ), File( filename="audio1.mp3", - type=FileType.AUDIO, + file_type=FileType.AUDIO, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", extension=".txt", @@ -156,7 +156,7 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, extension=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 833c3030521..5655f807379 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,16 +2,16 @@ import json import time import pytest +from pydantic import ValidationError as PydanticValidationError + +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState from graphon.variables import build_segment, segment_to_variable from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.variables import Variable -from pydantic import ValidationError as PydanticValidationError - -from core.workflow.system_variables import build_system_variables -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables): inputs=user_inputs, ) - config = { - "id": "start", - "data": StartNodeData(title="Start", variables=variables).model_dump(), - } + node_data = StartNodeData(title="Start", variables=variables) graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, @@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables): ) return StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, @@ -109,7 +106,7 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"): + with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"): node._run() @@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot(): inputs={"profile": {"age": 20, "name": "Tom"}}, ) - config = { - "id": "start", - "data": StartNodeData( - title="Start", - variables=[ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ], - ).model_dump(), - } + node_data = StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 15870148027..284af683196 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -8,14 +8,15 @@ from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock import pytest + +from core.workflow.system_variables import build_system_variables from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment - -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only @@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode: runtime = _StubToolRuntime() node = ToolNode( - id="node-instance", - config=config, + node_id="node-instance", + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, @@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode: return node -def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: +def _collect_events(generator: Generator) -> list[Any]: events: list[Any] = [] try: while True: events.append(next(generator)) - except StopIteration as stop: - return events, stop.value + except StopIteration: + return events def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: @@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li node_id=tool_node._node_id, tool_runtime=ToolRuntimeHandle(raw=object()), ) - return _collect_events(generator) + events = _collect_events(generator) + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert completed_events + return events, completed_events[-1].node_run_result.llm_usage def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", @@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index c4dfc5a1792..438af211f36 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -6,11 +6,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType -from graphon.nodes.tool.exc import ToolRuntimeInvocationError -from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage -from graphon.runtime import VariablePool from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError @@ -22,6 +17,11 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index 952e798430f..e3b5e3b5915 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,13 +1,12 @@ from collections.abc import Mapping -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState - from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -28,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": TRIGGER_PLUGIN_NODE_TYPE, - "title": "Trigger Event", - "plugin_id": "plugin-id", - "provider_id": "provider-id", - "event_name": "event-name", - "subscription_id": "subscription-id", - "plugin_unique_identifier": "plugin-unique-identifier", - "event_parameters": {}, - }, - } +def _build_node_data() -> TriggerEventNodeData: + return TriggerEventNodeData( + type=TRIGGER_PLUGIN_NODE_TYPE, + title="Trigger Event", + plugin_id="plugin-id", + provider_id="provider-id", + event_name="event-name", + subscription_id="subscription-id", + plugin_unique_identifier="plugin-unique-identifier", + event_parameters={}, ) def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: init_params, runtime_state = _build_context(graph_config={}) node = TriggerEventNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index f1132af02b5..617554ee179 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,5 +1,4 @@ import pytest -from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -7,6 +6,7 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) +from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index cccd3fb6767..07d03bec058 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -6,12 +6,9 @@ to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" erro when passing files to downstream LLM nodes. """ +from typing import Any from unittest.mock import Mock, patch -from graphon.entities import GraphInitParams -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -21,6 +18,9 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_variable_pool @@ -30,11 +30,6 @@ def create_webhook_node( tenant_id: str = "test-tenant", ) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "webhook-node-1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="test-workflow", graph_config={}, @@ -56,8 +51,8 @@ def create_webhook_node( ) node = TriggerWebhookNode( - id="webhook-node-1", - config=node_config, + node_id="webhook-node-1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -66,10 +61,6 @@ def create_webhook_node( runtime_state.app_config = Mock() runtime_state.app_config.tenant_id = tenant_id - # Provide compatibility alias expected by node implementation - # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests - node.node_id = node.id - return node @@ -97,7 +88,7 @@ def create_test_file_dict( } -def build_webhook_variable_pool(inputs: dict) -> VariablePool: +def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool: return build_test_variable_pool( variables=default_system_variables(), node_id="webhook-node-1", @@ -105,7 +96,7 @@ def build_webhook_variable_pool(inputs: dict) -> VariablePool: ) -def expected_factory_mapping(file_dict: dict) -> dict: +def expected_factory_mapping(file_dict: dict[str, Any]) -> dict[str, Any]: return {**file_dict, "upload_file_id": file_dict["related_id"]} diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 34c66a4f9f3..b839490d3c7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,11 +1,7 @@ +from typing import Any from unittest.mock import patch import pytest -from graphon.entities import GraphInitParams -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import FileVariable, StringVariable from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE @@ -18,16 +14,16 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import FileVariable, StringVariable from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="1", graph_config={}, @@ -47,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) start_at=0, ) node = TriggerWebhookNode( - id="1", - config=node_config, + node_id="1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -56,13 +52,10 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) # Provide tenant_id for conversion path runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() - # Compatibility alias for some nodes referencing `self.node_id` - node.node_id = node.id - return node -def build_webhook_variable_pool(inputs: dict) -> VariablePool: +def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool: return build_test_variable_pool( variables=default_system_variables(), node_id="1", @@ -224,7 +217,7 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="image.jpg", @@ -233,7 +226,7 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", filename="document.pdf", @@ -268,8 +261,19 @@ def test_webhook_node_run_with_file_params(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() @@ -283,7 +287,7 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="test.jpg", @@ -316,8 +320,19 @@ def test_webhook_node_run_mixed_parameters(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py new file mode 100644 index 00000000000..8b5fceeb379 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py @@ -0,0 +1,350 @@ +from types import SimpleNamespace + +import pytest +from pydantic import BaseModel + +from core.workflow.human_input_adapter import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + adapt_human_input_node_data_for_graph, + adapt_node_config_for_graph, + adapt_node_data_for_graph, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert " Team") - html = EmailDeliveryConfig.render_markdown_body( - "**Hello** [mail](mailto:test@example.com)" - ) - - assert rendered == "Open https://example.com and use 42" - assert sanitized == "Hello alert(1) Team" - assert "Hello" in html - assert " + ) +} + +export default memo(CreateAppAttributionBootstrap) diff --git a/web/app/components/custom/custom-page/__tests__/index.spec.tsx b/web/app/components/custom/custom-page/__tests__/index.spec.tsx index 5514d21b502..1f3655a9f81 100644 --- a/web/app/components/custom/custom-page/__tests__/index.spec.tsx +++ b/web/app/components/custom/custom-page/__tests__/index.spec.tsx @@ -1,9 +1,10 @@ +import type { ReactElement } from 'react' import type { AppContextValue } from '@/context/app-context' -import type { SystemFeatures } from '@/types/feature' -import { render, screen } from '@testing-library/react' +import { screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { beforeEach, describe, expect, it, vi } from 'vitest' import { createMockProviderContextValue } from '@/__mocks__/provider-context' +import { renderWithSystemFeatures } from '@/__tests__/utils/mock-system-features' import { contactSalesUrl, defaultPlan } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' import { @@ -12,12 +13,19 @@ import { useAppContext, userProfilePlaceholder, } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' -import { defaultSystemFeatures } from '@/types/feature' import CustomPage from '../index' +const render = (ui: ReactElement) => renderWithSystemFeatures(ui, { + systemFeatures: { + branding: { + enabled: true, + workspace_logo: 'https://example.com/workspace-logo.png', + }, + }, +}) + const { mockToast } = vi.hoisted(() => { const mockToast = Object.assign(vi.fn(), { success: vi.fn(), @@ -44,17 +52,13 @@ vi.mock('@/context/app-context', async (importOriginal) => { useAppContext: vi.fn(), } }) -vi.mock('@/context/global-public-context', () => ({ - useGlobalPublicStore: vi.fn(), -})) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: mockToast, })) const mockUseProviderContext = vi.mocked(useProviderContext) const mockUseModalContext = vi.mocked(useModalContext) const mockUseAppContext = vi.mocked(useAppContext) -const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore) const createProviderContext = ({ enableBilling = false, @@ -93,15 +97,6 @@ const createAppContextValue = (): AppContextValue => ({ isValidatingCurrentWorkspace: false, }) -const createSystemFeatures = (): SystemFeatures => ({ - ...defaultSystemFeatures, - branding: { - ...defaultSystemFeatures.branding, - enabled: true, - workspace_logo: 'https://example.com/workspace-logo.png', - }, -}) - describe('CustomPage', () => { const setShowPricingModal = vi.fn() @@ -113,10 +108,6 @@ describe('CustomPage', () => { setShowPricingModal, } as unknown as ReturnType) mockUseAppContext.mockReturnValue(createAppContextValue()) - mockUseGlobalPublicStore.mockImplementation(selector => selector({ - systemFeatures: createSystemFeatures(), - setSystemFeatures: vi.fn(), - })) }) // Integration coverage for the page and its child custom brand section. diff --git a/web/app/components/custom/custom-page/index.tsx b/web/app/components/custom/custom-page/index.tsx index 840c2bb7c4c..cd85a912303 100644 --- a/web/app/components/custom/custom-page/index.tsx +++ b/web/app/components/custom/custom-page/index.tsx @@ -15,12 +15,12 @@ const CustomPage = () => { return (
{showBillingTip && ( -
+
{t('upgradeTip.title', { ns: 'custom' })}
{t('upgradeTip.des', { ns: 'custom' })}
-
setShowPricingModal()}>{t('upgradeBtn.encourageShort', { ns: 'billing' })}
+
setShowPricingModal()}>{t('upgradeBtn.encourageShort', { ns: 'billing' })}
)} diff --git a/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx b/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx index 5700a04e41a..1ea03c7bc68 100644 --- a/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx +++ b/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx @@ -1,5 +1,5 @@ -import Button from '@/app/components/base/button' -import { cn } from '@/utils/classnames' +import { Button } from '@langgenius/dify-ui/button' +import { cn } from '@langgenius/dify-ui/cn' import PoweredByBrand from './powered-by-brand' type ChatPreviewCardProps = { @@ -22,28 +22,28 @@ const ChatPreviewCard = ({
-
Chatflow App
+
Chatflow App
-
+
-
+
-
+
@@ -61,14 +61,14 @@ const ChatPreviewCard = ({
-
+
-
Hello! How can I assist you today?
+
Hello! How can I assist you today?
-
Talk to Dify
+
Talk to Dify
diff --git a/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx b/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx index 8a0feffbc45..51c0db53d75 100644 --- a/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx +++ b/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx @@ -20,7 +20,7 @@ const PoweredByBrand = ({ return ( <> -
POWERED BY
+
POWERED BY
{previewLogo ? logo : } diff --git a/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx b/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx index 276f77ce711..b4f30827bc2 100644 --- a/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx +++ b/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx @@ -1,5 +1,5 @@ -import Button from '@/app/components/base/button' -import { cn } from '@/utils/classnames' +import { Button } from '@langgenius/dify-ui/button' +import { cn } from '@langgenius/dify-ui/cn' import PoweredByBrand from './powered-by-brand' type WorkflowPreviewCardProps = { @@ -22,29 +22,29 @@ const WorkflowPreviewCard = ({
-
Workflow App
+
Workflow App
-
RUN ONCE
-
RUN BATCH
+
RUN ONCE
+
RUN BATCH
-
+
diff --git a/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx b/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx index a2730ffd27a..99cbc03b329 100644 --- a/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx +++ b/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx @@ -1,9 +1,10 @@ import type { ChangeEvent } from 'react' import type { AppContextValue } from '@/context/app-context' import type { SystemFeatures } from '@/types/feature' -import { act, renderHook } from '@testing-library/react' +import { act } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { createMockProviderContextValue } from '@/__mocks__/provider-context' +import { renderHookWithSystemFeatures } from '@/__tests__/utils/mock-system-features' import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' import { defaultPlan } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' @@ -13,12 +14,22 @@ import { useAppContext, userProfilePlaceholder, } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { updateCurrentWorkspace } from '@/service/common' -import { defaultSystemFeatures } from '@/types/feature' import useWebAppBrand from '../use-web-app-brand' +let currentBrandingOverrides: Partial = {} +const renderHook = (callback: (props: Props) => Result) => + renderHookWithSystemFeatures(callback, { + systemFeatures: { + branding: { + enabled: true, + workspace_logo: 'https://example.com/workspace-logo.png', + ...currentBrandingOverrides, + }, + }, + }) + const { mockNotify, mockToast } = vi.hoisted(() => { const mockNotify = vi.fn() const mockToast = Object.assign(mockNotify, { @@ -33,7 +44,7 @@ const { mockNotify, mockToast } = vi.hoisted(() => { return { mockNotify, mockToast } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: mockToast, })) vi.mock('@/service/common', () => ({ @@ -49,9 +60,6 @@ vi.mock('@/context/app-context', async (importOriginal) => { vi.mock('@/context/provider-context', () => ({ useProviderContext: vi.fn(), })) -vi.mock('@/context/global-public-context', () => ({ - useGlobalPublicStore: vi.fn(), -})) vi.mock('@/app/components/base/image-uploader/utils', () => ({ imageUpload: vi.fn(), getImageUploadErrorMessage: vi.fn(), @@ -60,7 +68,6 @@ vi.mock('@/app/components/base/image-uploader/utils', () => ({ const mockUpdateCurrentWorkspace = vi.mocked(updateCurrentWorkspace) const mockUseAppContext = vi.mocked(useAppContext) const mockUseProviderContext = vi.mocked(useProviderContext) -const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore) const mockImageUpload = vi.mocked(imageUpload) const mockGetImageUploadErrorMessage = vi.mocked(getImageUploadErrorMessage) @@ -80,16 +87,6 @@ const createProviderContext = ({ }) } -const createSystemFeatures = (brandingOverrides: Partial = {}): SystemFeatures => ({ - ...defaultSystemFeatures, - branding: { - ...defaultSystemFeatures.branding, - enabled: true, - workspace_logo: 'https://example.com/workspace-logo.png', - ...brandingOverrides, - }, -}) - const createAppContextValue = (overrides: Partial = {}): AppContextValue => { const { currentWorkspace: currentWorkspaceOverride, ...restOverrides } = overrides const workspaceOverrides: Partial = currentWorkspaceOverride ?? {} @@ -122,21 +119,16 @@ const createAppContextValue = (overrides: Partial = {}): AppCon describe('useWebAppBrand', () => { let appContextValue: AppContextValue - let systemFeatures: SystemFeatures beforeEach(() => { vi.clearAllMocks() appContextValue = createAppContextValue() - systemFeatures = createSystemFeatures() + currentBrandingOverrides = {} mockUpdateCurrentWorkspace.mockResolvedValue(appContextValue.currentWorkspace) mockUseAppContext.mockImplementation(() => appContextValue) mockUseProviderContext.mockReturnValue(createProviderContext()) - mockUseGlobalPublicStore.mockImplementation(selector => selector({ - systemFeatures, - setSystemFeatures: vi.fn(), - })) mockGetImageUploadErrorMessage.mockReturnValue('upload error') }) @@ -174,10 +166,7 @@ describe('useWebAppBrand', () => { }) it('should fall back to an empty workspace logo when branding is disabled', () => { - systemFeatures = createSystemFeatures({ - enabled: false, - workspace_logo: '', - }) + currentBrandingOverrides = { enabled: false, workspace_logo: '' } const { result } = renderHook(() => useWebAppBrand()) diff --git a/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts b/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts index 44810de1fda..e24edab4212 100644 --- a/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts +++ b/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts @@ -1,13 +1,14 @@ import type { ChangeEvent } from 'react' +import { toast } from '@langgenius/dify-ui/toast' +import { useSuspenseQuery } from '@tanstack/react-query' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' -import { toast } from '@/app/components/base/ui/toast' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { updateCurrentWorkspace } from '@/service/common' +import { systemFeaturesQueryOptions } from '@/service/system-features' const MAX_LOGO_FILE_SIZE = 5 * 1024 * 1024 const CUSTOM_CONFIG_URL = '/workspaces/custom-config' @@ -19,7 +20,7 @@ const useWebAppBrand = () => { const [fileId, setFileId] = useState('') const [imgKey, setImgKey] = useState(() => Date.now()) const [uploadProgress, setUploadProgress] = useState(0) - const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) + const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) const isSandbox = enableBilling && plan.type === Plan.sandbox const uploading = uploadProgress > 0 && uploadProgress < 100 const webappLogo = currentWorkspace.custom_config?.replace_webapp_logo || '' diff --git a/web/app/components/custom/custom-web-app-brand/index.tsx b/web/app/components/custom/custom-web-app-brand/index.tsx index 02a6419f188..53e6ff36ec0 100644 --- a/web/app/components/custom/custom-web-app-brand/index.tsx +++ b/web/app/components/custom/custom-web-app-brand/index.tsx @@ -1,8 +1,8 @@ +import { Button } from '@langgenius/dify-ui/button' +import { cn } from '@langgenius/dify-ui/cn' +import { Switch } from '@langgenius/dify-ui/switch' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import Switch from '@/app/components/base/switch' -import { cn } from '@/utils/classnames' import ChatPreviewCard from './components/chat-preview-card' import WorkflowPreviewCard from './components/workflow-preview-card' import useWebAppBrand from './hooks/use-web-app-brand' @@ -31,19 +31,19 @@ const CustomWebAppBrand = () => { return (
-
+
{t('webapp.removeBrand', { ns: 'custom' })}
-
{t('webapp.changeLogo', { ns: 'custom' })}
-
{t('webapp.changeLogoTip', { ns: 'custom' })}
+
{t('webapp.changeLogo', { ns: 'custom' })}
+
{t('webapp.changeLogoTip', { ns: 'custom' })}
{(!uploadDisabled && webappLogo && !webappBrandRemoved) && ( @@ -55,7 +55,7 @@ const CustomWebAppBrand = () => { > {t('restore', { ns: 'custom' })} -
+
)} { @@ -64,7 +64,7 @@ const CustomWebAppBrand = () => { className="relative mr-2" disabled={uploadDisabled} > - + { (webappLogo || fileId) ? t('change', { ns: 'custom' }) @@ -87,7 +87,7 @@ const CustomWebAppBrand = () => { className="relative mr-2" disabled={true} > - + {t('uploading', { ns: 'custom' })} ) @@ -118,8 +118,8 @@ const CustomWebAppBrand = () => { {uploadProgress === -1 && (
{t('uploadedFail', { ns: 'custom' })}
)} -
-
{t('overview.appInfo.preview', { ns: 'appOverview' })}
+
+
{t('overview.appInfo.preview', { ns: 'appOverview' })}
diff --git a/web/app/components/custom/custom-web-app-brand/style.module.css b/web/app/components/custom/custom-web-app-brand/style.module.css deleted file mode 100644 index bdc7d7cfbfe..00000000000 --- a/web/app/components/custom/custom-web-app-brand/style.module.css +++ /dev/null @@ -1,3 +0,0 @@ -.mask { - background: linear-gradient(273deg, rgba(255, 255, 255, 0.00) 51.75%, rgba(255, 255, 255, 0.80) 115.32%); -} diff --git a/web/app/components/datasets/chunk.tsx b/web/app/components/datasets/chunk.tsx index 74184c8d07b..947f2859c4e 100644 --- a/web/app/components/datasets/chunk.tsx +++ b/web/app/components/datasets/chunk.tsx @@ -2,7 +2,7 @@ import type { FC, PropsWithChildren } from 'react' import type { QA } from '@/models/datasets' import { SelectionMod } from '../base/icons/src/public/knowledge' -export type ChunkLabelProps = { +type ChunkLabelProps = { label: string characterCount: number } @@ -27,7 +27,7 @@ export const ChunkLabel: FC = (props) => { ) } -export type ChunkContainerProps = ChunkLabelProps & PropsWithChildren +type ChunkContainerProps = ChunkLabelProps & PropsWithChildren export const ChunkContainer: FC = (props) => { const { label, characterCount, children } = props @@ -41,7 +41,7 @@ export const ChunkContainer: FC = (props) => { ) } -export type QAPreviewProps = { +type QAPreviewProps = { qa: QA } @@ -50,11 +50,11 @@ export const QAPreview: FC = (props) => { return (
- +

{qa.question}

- +

{qa.answer}

diff --git a/web/app/components/datasets/common/credential-icon.tsx b/web/app/components/datasets/common/credential-icon.tsx index a97cf7d4101..f993eca5853 100644 --- a/web/app/components/datasets/common/credential-icon.tsx +++ b/web/app/components/datasets/common/credential-icon.tsx @@ -1,6 +1,6 @@ +import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' import { useCallback, useMemo, useState } from 'react' -import { cn } from '@/utils/classnames' type CredentialIconProps = { avatarUrl?: string @@ -59,7 +59,7 @@ export const CredentialIcon: React.FC = ({ )} style={{ width: `${size}px`, height: `${size}px` }} > - + {firstLetter}
diff --git a/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx b/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx index f8f0ce6e124..1251eab9fbe 100644 --- a/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx +++ b/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx @@ -5,34 +5,7 @@ import * as React from 'react' import { ChunkingMode, DataSourceType } from '@/models/datasets' import DocumentPicker from '../index' -// Mock portal-to-follow-elem - always render content for testing -vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ - PortalToFollowElem: ({ children, open }: { - children: React.ReactNode - open?: boolean - }) => ( -
- {children} -
- ), - PortalToFollowElemTrigger: ({ children, onClick }: { - children: React.ReactNode - onClick?: () => void - }) => ( -
- {children} -
- ), - // Always render content to allow testing document selection - PortalToFollowElemContent: ({ children, className }: { - children: React.ReactNode - className?: string - }) => ( -
- {children} -
- ), -})) +vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover')) // Mock useDocumentList hook with controllable return value let mockDocumentListData: { data: SimpleDocumentDetail[] } | undefined @@ -152,6 +125,10 @@ const renderComponent = (props: Partial { + fireEvent.click(screen.getByTestId('popover-trigger')) +} + describe('DocumentPicker', () => { beforeEach(() => { vi.clearAllMocks() @@ -165,7 +142,7 @@ describe('DocumentPicker', () => { it('should render without crashing', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should render document name when provided', () => { @@ -273,7 +250,7 @@ describe('DocumentPicker', () => { onChange, }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle value with all fields', () => { @@ -318,13 +295,13 @@ describe('DocumentPicker', () => { it('should initialize with popup closed', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + expect(screen.getByTestId('popover')).toHaveAttribute('data-open', 'false') }) it('should open popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) // Verify click handler is called @@ -430,7 +407,7 @@ describe('DocumentPicker', () => { ) // The component should use the new callback - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should memoize handleChange callback with useCallback', () => { @@ -440,7 +417,7 @@ describe('DocumentPicker', () => { renderComponent({ onChange }) // Verify component renders correctly, callback memoization is internal - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -518,7 +495,7 @@ describe('DocumentPicker', () => { it('should toggle popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) // Trigger click should be handled @@ -591,7 +568,7 @@ describe('DocumentPicker', () => { renderComponent() // When loading, component should still render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should fetch documents on mount', () => { @@ -611,7 +588,7 @@ describe('DocumentPicker', () => { renderComponent() // Component should render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle undefined data response', () => { @@ -620,7 +597,7 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -732,13 +709,13 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle rapid toggle clicks', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') // Rapid clicks fireEvent.click(trigger) @@ -795,7 +772,7 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle document list mapping with various data_source_detail_dict states', () => { @@ -819,7 +796,7 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash during mapping - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -829,13 +806,13 @@ describe('DocumentPicker', () => { it('should handle empty datasetId', () => { renderComponent({ datasetId: '' }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle UUID format datasetId', () => { renderComponent({ datasetId: '123e4567-e89b-12d3-a456-426614174000' }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -926,6 +903,7 @@ describe('DocumentPicker', () => { const onChange = vi.fn() renderComponent({ onChange }) + openPopover() fireEvent.click(screen.getByText('Document 2')) @@ -939,6 +917,7 @@ describe('DocumentPicker', () => { mockDocumentListData = { data: docs } renderComponent() + openPopover() // Documents should be rendered in the list expect(screen.getByText('Document 1')).toBeInTheDocument() @@ -978,14 +957,14 @@ describe('DocumentPicker', () => { // The mapping: d.data_source_detail_dict?.upload_file?.extension || '' // Should extract 'pdf' from the document - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should render trigger with SearchInput integration', () => { renderComponent() // The trigger is always rendered - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + expect(screen.getByTestId('popover-trigger')).toBeInTheDocument() }) it('should integrate FileIcon component', () => { @@ -1001,7 +980,7 @@ describe('DocumentPicker', () => { }) // FileIcon should render an SVG icon for the file extension - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -1010,9 +989,10 @@ describe('DocumentPicker', () => { describe('Visual States', () => { it('should render portal content for document selection', () => { renderComponent() + openPopover() - // Portal content is rendered in our mock for testing - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Popover content is rendered after opening the trigger in our mock + expect(screen.getByTestId('popover-content')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx b/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx index 7178e9f60cb..c7eb2c740c1 100644 --- a/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx +++ b/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx @@ -3,34 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' import PreviewDocumentPicker from '../preview-document-picker' -// Mock portal-to-follow-elem - always render content for testing -vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ - PortalToFollowElem: ({ children, open }: { - children: React.ReactNode - open?: boolean - }) => ( -
- {children} -
- ), - PortalToFollowElemTrigger: ({ children, onClick }: { - children: React.ReactNode - onClick?: () => void - }) => ( -
- {children} -
- ), - // Always render content to allow testing document selection - PortalToFollowElemContent: ({ children, className }: { - children: React.ReactNode - className?: string - }) => ( -
- {children} -
- ), -})) +vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover')) // Factory function to create mock DocumentItem const createMockDocumentItem = (overrides: Partial = {}): DocumentItem => ({ @@ -67,6 +40,10 @@ const renderComponent = (props: Partial { + fireEvent.click(screen.getByTestId('popover-trigger')) +} + describe('PreviewDocumentPicker', () => { beforeEach(() => { vi.clearAllMocks() @@ -77,7 +54,7 @@ describe('PreviewDocumentPicker', () => { it('should render without crashing', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should render document name from value prop', () => { @@ -110,7 +87,7 @@ describe('PreviewDocumentPicker', () => { files: [], // Use empty files to avoid duplicate icons }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) @@ -120,7 +97,7 @@ describe('PreviewDocumentPicker', () => { files: [], // Use empty files to avoid duplicate icons }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -131,22 +108,21 @@ describe('PreviewDocumentPicker', () => { const props = createDefaultProps() render() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should apply className to trigger element', () => { renderComponent({ className: 'custom-class' }) - const trigger = screen.getByTestId('portal-trigger') - const innerDiv = trigger.querySelector('.custom-class') - expect(innerDiv).toBeInTheDocument() + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('custom-class') }) it('should handle empty files array', () => { // Component should render without crashing with empty files renderComponent({ files: [] }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle single file', () => { @@ -155,7 +131,7 @@ describe('PreviewDocumentPicker', () => { files: [createMockDocumentItem({ id: 'single-doc', name: 'Single File' })], }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle multiple files', () => { @@ -164,7 +140,7 @@ describe('PreviewDocumentPicker', () => { files: createMockDocumentList(5), }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should use value.extension for file icon', () => { @@ -172,7 +148,7 @@ describe('PreviewDocumentPicker', () => { value: createMockDocumentItem({ name: 'test.docx', extension: 'docx' }), }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -182,13 +158,13 @@ describe('PreviewDocumentPicker', () => { it('should initialize with popup closed', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + expect(screen.getByTestId('popover')).toHaveAttribute('data-open', 'false') }) it('should toggle popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) expect(trigger).toBeInTheDocument() @@ -196,9 +172,10 @@ describe('PreviewDocumentPicker', () => { it('should render portal content for document selection', () => { renderComponent() + openPopover() - // Portal content is always rendered in our mock for testing - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Popover content is rendered after opening the trigger in our mock + expect(screen.getByTestId('popover-content')).toBeInTheDocument() }) }) @@ -242,7 +219,7 @@ describe('PreviewDocumentPicker', () => { , ) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -265,7 +242,7 @@ describe('PreviewDocumentPicker', () => { , ) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -274,7 +251,7 @@ describe('PreviewDocumentPicker', () => { it('should toggle popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) expect(trigger).toBeInTheDocument() @@ -283,6 +260,7 @@ describe('PreviewDocumentPicker', () => { it('should render document list with files', () => { const files = createMockDocumentList(3) renderComponent({ files }) + openPopover() // Documents should be visible in the list expect(screen.getByText('Document 1')).toBeInTheDocument() @@ -295,6 +273,7 @@ describe('PreviewDocumentPicker', () => { const files = createMockDocumentList(3) renderComponent({ files, onChange }) + openPopover() fireEvent.click(screen.getByText('Document 2')) @@ -306,7 +285,7 @@ describe('PreviewDocumentPicker', () => { it('should handle rapid toggle clicks', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') // Rapid clicks fireEvent.click(trigger) @@ -337,14 +316,14 @@ describe('PreviewDocumentPicker', () => { // Renders placeholder for missing name expect(screen.getByText('--')).toBeInTheDocument() // Portal wrapper renders - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle empty files array', () => { renderComponent({ files: [] }) // Component should render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle very long document names', () => { @@ -374,7 +353,7 @@ describe('PreviewDocumentPicker', () => { render() // Component should render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle large number of files', () => { @@ -382,7 +361,7 @@ describe('PreviewDocumentPicker', () => { renderComponent({ files: manyFiles }) // Component should accept large files array - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle files with same name but different extensions', () => { @@ -393,7 +372,7 @@ describe('PreviewDocumentPicker', () => { renderComponent({ files }) // Component should handle duplicate names - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -427,7 +406,7 @@ describe('PreviewDocumentPicker', () => { files: [createMockDocumentItem({ name: 'Single' })], }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle two files', () => { @@ -435,7 +414,7 @@ describe('PreviewDocumentPicker', () => { files: createMockDocumentList(2), }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle many files', () => { @@ -443,7 +422,7 @@ describe('PreviewDocumentPicker', () => { files: createMockDocumentList(50), }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -451,23 +430,22 @@ describe('PreviewDocumentPicker', () => { it('should apply custom className', () => { renderComponent({ className: 'my-custom-class' }) - const trigger = screen.getByTestId('portal-trigger') - expect(trigger.querySelector('.my-custom-class')).toBeInTheDocument() + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('my-custom-class') }) it('should work without className', () => { renderComponent({ className: undefined }) - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + expect(screen.getByTestId('popover-trigger')).toBeInTheDocument() }) it('should handle multiple class names', () => { renderComponent({ className: 'class-one class-two' }) - const trigger = screen.getByTestId('portal-trigger') - const element = trigger.querySelector('.class-one') - expect(element).toBeInTheDocument() - expect(element).toHaveClass('class-two') + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('class-one') + expect(trigger).toHaveClass('class-two') }) }) @@ -480,7 +458,7 @@ describe('PreviewDocumentPicker', () => { files: [], // Use empty files to avoid duplicate icons }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -491,6 +469,7 @@ describe('PreviewDocumentPicker', () => { it('should render all documents in the list', () => { const files = createMockDocumentList(5) renderComponent({ files }) + openPopover() // All documents should be visible files.forEach((file) => { @@ -503,6 +482,7 @@ describe('PreviewDocumentPicker', () => { const files = createMockDocumentList(3) renderComponent({ files, onChange }) + openPopover() fireEvent.click(screen.getByText('Document 1')) @@ -528,6 +508,7 @@ describe('PreviewDocumentPicker', () => { onChange={vi.fn()} />, ) + openPopover() expect(screen.getByText(/dataset\.preprocessDocument/)).toBeInTheDocument() }) }) @@ -537,9 +518,8 @@ describe('PreviewDocumentPicker', () => { it('should apply hover styles on trigger', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') - const innerDiv = trigger.querySelector('.hover\\:bg-state-base-hover') - expect(innerDiv).toBeInTheDocument() + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('hover:bg-state-base-hover') }) it('should have truncate class for long names', () => { @@ -568,6 +548,7 @@ describe('PreviewDocumentPicker', () => { const files = createMockDocumentList(3) renderComponent({ files, onChange }) + openPopover() fireEvent.click(screen.getByText('Document 1')) @@ -582,10 +563,12 @@ describe('PreviewDocumentPicker', () => { ] renderComponent({ files: customFiles, onChange }) + openPopover() fireEvent.click(screen.getByText('Custom File 1')) expect(onChange).toHaveBeenCalledWith(customFiles[0]) + openPopover() fireEvent.click(screen.getByText('Custom File 2')) expect(onChange).toHaveBeenCalledWith(customFiles[1]) }) @@ -597,8 +580,11 @@ describe('PreviewDocumentPicker', () => { renderComponent({ files, onChange }) // Select multiple documents sequentially + openPopover() fireEvent.click(screen.getByText('Document 1')) + openPopover() fireEvent.click(screen.getByText('Document 3')) + openPopover() fireEvent.click(screen.getByText('Document 2')) expect(onChange).toHaveBeenCalledTimes(3) diff --git a/web/app/components/datasets/common/document-picker/document-list.tsx b/web/app/components/datasets/common/document-picker/document-list.tsx index 574792ee145..d2d8d1966c6 100644 --- a/web/app/components/datasets/common/document-picker/document-list.tsx +++ b/web/app/components/datasets/common/document-picker/document-list.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC } from 'react' import type { DocumentItem } from '@/models/datasets' +import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' import { useCallback } from 'react' -import { cn } from '@/utils/classnames' import FileIcon from '../document-file-icon' type Props = { diff --git a/web/app/components/datasets/common/document-picker/index.tsx b/web/app/components/datasets/common/document-picker/index.tsx index d1dbf1644a2..0566b590de7 100644 --- a/web/app/components/datasets/common/document-picker/index.tsx +++ b/web/app/components/datasets/common/document-picker/index.tsx @@ -1,6 +1,12 @@ 'use client' import type { FC } from 'react' import type { DocumentItem, ParentMode, SimpleDocumentDetail } from '@/models/datasets' +import { cn } from '@langgenius/dify-ui/cn' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@langgenius/dify-ui/popover' import { RiArrowDownSLine } from '@remixicon/react' import { useBoolean } from 'ahooks' import * as React from 'react' @@ -8,15 +14,9 @@ import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { GeneralChunk, ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge' import Loading from '@/app/components/base/loading' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' import SearchInput from '@/app/components/base/search-input' import { ChunkingMode } from '@/models/datasets' import { useDocumentList } from '@/service/knowledge/use-document' -import { cn } from '@/utils/classnames' import FileIcon from '../document-file-icon' import DocumentList from './document-list' @@ -61,7 +61,6 @@ const DocumentPicker: FC = ({ const [open, { set: setOpen, - toggle: togglePopup, }] = useBoolean(false) const ArrowIcon = RiArrowDownSLine @@ -77,34 +76,40 @@ const DocumentPicker: FC = ({ }, [parentMode, t]) return ( - - -
- -
-
- - {' '} - {name || '--'} - - -
-
- - - {isGeneralMode && t('chunkingMode.general', { ns: 'dataset' })} - {isQAMode && t('chunkingMode.qa', { ns: 'dataset' })} - {isParentChild && `${t('chunkingMode.parentChild', { ns: 'dataset' })} · ${parentModeLabel}`} - + + +
+
+ + {' '} + {name || '--'} + + +
+
+ + + {isGeneralMode && t('chunkingMode.general', { ns: 'dataset' })} + {isQAMode && t('chunkingMode.qa', { ns: 'dataset' })} + {isParentChild && `${t('chunkingMode.parentChild', { ns: 'dataset' })} · ${parentModeLabel}`} + +
-
- - + )} + /> +
{documentsList @@ -125,9 +130,8 @@ const DocumentPicker: FC = ({
)}
- - -
+ + ) } export default React.memo(DocumentPicker) diff --git a/web/app/components/datasets/common/document-picker/preview-document-picker.tsx b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx index 74e8ea96734..597ceda9a5f 100644 --- a/web/app/components/datasets/common/document-picker/preview-document-picker.tsx +++ b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx @@ -1,18 +1,18 @@ 'use client' import type { FC } from 'react' import type { DocumentItem } from '@/models/datasets' +import { cn } from '@langgenius/dify-ui/cn' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@langgenius/dify-ui/popover' import { RiArrowDownSLine } from '@remixicon/react' import { useBoolean } from 'ahooks' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' -import { cn } from '@/utils/classnames' import FileIcon from '../document-file-icon' import DocumentList from './document-list' @@ -35,7 +35,6 @@ const PreviewDocumentPicker: FC = ({ const [open, { set: setOpen, - toggle: togglePopup, }] = useBoolean(false) const ArrowIcon = RiArrowDownSLine @@ -45,29 +44,34 @@ const PreviewDocumentPicker: FC = ({ }, [onChange, setOpen]) return ( - - -
- -
-
- - {' '} - {name || '--'} - - + + +
+
+ + {' '} + {name || '--'} + + +
-
- - + )} + /> +
- {files?.length > 1 &&
{t('preprocessDocument', { ns: 'dataset', num: files.length })}
} + {files?.length > 1 &&
{t('preprocessDocument', { ns: 'dataset', num: files.length })}
} {files?.length > 0 ? ( = ({
)}
- - -
+ + ) } export default React.memo(PreviewDocumentPicker) diff --git a/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx b/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx index fcaca86e89d..ba058860d3a 100644 --- a/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx +++ b/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx @@ -1,6 +1,6 @@ +import { toast } from '@langgenius/dify-ui/toast' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { toast } from '@/app/components/base/ui/toast' import { useAutoDisabledDocuments } from '@/service/knowledge/use-document' import AutoDisabledDocument from '../auto-disabled-document' @@ -30,7 +30,7 @@ vi.mock('@/service/knowledge/use-document', () => ({ useInvalidDisabledDocument: vi.fn(() => mockInvalidDisabledDocument), })) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: { success: mockToastSuccess, }, diff --git a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx index c6c7e03bd1c..c150ac823f8 100644 --- a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx +++ b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC } from 'react' +import { toast } from '@langgenius/dify-ui/toast' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { toast } from '@/app/components/base/ui/toast' import { useAutoDisabledDocuments, useDocumentEnable, useInvalidDisabledDocument } from '@/service/knowledge/use-document' import StatusWithAction from './status-with-action' diff --git a/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx index 9d42d40b1a7..3c2536fa5e0 100644 --- a/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx +++ b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC } from 'react' +import { cn } from '@langgenius/dify-ui/cn' import { RiAlertFill, RiCheckboxCircleFill, RiErrorWarningFill, RiInformation2Fill } from '@remixicon/react' import * as React from 'react' import Divider from '@/app/components/base/divider' -import { cn } from '@/utils/classnames' type Status = 'success' | 'error' | 'warning' | 'info' type Props = { @@ -46,7 +46,7 @@ const StatusAction: FC = ({ }) => { const { Icon, color } = getIcon(type) return ( -
+
= ({
{description}
{onAction && ( <> - +
{actionText}
)} diff --git a/web/app/components/datasets/common/image-list/__tests__/index.spec.tsx b/web/app/components/datasets/common/image-list/__tests__/index.spec.tsx index 3b04b2708e0..e1fde328bb7 100644 --- a/web/app/components/datasets/common/image-list/__tests__/index.spec.tsx +++ b/web/app/components/datasets/common/image-list/__tests__/index.spec.tsx @@ -72,7 +72,7 @@ describe('ImageList', () => { it('should render without crashing', () => { const images = createMockImages(3) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should render all images when count is below limit', () => { @@ -87,7 +87,8 @@ describe('ImageList', () => { const images = createMockImages(15) render() // More button should be visible - expect(screen.getByText(/\+6/)).toBeInTheDocument() + // More button should be visible + expect(screen.getByText(/\+6/))!.toBeInTheDocument() }) }) @@ -97,33 +98,35 @@ describe('ImageList', () => { const { container } = render( , ) - expect(container.firstChild).toHaveClass('custom-class') + expect(container.firstChild)!.toHaveClass('custom-class') }) it('should use default limit of 9', () => { const images = createMockImages(12) render() // Should show "+3" for remaining images - expect(screen.getByText(/\+3/)).toBeInTheDocument() + // Should show "+3" for remaining images + expect(screen.getByText(/\+3/))!.toBeInTheDocument() }) it('should respect custom limit', () => { const images = createMockImages(10) render() // Should show "+5" for remaining images - expect(screen.getByText(/\+5/)).toBeInTheDocument() + // Should show "+5" for remaining images + expect(screen.getByText(/\+5/))!.toBeInTheDocument() }) it('should handle size prop sm', () => { const images = createMockImages(2) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should handle size prop md', () => { const images = createMockImages(2) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) }) @@ -135,6 +138,37 @@ describe('ImageList', () => { const moreButton = screen.getByText(/\+6/) fireEvent.click(moreButton) + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear // More button should disappear expect(screen.queryByText(/\+6/)).not.toBeInTheDocument() }) @@ -146,9 +180,10 @@ describe('ImageList', () => { // Find and click an image thumbnail const thumbnails = document.querySelectorAll('[class*="cursor-pointer"]') if (thumbnails.length > 0) { - fireEvent.click(thumbnails[0]) + fireEvent.click(thumbnails[0]!) // Preview should open - expect(screen.getByTestId('image-previewer')).toBeInTheDocument() + // Preview should open + expect(screen.getByTestId('image-previewer'))!.toBeInTheDocument() } }) @@ -159,12 +194,43 @@ describe('ImageList', () => { // Open preview const thumbnails = document.querySelectorAll('[class*="cursor-pointer"]') if (thumbnails.length > 0) { - fireEvent.click(thumbnails[0]) + fireEvent.click(thumbnails[0]!) // Close preview const closeButton = screen.getByTestId('close-preview') fireEvent.click(closeButton) + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed // Preview should be closed expect(screen.queryByTestId('image-previewer')).not.toBeInTheDocument() } @@ -174,7 +240,7 @@ describe('ImageList', () => { describe('Edge Cases', () => { it('should handle empty images array', () => { const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should not open preview when clicked image not found in list (index === -1)', () => { @@ -185,7 +251,8 @@ describe('ImageList', () => { fireEvent.click(firstThumb) // Preview should open for valid image - expect(screen.getByTestId('image-previewer')).toBeInTheDocument() + // Preview should open for valid image + expect(screen.getByTestId('image-previewer'))!.toBeInTheDocument() // Close preview fireEvent.click(screen.getByTestId('close-preview')) @@ -197,7 +264,7 @@ describe('ImageList', () => { const validThumb = screen.getByTestId('file-thumb-https://example.com/image-1.png') fireEvent.click(validThumb) - expect(screen.getByTestId('image-previewer')).toBeInTheDocument() + expect(screen.getByTestId('image-previewer'))!.toBeInTheDocument() }) it('should return early when file sourceUrl is not found in limitedImages (index === -1)', () => { @@ -216,6 +283,37 @@ describe('ImageList', () => { }) } + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages // Preview should NOT open because the file was not found in limitedImages expect(screen.queryByTestId('image-previewer')).not.toBeInTheDocument() }) @@ -223,7 +321,7 @@ describe('ImageList', () => { it('should handle single image', () => { const images = createMockImages(1) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should not show More button when images count equals limit', () => { @@ -236,13 +334,45 @@ describe('ImageList', () => { const images = createMockImages(5) render() // Should show "+5" for all images - expect(screen.getByText(/\+5/)).toBeInTheDocument() + // Should show "+5" for all images + expect(screen.getByText(/\+5/))!.toBeInTheDocument() }) it('should handle limit larger than images count', () => { const images = createMockImages(5) render() // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button expect(screen.queryByText(/\+/)).not.toBeInTheDocument() }) }) diff --git a/web/app/components/datasets/common/image-list/index.tsx b/web/app/components/datasets/common/image-list/index.tsx index 48a990f94db..b04dd5afbde 100644 --- a/web/app/components/datasets/common/image-list/index.tsx +++ b/web/app/components/datasets/common/image-list/index.tsx @@ -1,8 +1,8 @@ import type { ImageInfo } from '../image-previewer' import type { FileEntity } from '@/app/components/base/file-thumb' +import { cn } from '@langgenius/dify-ui/cn' import { useCallback, useMemo, useState } from 'react' import FileThumb from '@/app/components/base/file-thumb' -import { cn } from '@/utils/classnames' import ImagePreviewer from '../image-previewer' import More from './more' diff --git a/web/app/components/datasets/common/image-list/more.tsx b/web/app/components/datasets/common/image-list/more.tsx index be6b53a5a57..255e5e4d877 100644 --- a/web/app/components/datasets/common/image-list/more.tsx +++ b/web/app/components/datasets/common/image-list/more.tsx @@ -32,7 +32,7 @@ const More = ({ count, onClick }: MoreProps) => {
-
+
) } diff --git a/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx b/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx index 44f0e95c08d..1ff5cb33f8d 100644 --- a/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx +++ b/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx @@ -65,7 +65,8 @@ describe('ImagePreviewer', () => { }) // Should render in portal - expect(document.body.querySelector('.image-previewer')).toBeInTheDocument() + // Should render in portal + expect(document.body.querySelector('.image-previewer'))!.toBeInTheDocument() }) it('should render close button', async () => { @@ -77,7 +78,8 @@ describe('ImagePreviewer', () => { }) // Esc text should be visible - expect(screen.getByText('Esc')).toBeInTheDocument() + // Esc text should be visible + expect(screen.getByText('Esc'))!.toBeInTheDocument() }) it('should show loading state initially', async () => { @@ -92,7 +94,8 @@ describe('ImagePreviewer', () => { }) // Loading component should be visible - expect(document.body.querySelector('.image-previewer')).toBeInTheDocument() + // Loading component should be visible + expect(document.body.querySelector('.image-previewer'))!.toBeInTheDocument() }) }) @@ -107,7 +110,8 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Should start at second image - expect(screen.getByText('image2.png')).toBeInTheDocument() + // Should start at second image + expect(screen.getByText('image2.png'))!.toBeInTheDocument() }) }) @@ -120,7 +124,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) }) }) @@ -151,7 +155,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) // Find and click next button (right arrow) @@ -166,7 +170,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image2.png')).toBeInTheDocument() + expect(screen.getByText('image2.png'))!.toBeInTheDocument() }) } }) @@ -180,7 +184,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image2.png')).toBeInTheDocument() + expect(screen.getByText('image2.png'))!.toBeInTheDocument() }) // Find and click prev button (left arrow) @@ -195,7 +199,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) } }) @@ -213,7 +217,7 @@ describe('ImagePreviewer', () => { btn.className.includes('left-8'), ) - expect(prevButton).toBeDisabled() + expect(prevButton)!.toBeDisabled() }) it('should disable next button at last image', async () => { @@ -229,7 +233,7 @@ describe('ImagePreviewer', () => { btn.className.includes('right-8'), ) - expect(nextButton).toBeDisabled() + expect(nextButton)!.toBeDisabled() }) }) @@ -258,7 +262,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText(/Failed to load image/)).toBeInTheDocument() + expect(screen.getByText(/Failed to load image/))!.toBeInTheDocument() }) }) @@ -275,7 +279,7 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Retry button should be visible const retryButton = document.querySelector('button.rounded-full') - expect(retryButton).toBeInTheDocument() + expect(retryButton)!.toBeInTheDocument() }) }) }) @@ -290,7 +294,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) const buttons = document.querySelectorAll('button') @@ -306,7 +310,7 @@ describe('ImagePreviewer', () => { // Should still be at first image await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) } }) @@ -320,7 +324,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image3.png')).toBeInTheDocument() + expect(screen.getByText('image3.png'))!.toBeInTheDocument() }) const buttons = document.querySelectorAll('button') @@ -336,7 +340,7 @@ describe('ImagePreviewer', () => { // Should still be at last image await waitFor(() => { - expect(screen.getByText('image3.png')).toBeInTheDocument() + expect(screen.getByText('image3.png'))!.toBeInTheDocument() }) } }) @@ -366,7 +370,7 @@ describe('ImagePreviewer', () => { // Wait for error state await waitFor(() => { - expect(screen.getByText(/Failed to load image/)).toBeInTheDocument() + expect(screen.getByText(/Failed to load image/))!.toBeInTheDocument() }) const retryButton = document.querySelector('button.rounded-full') @@ -393,7 +397,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText(/Failed to load image/)).toBeInTheDocument() + expect(screen.getByText(/Failed to load image/))!.toBeInTheDocument() }) // Find and click the retry button (not the nav buttons) @@ -402,7 +406,7 @@ describe('ImagePreviewer', () => { btn.className.includes('rounded-full') && !btn.className.includes('left-8') && !btn.className.includes('right-8'), ) - expect(retryButton).toBeInTheDocument() + expect(retryButton)!.toBeInTheDocument() if (retryButton) { mockFetch.mockClear() @@ -451,7 +455,7 @@ describe('ImagePreviewer', () => { describe('Edge Cases', () => { it('should handle single image', async () => { const onClose = vi.fn() - const images = [createMockImages()[0]] + const images = [createMockImages()[0]!] await act(async () => { render() @@ -466,8 +470,8 @@ describe('ImagePreviewer', () => { btn.className.includes('right-8'), ) - expect(prevButton).toBeDisabled() - expect(nextButton).toBeDisabled() + expect(prevButton)!.toBeDisabled() + expect(nextButton)!.toBeDisabled() }) it('should stop event propagation on container click', async () => { @@ -500,7 +504,8 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Should display dimensions (800 × 600 from MockImage) - expect(screen.getByText(/800.*600/)).toBeInTheDocument() + // Should display dimensions (800 × 600 from MockImage) + expect(screen.getByText(/800.*600/))!.toBeInTheDocument() }) }) @@ -514,7 +519,8 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Should display formatted file size - expect(screen.getByText('image1.png')).toBeInTheDocument() + // Should display formatted file size + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/common/image-previewer/index.tsx b/web/app/components/datasets/common/image-previewer/index.tsx index ec153e874cb..42b4ce33a9f 100644 --- a/web/app/components/datasets/common/image-previewer/index.tsx +++ b/web/app/components/datasets/common/image-previewer/index.tsx @@ -1,8 +1,8 @@ +import { Button } from '@langgenius/dify-ui/button' import { RiArrowLeftLine, RiArrowRightLine, RiCloseLine, RiRefreshLine } from '@remixicon/react' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { createPortal } from 'react-dom' import { useHotkeys } from 'react-hotkeys-hook' -import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' import { formatFileSize } from '@/utils/format' @@ -137,7 +137,7 @@ const ImagePreviewer = ({ return { ...prev, [image.url]: { - ...prev[image.url], + ...prev[image.url]!, status: 'loading', }, } @@ -151,11 +151,11 @@ const ImagePreviewer = ({ return createPortal(
e.stopPropagation()} tabIndex={-1} > -
+
- {cachedImages[currentImage.url].status === 'loading' && ( + {cachedImages[currentImage!.url]!.status === 'loading' && ( )} - {cachedImages[currentImage.url].status === 'error' && ( -
- {`Failed to load image: ${currentImage.url}. Please try again.`} + {cachedImages[currentImage!.url]!.status === 'error' && ( +
+ {`Failed to load image: ${currentImage!.url}. Please try again.`}
)} - {cachedImages[currentImage.url].status === 'loaded' && ( + {cachedImages[currentImage!.url]!.status === 'loaded' && (
{currentImage.name} -
- {currentImage.name} +
+ {currentImage!.name} · - {`${cachedImages[currentImage.url].width} ×  ${cachedImages[currentImage.url].height}`} + {`${cachedImages[currentImage!.url]!.width} ×  ${cachedImages[currentImage!.url]!.height}`} · - {formatFileSize(currentImage.size)} + {formatFileSize(currentImage!.size)}
)} @@ -173,7 +173,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('rerank-switch')).toBeInTheDocument() + expect(screen.getByTestId('rerank-switch'))!.toBeInTheDocument() }) it('should render model selector when reranking is enabled', () => { @@ -186,7 +186,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('model-selector')).toBeInTheDocument() + expect(screen.getByTestId('model-selector'))!.toBeInTheDocument() }) it('should not render model selector when reranking is disabled', () => { @@ -212,8 +212,8 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('top-k-item')).toBeInTheDocument() - expect(screen.getByTestId('top-k-item')).toHaveAttribute('data-value', '5') + expect(screen.getByTestId('top-k-item'))!.toBeInTheDocument() + expect(screen.getByTestId('top-k-item'))!.toHaveAttribute('data-value', '5') }) it('should render score threshold item when reranking is enabled', () => { @@ -226,7 +226,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('score-threshold-item')).toBeInTheDocument() + expect(screen.getByTestId('score-threshold-item'))!.toBeInTheDocument() }) it('should toggle reranking enable', () => { @@ -349,7 +349,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip')).toBeInTheDocument() + expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip'))!.toBeInTheDocument() }) it('should not show multimodal tip when showMultiModalTip is false', () => { @@ -378,7 +378,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('rerank-switch')).toBeInTheDocument() + expect(screen.getByTestId('rerank-switch'))!.toBeInTheDocument() }) it('should hide score threshold when reranking is disabled for full text search', () => { @@ -410,7 +410,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('score-threshold-item')).toBeInTheDocument() + expect(screen.getByTestId('score-threshold-item'))!.toBeInTheDocument() }) }) @@ -451,7 +451,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('top-k-item')).toBeInTheDocument() + expect(screen.getByTestId('top-k-item'))!.toBeInTheDocument() }) it('should not render score threshold for keyword search', () => { @@ -496,7 +496,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('dataset.weightedScore.title')).toBeInTheDocument() + expect(screen.getByText('dataset.weightedScore.title'))!.toBeInTheDocument() }) it('should have RerankingModel option', () => { @@ -508,7 +508,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('common.modelProvider.rerankModel.key')).toBeInTheDocument() + expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument() }) it('should show model selector when RerankingModel mode is selected', () => { @@ -520,7 +520,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('model-selector')).toBeInTheDocument() + expect(screen.getByTestId('model-selector'))!.toBeInTheDocument() }) it('should show WeightedScore component when WeightedScore mode is selected', () => { @@ -547,7 +547,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('weighted-score')).toBeInTheDocument() + expect(screen.getByTestId('weighted-score'))!.toBeInTheDocument() expect(screen.queryByTestId('model-selector')).not.toBeInTheDocument() }) @@ -565,7 +565,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(weightedScoreCard!) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.reranking_mode).toBe(RerankingModeEnum.WeightedScore) expect(calledWith.weights).toBeDefined() }) @@ -645,7 +645,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(screen.getByTestId('change-weights-btn')) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.weights.vector_setting.vector_weight).toBe(0.6) expect(calledWith.weights.keyword_setting.keyword_weight).toBe(0.4) }) @@ -659,8 +659,8 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('top-k-item')).toBeInTheDocument() - expect(screen.getByTestId('score-threshold-item')).toBeInTheDocument() + expect(screen.getByTestId('top-k-item'))!.toBeInTheDocument() + expect(screen.getByTestId('score-threshold-item'))!.toBeInTheDocument() }) it('should update top_k for hybrid search', () => { @@ -724,7 +724,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip')).toBeInTheDocument() + expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip'))!.toBeInTheDocument() }) it('should not show multimodal tip for hybrid search with WeightedScore', () => { @@ -799,7 +799,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('tooltip')).toBeInTheDocument() + expect(screen.getByTestId('tooltip'))!.toBeInTheDocument() }) }) @@ -814,7 +814,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('common.modelProvider.rerankModel.key')).toBeInTheDocument() + expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument() }) }) @@ -838,7 +838,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(weightedScoreCard!) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.weights).toBeDefined() expect(calledWith.weights.weight_type).toBe(WeightedScoreEnum.Customized) }) @@ -872,7 +872,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(weightedScoreCard!) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.weights.vector_setting.vector_weight).toBe(0.8) }) }) diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index acb77acfa60..8514e1fae2e 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -1,8 +1,11 @@ 'use client' import type { FC } from 'react' import type { RetrievalConfig } from '@/types/app' -import * as React from 'react' +import { cn } from '@langgenius/dify-ui/cn' +import { Switch } from '@langgenius/dify-ui/switch' +import { toast } from '@langgenius/dify-ui/toast' +import * as React from 'react' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' @@ -10,9 +13,7 @@ import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/aler import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' import TopKItem from '@/app/components/base/param-item/top-k-item' import RadioCard from '@/app/components/base/radio-card' -import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' -import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' @@ -22,7 +23,6 @@ import { WeightedScoreEnum, } from '@/models/datasets' import { RETRIEVE_METHOD } from '@/types/app' -import { cn } from '@/utils/classnames' import ProgressIndicator from '../../create/assets/progress-indicator.svg' import Reranking from '../../create/assets/rerank.svg' @@ -121,12 +121,12 @@ const RetrievalParamConfig: FC = ({ {canToggleRerankModalEnable && ( )}
- {t('modelProvider.rerankModel.key', { ns: 'common' })} + {t('modelProvider.rerankModel.key', { ns: 'common' })} {t('modelProvider.rerankModel.tip', { ns: 'common' })}
@@ -152,11 +152,11 @@ const RetrievalParamConfig: FC = ({ /> {showMultiModalTip && (
-
+
- + {t('form.retrievalSetting.multiModalTip', { ns: 'datasetSettings' })}
@@ -246,11 +246,11 @@ const RetrievalParamConfig: FC = ({ ...value.weights!, vector_setting: { ...value.weights!.vector_setting, - vector_weight: v.value[0], + vector_weight: v.value[0]!, }, keyword_setting: { ...value.weights!.keyword_setting, - keyword_weight: v.value[1], + keyword_weight: v.value[1]!, }, }, }) @@ -276,11 +276,11 @@ const RetrievalParamConfig: FC = ({ /> {showMultiModalTip && (
-
+
- + {t('form.retrievalSetting.multiModalTip', { ns: 'datasetSettings' })}
diff --git a/web/app/components/datasets/create-from-pipeline/__tests__/footer.spec.tsx b/web/app/components/datasets/create-from-pipeline/__tests__/footer.spec.tsx index 7f1bc0e00cf..24c8b0338dc 100644 --- a/web/app/components/datasets/create-from-pipeline/__tests__/footer.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/__tests__/footer.spec.tsx @@ -124,7 +124,7 @@ describe('Footer', () => { expect(footerDiv).toHaveClass('absolute', 'bottom-0', 'left-0', 'right-0', 'z-10') }) - it('should have backdrop blur effect', () => { + it('should have backdrop blur-sm effect', () => { const { container } = render(